aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--zenhttp/include/zenhttp/websocket.h44
-rw-r--r--zenhttp/websocketasio.cpp103
-rw-r--r--zenserver/cache/structuredcache.cpp156
-rw-r--r--zenserver/cache/structuredcache.h6
-rw-r--r--zenserver/zenserver.cpp7
5 files changed, 281 insertions, 35 deletions
diff --git a/zenhttp/include/zenhttp/websocket.h b/zenhttp/include/zenhttp/websocket.h
index 132dd1679..7ec0a8555 100644
--- a/zenhttp/include/zenhttp/websocket.h
+++ b/zenhttp/include/zenhttp/websocket.h
@@ -49,9 +49,37 @@ enum class WebSocketMessageType : uint8_t
kInvalid,
kNotification,
kRequest,
- kResponse
+ kStreamRequest,
+ kResponse,
+ kStreamResponse,
+ kStreamCompleteResponse,
+ kCount
};
+inline std::string_view
+ToString(WebSocketMessageType Type)
+{
+ switch (Type)
+ {
+ case WebSocketMessageType::kInvalid:
+ return std::string_view("Invalid");
+ case WebSocketMessageType::kNotification:
+ return std::string_view("Notification");
+ case WebSocketMessageType::kRequest:
+ return std::string_view("Request");
+ case WebSocketMessageType::kStreamRequest:
+ return std::string_view("StreamRequest");
+ case WebSocketMessageType::kResponse:
+ return std::string_view("Response");
+ case WebSocketMessageType::kStreamResponse:
+ return std::string_view("StreamResponse");
+ case WebSocketMessageType::kStreamCompleteResponse:
+ return std::string_view("StreamCompleteResponse");
+ default:
+ return std::string_view("Unknown");
+ };
+}
+
/**
* Web socket message.
*/
@@ -59,12 +87,12 @@ class WebSocketMessage
{
struct Header
{
- static constexpr uint32_t HeaderMagic = 0x7a776d68; // zwmh
+ static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh
uint64_t MessageSize{};
- uint32_t Magic{HeaderMagic};
+ uint32_t Magic{ExpectedMagic};
uint32_t CorrelationId{};
- uint32_t Crc32{};
+ uint32_t StatusCode{200u};
WebSocketMessageType MessageType{};
uint8_t Reserved[3] = {0};
@@ -82,11 +110,13 @@ public:
WebSocketId SocketId() const { return m_SocketId; }
void SetSocketId(WebSocketId Id) { m_SocketId = Id; }
- void SetMessageType(WebSocketMessageType MessageType);
- WebSocketMessageType MessageType() const { return m_Header.MessageType; }
uint64_t MessageSize() const { return m_Header.MessageSize; }
+ void SetMessageType(WebSocketMessageType MessageType);
void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; }
uint32_t CorrelationId() const { return m_Header.CorrelationId; }
+ uint32_t StatusCode() const { m_Header.StatusCode; }
+ void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; }
+ WebSocketMessageType MessageType() const { return m_Header.MessageType; }
const CbPackage& Body() const { return m_Body.value(); }
void SetBody(CbPackage&& Body);
@@ -123,6 +153,8 @@ protected:
WebSocketService() = default;
virtual void RegisterHandlers(WebSocketServer& Server) = 0;
+ void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse);
+ void SendStreamCompleteResponse(WebSocketId SocketId, uint32_t CorrelationId);
WebSocketServer& SocketServer()
{
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
index 966925d98..5407f8bc6 100644
--- a/zenhttp/websocketasio.cpp
+++ b/zenhttp/websocketasio.cpp
@@ -4,6 +4,7 @@
#include <zencore/base64.h>
#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
#include <zencore/intmath.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
@@ -11,6 +12,7 @@
#include <zencore/sha1.h>
#include <zencore/stream.h>
#include <zencore/string.h>
+#include <zencore/trace.h>
#include <chrono>
#include <optional>
@@ -386,6 +388,8 @@ private:
ParseMessageResult
WebSocketMessageParser::OnParseMessage(MemoryView Msg)
{
+ ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage");
+
const uint64_t PrevOffset = m_Stream.CurrentOffset();
if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
@@ -393,6 +397,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
m_Stream.Write(Msg.Left(RemaingHeaderSize));
+ Msg += RemaingHeaderSize;
if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
{
@@ -410,24 +415,26 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
.Reason = std::string("Invalid websocket message header")};
}
- Msg += RemaingHeaderSize;
+ if (m_Message.MessageSize() == 0)
+ {
+ return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
}
ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
- if (Msg.IsEmpty())
+ if (Msg.IsEmpty() == false)
{
- return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
+ m_Stream.Write(Msg.Left(RemaingMessageSize));
}
- const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
-
- m_Stream.Write(Msg.Left(RemaingMessageSize));
-
- const bool IsComplete = WebSocketMessage::HeaderSize + m_Message.MessageSize() == m_Stream.CurrentOffset();
+ auto Status = ParseMessageStatus::kContinue;
- if (IsComplete)
+ if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize())
{
+ Status = ParseMessageStatus::kDone;
+
BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize));
CbPackage Pkg;
@@ -441,8 +448,7 @@ WebSocketMessageParser::OnParseMessage(MemoryView Msg)
m_Message.SetBody(std::move(Pkg));
}
- return {.Status = IsComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue,
- .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
}
void
@@ -486,6 +492,7 @@ public:
WebSocketState Close();
MessageParser* Parser() { return m_MsgParser.get(); }
void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ std::mutex& WriteMutex() { return m_WriteMutex; }
private:
WebSocketId m_Id;
@@ -494,6 +501,7 @@ private:
std::atomic_uint32_t m_State;
std::unique_ptr<MessageParser> m_MsgParser;
asio::streambuf m_ReadBuffer;
+ std::mutex m_WriteMutex;
};
WebSocketState
@@ -635,13 +643,16 @@ WsServer::RegisterService(WebSocketService& Service)
bool
WsServer::Run()
{
+ static constexpr size_t ReceiveBufferSize = 256 << 10;
+ static constexpr size_t SendBufferSize = 256 << 10;
+
m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
m_Acceptor->set_option(asio::ip::v6_only(false));
m_Acceptor->set_option(asio::socket_base::reuse_address(true));
m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_Acceptor->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+ m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize));
+ m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize));
asio::error_code Ec;
m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec);
@@ -706,7 +717,10 @@ WsServer::SendNotification(WebSocketMessage&& Notification)
void
WsServer::SendResponse(WebSocketMessage&& Response)
{
- ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse);
+ ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse);
+
ZEN_ASSERT(Response.CorrelationId() != 0);
SendMessage(std::move(Response));
@@ -970,6 +984,7 @@ WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
switch (RoutedMessage.MessageType())
{
case WebSocketMessageType::kRequest:
+ case WebSocketMessageType::kStreamRequest:
{
CbObjectView Request = RoutedMessage.Body().GetObject();
std::string_view Method = Request["Method"].AsString();
@@ -995,7 +1010,7 @@ WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
if (Error || Handled == false)
{
- std::string ErrorText = Error ? Exception.what() : std::string("Not Found");
+ std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method);
ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
@@ -1063,18 +1078,21 @@ WsServer::SendMessage(WebSocketMessage&& Msg)
{
BinaryWriter Writer;
Msg.Save(Writer);
- IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
- async_write(Connection->Socket(),
- asio::buffer(Buffer.Data(), Buffer.Size()),
- [this, Connection, Buffer](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason '{}'", Ec.message());
+ ZEN_LOG_TRACE(LogWebSocket,
+ "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}",
+ ToString(Msg.MessageType()),
+ Connection->Id().Value(),
+ Msg.MessageSize(),
+ Msg.CorrelationId(),
+ NiceBytes(Writer.Size()));
- CloseConnection(Connection, Ec);
- }
- });
+ {
+ ZEN_TRACE_CPU("WS::SendMessage");
+ std::unique_lock _(Connection->WriteMutex());
+ ZEN_TRACE_CPU("WS::WriteSocketData");
+ asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size()));
+ }
}
}
@@ -1458,7 +1476,8 @@ std::atomic_uint32_t WebSocketId::NextId{1};
bool
WebSocketMessage::Header::IsValid() const
{
- return Magic == HeaderMagic && MessageSize > 0 && Crc32 > 0 && uint8_t(MessageType) > 0 && uint8_t(MessageType) < 4;
+ return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) &&
+ uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount);
}
std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
@@ -1490,6 +1509,12 @@ WebSocketMessage::Save(BinaryWriter& Writer)
if (m_Body.has_value())
{
+ const CbObject& Obj = m_Body.value().GetObject();
+ MemoryView View = Obj.GetBuffer().GetView();
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All);
+ ZEN_ASSERT(ValidationResult == CbValidateError::None);
+
m_Body.value().Save(Writer);
}
@@ -1499,7 +1524,6 @@ WebSocketMessage::Save(BinaryWriter& Writer)
}
m_Header.MessageSize = Writer.Size() - HeaderSize;
- m_Header.Crc32 = 1; // TODO
Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
}
@@ -1529,6 +1553,31 @@ WebSocketService::Configure(WebSocketServer& Server)
RegisterHandlers(Server);
}
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse)
+{
+ WebSocketMessage Message;
+
+ Message.SetMessageType(WebSocketMessageType::kStreamResponse);
+ Message.SetCorrelationId(CorrelationId);
+ Message.SetSocketId(SocketId);
+ Message.SetBody(std::move(StreamResponse));
+
+ SocketServer().SendResponse(std::move(Message));
+}
+
+void
+WebSocketService::SendStreamCompleteResponse(WebSocketId SocketId, uint32_t CorrelationId)
+{
+ WebSocketMessage Message;
+
+ Message.SetMessageType(WebSocketMessageType::kStreamCompleteResponse);
+ Message.SetCorrelationId(CorrelationId);
+ Message.SetSocketId(SocketId);
+
+ SocketServer().SendResponse(std::move(Message));
+}
+
std::unique_ptr<WebSocketServer>
WebSocketServer::Create(const WebSocketServerOptions& Options)
{
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp
index 8ae531720..499329e94 100644
--- a/zenserver/cache/structuredcache.cpp
+++ b/zenserver/cache/structuredcache.cpp
@@ -162,6 +162,162 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request)
}
void
+HttpStructuredCacheService::RegisterHandlers(WebSocketServer& Server)
+{
+ Server.RegisterRequestHandler("GetBinaryCacheValue"sv, *this);
+ Server.RegisterRequestHandler("GetCacheValues"sv, *this);
+}
+
+bool
+HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage)
+{
+ CbObjectView Request = RequestMessage.Body().GetObject();
+
+ const auto Method = Request["Method"].AsString();
+ CbObjectView Params = Request["Params"sv].AsObjectView();
+
+ if (Method == "GetBinaryCacheValue"sv)
+ {
+ ZEN_TRACE_CPU("Z$::WS_GetBinaryCacheValue");
+
+ // CachePolicy Policy;
+ CbObjectView KeyObject = Params["Key"sv].AsObjectView();
+ CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
+
+ ZenCacheValue CacheValue;
+ const bool InLocalCache = m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue);
+
+ CbPackage Response;
+
+ if (InLocalCache)
+ {
+ m_CacheStats.HitCount++;
+
+ CbAttachment Attachment(SharedBuffer(CacheValue.Value));
+
+ CbObjectWriter ResponseObject;
+ ResponseObject.AddAttachment("Result", Attachment);
+ Response.AddAttachment(std::move(Attachment));
+ Response.SetObject(ResponseObject.Save());
+
+ ZenContentType ContentType = CacheValue.Value.GetContentType();
+
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", Key.Bucket, Key.Hash, NiceBytes(CacheValue.Value.Size()), ToString(ContentType));
+ }
+ else
+ {
+ m_CacheStats.MissCount++;
+
+ CbObjectWriter ResponseObject;
+ ResponseObject << "Error"sv
+ << "Not Found"sv;
+ Response.SetObject(ResponseObject.Save());
+
+ ZEN_DEBUG("MISS - '{}/{}' '{}'", Key.Bucket, Key.Hash, ToString(ZenContentType::kBinary));
+ }
+
+ WebSocketMessage ResponseMessage;
+ ResponseMessage.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMessage.SetCorrelationId(RequestMessage.CorrelationId());
+ ResponseMessage.SetSocketId(RequestMessage.SocketId());
+ ResponseMessage.SetBody(std::move(Response));
+
+ SocketServer().SendResponse(std::move(ResponseMessage));
+
+ return true;
+ }
+
+ if (Method == "GetCacheValues"sv)
+ {
+ ZEN_TRACE_CPU("Z$::WS_GetCacheValues");
+
+ const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText);
+
+ for (uint32_t RequestIdx = 0; CbFieldView RequestField : Params["Requests"sv])
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+ CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
+ std::string_view PolicyText = RequestObject["Policy"sv].AsString();
+ CachePolicy Policy = PolicyText.empty() ? DefaultPolicy : ParseCachePolicy(PolicyText);
+
+ CompressedBuffer Compressed;
+ bool InLocalCache = false;
+
+ if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue CacheValue;
+ if (m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue))
+ {
+ Compressed = CompressedBuffer::FromCompressed(SharedBuffer(CacheValue.Value));
+ InLocalCache = true;
+ }
+ }
+
+ if (Compressed.IsNull() && EnumHasAllFlags(Policy, CachePolicy::QueryRemote))
+ {
+ if (auto UpstreamResult = m_UpstreamCache.GetCacheRecord({Key.Bucket, Key.Hash}, ZenContentType::kCompressedBinary);
+ UpstreamResult.Success)
+ {
+ Compressed = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value));
+
+ if (Compressed)
+ {
+ UpstreamResult.Value.SetContentType(ZenContentType::kCompressedBinary);
+ m_CacheStore.Put(Key.Bucket, Key.Hash, ZenCacheValue{UpstreamResult.Value});
+ }
+ }
+ }
+
+ CbPackage Response;
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginObject("Result"sv);
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIdx++);
+
+ const IoHash RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash());
+ const uint64_t RawSize = Compressed.GetRawSize();
+
+ if (Compressed)
+ {
+ ResponseObject.AddHash("RawHash"sv, RawHash);
+
+ if (EnumHasAllFlags(Policy, CachePolicy::SkipData))
+ {
+ ResponseObject.AddInteger("RawSize"sv, RawSize);
+ }
+ else
+ {
+ Response.AddAttachment(CbAttachment(std::move(Compressed)));
+ }
+ }
+
+ ResponseObject.EndObject();
+ Response.SetObject(ResponseObject.Save());
+
+ SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response));
+
+ if (RawSize > 0)
+ {
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}'", Key.Bucket, Key.Hash, NiceBytes(RawSize), ToString(ZenContentType::kCompressedBinary));
+ }
+ else
+ {
+ ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash);
+ }
+ }
+
+ SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId());
+
+ return true;
+ }
+
+ return false;
+}
+
+void
HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Bucket)
{
switch (Request.RequestVerb())
diff --git a/zenserver/cache/structuredcache.h b/zenserver/cache/structuredcache.h
index 00c4260aa..39585f402 100644
--- a/zenserver/cache/structuredcache.h
+++ b/zenserver/cache/structuredcache.h
@@ -4,6 +4,7 @@
#include <zencore/stats.h>
#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
#include "monitoring/httpstats.h"
#include "monitoring/httpstatus.h"
@@ -61,7 +62,7 @@ namespace cache::detail {
*
*/
-class HttpStructuredCacheService : public HttpService, public IHttpStatsProvider, public IHttpStatusProvider
+class HttpStructuredCacheService : public HttpService, public IHttpStatsProvider, public IHttpStatusProvider, public WebSocketService
{
public:
HttpStructuredCacheService(ZenCacheStore& InCacheStore,
@@ -138,6 +139,9 @@ private:
/** HandleRpcGetCacheChunks Helper: Send response message containing all chunk results. */
void WriteGetCacheChunksResponse(std::vector<cache::detail::ChunkRequest>& Requests, zen::HttpServerRequest& HttpRequest);
+ virtual void RegisterHandlers(WebSocketServer& Server) override;
+ virtual bool HandleRequest(const WebSocketMessage& RequestMessage) override;
+
spdlog::logger& Log() { return m_Log; }
spdlog::logger& m_Log;
ZenCacheStore& m_CacheStore;
diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp
index 78a62e202..3c7f9004d 100644
--- a/zenserver/zenserver.cpp
+++ b/zenserver/zenserver.cpp
@@ -211,7 +211,7 @@ public:
ServerOptions.WebSocketThreads > 0 ? uint32_t(ServerOptions.WebSocketThreads) : std::thread::hardware_concurrency();
m_WebSocket = zen::WebSocketServer::Create(
- {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Max(ThreadCount, uint32_t(16))});
+ {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Min(ThreadCount, std::thread::hardware_concurrency())});
}
// Setup authentication manager
@@ -819,6 +819,11 @@ ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions)
m_Http->RegisterService(*m_StructuredCacheService);
m_Http->RegisterService(*m_UpstreamService);
+
+ if (m_WebSocket)
+ {
+ m_WebSocket->RegisterService(*m_StructuredCacheService);
+ }
}
////////////////////////////////////////////////////////////////////////////////