// Copyright Epic Games, Inc. All Rights Reserved. #include #include "servers/httpasio.h" #include "servers/httpmulti.h" #include "servers/httpnull.h" #include "servers/httpsys.h" #include "zenhttp/httpplugin.h" #if ZEN_WITH_PLUGINS # include "transports/asiotransport.h" # include "transports/dlltransport.h" # include "transports/winsocktransport.h" #endif #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace zen { using namespace std::literals; std::string_view MapContentTypeToString(HttpContentType ContentType) { switch (ContentType) { default: case HttpContentType::kUnknownContentType: case HttpContentType::kBinary: return "application/octet-stream"sv; case HttpContentType::kText: return "text/plain"sv; case HttpContentType::kJSON: return "application/json"sv; case HttpContentType::kCbObject: return "application/x-ue-cb"sv; case HttpContentType::kCbPackage: return "application/x-ue-cbpkg"sv; case HttpContentType::kCbPackageOffer: return "application/x-ue-offer"sv; case HttpContentType::kCompressedBinary: return "application/x-ue-comp"sv; case HttpContentType::kYAML: return "text/yaml"sv; case HttpContentType::kHTML: return "text/html"sv; case HttpContentType::kJavaScript: return "application/javascript"sv; case HttpContentType::kCSS: return "text/css"sv; case HttpContentType::kPNG: return "image/png"sv; case HttpContentType::kIcon: return "image/x-icon"sv; case HttpContentType::kXML: return "application/xml"sv; case HttpContentType::kProtobuf: return "application/x-protobuf"sv; } } ////////////////////////////////////////////////////////////////////////// // // Note that in addition to MIME types we accept abbreviated versions, for // use in suffix parsing as well as for convenience when using curl static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); static constinit uint32_t HashJson = HashStringDjb2("json"sv); static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv); static constinit uint32_t HashApplicationProblemJson = HashStringDjb2("application/problem+json"sv); static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv); static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv); static constinit uint32_t HashText = HashStringDjb2("text/plain"sv); static constinit uint32_t HashApplicationCompactBinary = HashStringDjb2("application/x-ue-cb"sv); static constinit uint32_t HashCompactBinary = HashStringDjb2("ucb"sv); static constinit uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv); static constinit uint32_t HashCompactBinaryPackageShort = HashStringDjb2("cbpkg"sv); static constinit uint32_t HashCompactBinaryPackageOffer = HashStringDjb2("application/x-ue-offer"sv); static constinit uint32_t HashCompressedBinary = HashStringDjb2("application/x-ue-comp"sv); static constinit uint32_t HashHtml = HashStringDjb2("html"sv); static constinit uint32_t HashTextHtml = HashStringDjb2("text/html"sv); static constinit uint32_t HashJavaScript = HashStringDjb2("js"sv); static constinit uint32_t HashJavaScriptSourceMap = HashStringDjb2("map"sv); // actually .js.map static constinit uint32_t HashApplicationJavaScript = HashStringDjb2("application/javascript"sv); static constinit uint32_t HashCss = HashStringDjb2("css"sv); static constinit uint32_t HashTextCss = HashStringDjb2("text/css"sv); static constinit uint32_t HashPng = HashStringDjb2("png"sv); static constinit uint32_t HashImagePng = HashStringDjb2("image/png"sv); static constinit uint32_t HashIcon = HashStringDjb2("ico"sv); static constinit uint32_t HashImageIcon = HashStringDjb2("image/x-icon"sv); static constinit uint32_t HashXml = HashStringDjb2("application/xml"sv); static constinit uint32_t HashProtobuf = HashStringDjb2("application/x-protobuf"sv); std::once_flag InitContentTypeLookup; struct HashedTypeEntry { uint32_t Hash; HttpContentType Type; } TypeHashTable[] = { // clang-format off {HashBinary, HttpContentType::kBinary}, {HashApplicationCompactBinary, HttpContentType::kCbObject}, {HashCompactBinary, HttpContentType::kCbObject}, {HashCompactBinaryPackage, HttpContentType::kCbPackage}, {HashCompactBinaryPackageShort, HttpContentType::kCbPackage}, {HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer}, {HashJson, HttpContentType::kJSON}, {HashApplicationJson, HttpContentType::kJSON}, {HashApplicationProblemJson, HttpContentType::kJSON}, {HashYaml, HttpContentType::kYAML}, {HashTextYaml, HttpContentType::kYAML}, {HashText, HttpContentType::kText}, {HashCompressedBinary, HttpContentType::kCompressedBinary}, {HashHtml, HttpContentType::kHTML}, {HashTextHtml, HttpContentType::kHTML}, {HashJavaScript, HttpContentType::kJavaScript}, {HashApplicationJavaScript, HttpContentType::kJavaScript}, {HashJavaScriptSourceMap, HttpContentType::kJavaScript}, {HashCss, HttpContentType::kCSS}, {HashTextCss, HttpContentType::kCSS}, {HashPng, HttpContentType::kPNG}, {HashImagePng, HttpContentType::kPNG}, {HashIcon, HttpContentType::kIcon}, {HashImageIcon, HttpContentType::kIcon}, {HashXml, HttpContentType::kXML}, {HashProtobuf, HttpContentType::kProtobuf}, // clang-format on }; HttpContentType ParseContentTypeImpl(const std::string_view& ContentTypeString) { if (!ContentTypeString.empty()) { size_t ContentEnd = ContentTypeString.find(';'); if (ContentEnd == std::string_view::npos) { ContentEnd = ContentTypeString.length(); } std::string_view ContentString(ContentTypeString.substr(0, ContentEnd)); const uint32_t CtHash = HashStringDjb2(ContentString); if (auto It = std::lower_bound(std::begin(TypeHashTable), std::end(TypeHashTable), CtHash, [](const HashedTypeEntry& Lhs, const uint32_t Rhs) { return Lhs.Hash < Rhs; }); It != std::end(TypeHashTable)) { if (It->Hash == CtHash) { return It->Type; } } } return HttpContentType::kUnknownContentType; } HttpContentType ParseContentTypeInit(const std::string_view& ContentTypeString) { std::call_once(InitContentTypeLookup, [] { std::sort(std::begin(TypeHashTable), std::end(TypeHashTable), [](const HashedTypeEntry& Lhs, const HashedTypeEntry& Rhs) { return Lhs.Hash < Rhs.Hash; }); // validate that there are no hash collisions uint32_t LastHash = 0; for (const auto& Item : TypeHashTable) { ZEN_ASSERT(LastHash != Item.Hash); LastHash = Item.Hash; } }); ParseContentType = ParseContentTypeImpl; return ParseContentTypeImpl(ContentTypeString); } HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString) = &ParseContentTypeInit; bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) { if (RangeHeader.empty()) { return false; } const size_t Count = Ranges.size(); std::size_t UnitDelim = RangeHeader.find_first_of('='); if (UnitDelim == std::string_view::npos) { return false; } // only bytes for now std::string_view Unit = RangeHeader.substr(0, UnitDelim); if (Unit != "bytes"sv) { return false; } std::string_view Tokens = RangeHeader.substr(UnitDelim); while (!Tokens.empty()) { // Skip =, Tokens = Tokens.substr(1); size_t Delim = Tokens.find_first_of(','); if (Delim == std::string_view::npos) { Delim = Tokens.length(); } std::string_view Token = Tokens.substr(0, Delim); Tokens = Tokens.substr(Delim); Delim = Token.find_first_of('-'); if (Delim == std::string_view::npos) { return false; } const auto Start = ParseInt(Token.substr(0, Delim)); const auto End = ParseInt(Token.substr(Delim + 1)); if (Start.has_value() && End.has_value() && End.value() > Start.value()) { Ranges.push_back({.Start = Start.value(), .End = End.value()}); } else if (Start) { Ranges.push_back({.Start = Start.value()}); } else if (End) { Ranges.push_back({.End = End.value()}); } } return Count != Ranges.size(); } ////////////////////////////////////////////////////////////////////////// const std::string_view ToString(HttpVerb Verb) { switch (Verb) { case HttpVerb::kGet: return "GET"sv; case HttpVerb::kPut: return "PUT"sv; case HttpVerb::kPost: return "POST"sv; case HttpVerb::kDelete: return "DELETE"sv; case HttpVerb::kHead: return "HEAD"sv; case HttpVerb::kCopy: return "COPY"sv; case HttpVerb::kOptions: return "OPTIONS"sv; default: return "???"sv; } } std::string_view ToString(HttpResponseCode HttpCode) { return ReasonStringForHttpResultCode(int(HttpCode)); } std::string_view ReasonStringForHttpResultCode(int HttpCode) { switch (HttpCode) { // 1xx Informational case 100: return "Continue"sv; case 101: return "Switching Protocols"sv; // 2xx Success case 200: return "OK"sv; case 201: return "Created"sv; case 202: return "Accepted"sv; case 204: return "No Content"sv; case 205: return "Reset Content"sv; case 206: return "Partial Content"sv; // 3xx Redirection case 300: return "Multiple Choices"sv; case 301: return "Moved Permanently"sv; case 302: return "Found"sv; case 303: return "See Other"sv; case 304: return "Not Modified"sv; case 305: return "Use Proxy"sv; case 306: return "Switch Proxy"sv; case 307: return "Temporary Redirect"sv; case 308: return "Permanent Redirect"sv; // 4xx Client errors case 400: return "Bad Request"sv; case 401: return "Unauthorized"sv; case 402: return "Payment Required"sv; case 403: return "Forbidden"sv; case 404: return "Not Found"sv; case 405: return "Method Not Allowed"sv; case 406: return "Not Acceptable"sv; case 407: return "Proxy Authentication Required"sv; case 408: return "Request Timeout"sv; case 409: return "Conflict"sv; case 410: return "Gone"sv; case 411: return "Length Required"sv; case 412: return "Precondition Failed"sv; case 413: return "Payload Too Large"sv; case 414: return "URI Too Long"sv; case 415: return "Unsupported Media Type"sv; case 416: return "Range Not Satisifiable"sv; case 417: return "Expectation Failed"sv; case 418: return "I'm a teapot"sv; case 421: return "Misdirected Request"sv; case 422: return "Unprocessable Entity"sv; case 423: return "Locked"sv; case 424: return "Failed Dependency"sv; case 425: return "Too Early"sv; case 426: return "Upgrade Required"sv; case 428: return "Precondition Required"sv; case 429: return "Too Many Requests"sv; case 431: return "Request Header Fields Too Large"sv; // 5xx Server errors case 500: return "Internal Server Error"sv; case 501: return "Not Implemented"sv; case 502: return "Bad Gateway"sv; case 503: return "Service Unavailable"sv; case 504: return "Gateway Timeout"sv; case 505: return "HTTP Version Not Supported"sv; case 506: return "Variant Also Negotiates"sv; case 507: return "Insufficient Storage"sv; case 508: return "Loop Detected"sv; case 510: return "Not Extended"sv; case 511: return "Network Authentication Required"sv; default: return "Unknown Result"sv; } } ////////////////////////////////////////////////////////////////////////// Ref HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) { ZEN_UNUSED(HttpServiceRequest); return Ref(); } ////////////////////////////////////////////////////////////////////////// HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) { } HttpServerRequest::~HttpServerRequest() { } void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbPackage Data) { std::vector ResponseBuffers = FormatPackageMessage(Data); return WriteResponse(ResponseCode, HttpContentType::kCbPackage, ResponseBuffers); } void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data) { if (m_AcceptType == HttpContentType::kJSON) { ExtendableStringBuilder<1024> Sb; WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView()); } else if (m_AcceptType == HttpContentType::kYAML) { ExtendableStringBuilder<1024> Sb; WriteResponse(ResponseCode, HttpContentType::kYAML, Data.ToYaml(Sb).ToView()); } else { SharedBuffer Buf = Data.GetBuffer(); std::array Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); } } void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array) { if (m_AcceptType == HttpContentType::kJSON) { ExtendableStringBuilder<1024> Sb; WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView()); } else if (m_AcceptType == HttpContentType::kYAML) { ExtendableStringBuilder<1024> Sb; WriteResponse(ResponseCode, HttpContentType::kYAML, Array.ToYaml(Sb).ToView()); } else { SharedBuffer Buf = Array.GetBuffer(); std::array Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); } } void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString) { return WriteResponse(ResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()}); } void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob) { std::array Buffers{Blob}; return WriteResponse(ResponseCode, ContentType, Buffers); } void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) { std::span Segments = Payload.GetSegments(); eastl::fixed_vector Buffers; Buffers.reserve(Segments.size()); for (auto& Segment : Segments) { Buffers.push_back(Segment.AsIoBuffer()); } WriteResponse(ResponseCode, ContentType, std::span(begin(Buffers), end(Buffers))); } std::string HttpServerRequest::Decode(std::string_view PercentEncodedString) { size_t Length = PercentEncodedString.length(); std::string Decoded; Decoded.reserve(Length); size_t Offset = 0; while (Offset < Length) { char C = PercentEncodedString[Offset]; if (C == '%' && (Offset <= (Length - 3))) { std::string_view CharHash(&PercentEncodedString[Offset + 1], 2); uint8_t DecodedChar = 0; if (ParseHexBytes(CharHash, &DecodedChar)) { Decoded.push_back((char)DecodedChar); Offset += 3; } else { Decoded.push_back(C); Offset++; } } else { Decoded.push_back(C); Offset++; } } return Decoded; } HttpServerRequest::QueryParams HttpServerRequest::GetQueryParams() { QueryParams Params; const std::string_view QStr = QueryString(); const char* QueryIt = QStr.data(); const char* QueryEnd = QueryIt + QStr.size(); while (QueryIt != QueryEnd) { if (*QueryIt == '&') { ++QueryIt; continue; } size_t QueryLen = ptrdiff_t(QueryEnd - QueryIt); const std::string_view Query{QueryIt, QueryLen}; size_t DelimIndex = Query.find('&', 0); if (DelimIndex == std::string_view::npos) { DelimIndex = Query.size(); } std::string_view ThisQuery{QueryIt, DelimIndex}; size_t EqIndex = ThisQuery.find('=', 0); if (EqIndex != std::string_view::npos) { std::string_view Param{ThisQuery.data(), EqIndex}; ThisQuery.remove_prefix(EqIndex + 1); Params.KvPairs.emplace_back(Param, ThisQuery); } QueryIt += DelimIndex; } return Params; } Oid HttpServerRequest::SessionId() const { if (m_Flags & kHaveSessionId) { return m_SessionId; } m_SessionId = ParseSessionId(); m_Flags |= kHaveSessionId; return m_SessionId; } uint32_t HttpServerRequest::RequestId() const { if (m_Flags & kHaveRequestId) { return m_RequestId; } m_RequestId = ParseRequestId(); m_Flags |= kHaveRequestId; return m_RequestId; } CbObject HttpServerRequest::ReadPayloadObject() { if (IoBuffer Payload = ReadPayload()) { if (m_ContentType == HttpContentType::kJSON) { std::string Json(reinterpret_cast(Payload.GetData()), Payload.GetSize()); std::string Err; CbFieldIterator It = LoadCompactBinaryFromJson(Json, Err); if (Err.empty()) { return It.AsObject(); } return CbObject(); } CbValidateError ValidationError = CbValidateError::None; if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(Payload), ValidationError); ValidationError == CbValidateError::None) { return ResponseObject; } } return {}; } CbPackage HttpServerRequest::ReadPayloadPackage() { if (IoBuffer Payload = ReadPayload()) { return ParsePackageMessage(std::move(Payload)); } return {}; } ////////////////////////////////////////////////////////////////////////// void HttpRequestRouter::AddPattern(const char* Id, const char* Regex) { ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); ZEN_ASSERT(!m_IsFinalized); m_PatternMap.insert({Id, Regex}); } void HttpRequestRouter::AddMatcher(const char* Id, std::function&& Matcher) { ZEN_ASSERT(m_MatcherNameMap.find(Id) == m_MatcherNameMap.end()); ZEN_ASSERT(!m_IsFinalized); const int MatcherIndex = gsl::narrow_cast(m_MatcherFunctions.size()); m_MatcherFunctions.push_back(Matcher); m_MatcherNameMap.insert({Id, MatcherIndex}); } void HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) { ZEN_ASSERT(!m_IsFinalized); if (ExtendableStringBuilder<128> ExpandedRegex; ProcessRegexSubstitutions(UriPattern, ExpandedRegex)) { // Regex route m_RegexHandlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), UriPattern); } else { // New-style regex-free route. More efficient and should be used for everything eventually int RegexLen = gsl::narrow_cast(strlen(UriPattern)); int i = 0; std::vector MatcherIndices; while (i < RegexLen) { if (UriPattern[i] == '{') { bool IsComplete = false; int PatternStart = i + 1; while (++i < RegexLen) { if (UriPattern[i] == '}') { std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) { // It's a match MatcherIndices.push_back(it->second); IsComplete = true; ++i; break; } else { throw std::runtime_error(fmt::format("unknown matcher pattern '{}' in URI pattern '{}'", Pattern, UriPattern)); } } } if (!IsComplete) { throw std::runtime_error(fmt::format("unterminated matcher pattern in URI pattern '{}'", UriPattern)); } } else { if (UriPattern[i] == '/') { throw std::runtime_error(fmt::format("unexpected '/' in literal segment of URI pattern '{}'", UriPattern)); } int SegmentStart = i; while (++i < RegexLen && UriPattern[i] != '/') ; std::string_view Segment(&UriPattern[SegmentStart], (i - SegmentStart)); int LiteralIndex = gsl::narrow_cast(m_Literals.size()); m_Literals.push_back(std::string(Segment)); MatcherIndices.push_back(-1 - LiteralIndex); } if (i < RegexLen && UriPattern[i] == '/') { ++i; // skip slash } } m_MatcherEndpoints.emplace_back(std::move(MatcherIndices), SupportedVerbs, std::move(HandlerFunc), UriPattern); } } std::string_view HttpRouterRequest::GetCapture(uint32_t Index) const { if (!m_CapturedSegments.empty()) { ZEN_ASSERT(Index < m_CapturedSegments.size()); return m_CapturedSegments[Index]; } ZEN_ASSERT(Index < m_Match.size()); const auto& Match = m_Match[Index]; return std::string_view(&*Match.first, Match.second - Match.first); } bool HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) { size_t RegexLen = strlen(Regex); bool HasRegex = false; std::vector UnknownPatterns; for (size_t i = 0; i < RegexLen;) { bool matched = false; if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\'))) { // Might have a pattern reference - find closing brace for (size_t j = i + 1; j < RegexLen; ++j) { if (Regex[j] == '}') { std::string Pattern(&Regex[i + 1], j - i - 1); if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) { OutExpandedRegex.Append(it->second.c_str()); HasRegex = true; } else { UnknownPatterns.push_back(Pattern); } // skip ahead i = j + 1; matched = true; break; } } } if (!matched) { OutExpandedRegex.Append(Regex[i++]); } } if (HasRegex) { if (UnknownPatterns.size() > 0) { std::string UnknownList; for (const auto& Pattern : UnknownPatterns) { if (!UnknownList.empty()) { UnknownList += ", "; } UnknownList += "'"; UnknownList += Pattern; UnknownList += "'"; } throw std::runtime_error(fmt::format("unknown pattern(s) {} in regex route '{}'", UnknownList, Regex)); } return true; } return false; } bool HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) { if (!m_IsFinalized) { m_IsFinalized = true; } const HttpVerb Verb = Request.RequestVerb(); std::string_view Uri = Request.RelativeUri(); HttpRouterRequest RouterRequest(Request); // First try new-style matcher routes for (const MatcherEndpoint& Handler : m_MatcherEndpoints) { if ((Handler.Verbs & Verb) == Verb) { size_t UriPos = 0; const size_t UriLen = Uri.length(); const std::vector& Matchers = Handler.ComponentIndices; bool IsMatch = true; std::vector CapturedSegments; CapturedSegments.emplace_back(Uri); for (int MatcherIndex : Matchers) { if (UriPos >= UriLen) { IsMatch = false; break; } if (MatcherIndex < 0) { // Literal match int LitIndex = -MatcherIndex - 1; const std::string& LitStr = m_Literals[LitIndex]; size_t LitLen = LitStr.length(); if (Uri.substr(UriPos, LitLen) == LitStr) { UriPos += LitLen; } else { IsMatch = false; break; } } else { // Matcher function size_t SegmentStart = UriPos; while (UriPos < UriLen && Uri[UriPos] != '/') { ++UriPos; } std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart); if (m_MatcherFunctions[MatcherIndex](Segment)) { CapturedSegments.push_back(Segment); } else { IsMatch = false; break; } } // Skip slash if (UriPos < UriLen && Uri[UriPos] == '/') { ++UriPos; } } if (IsMatch && UriPos == UriLen) { #if ZEN_WITH_OTEL if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } #endif RouterRequest.m_CapturedSegments = std::move(CapturedSegments); Handler.Handler(RouterRequest); return true; // Route matched } } } // Old-style regex routes for (const auto& Handler : m_RegexHandlers) { if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) { #if ZEN_WITH_OTEL if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } #endif Handler.Handler(RouterRequest); return true; // Route matched } } return false; // No route matched } ////////////////////////////////////////////////////////////////////////// int HttpServer::Initialize(int BasePort, std::filesystem::path DataDir) { return OnInitialize(BasePort, std::move(DataDir)); } void HttpServer::Run(bool IsInteractiveSession) { OnRun(IsInteractiveSession); } void HttpServer::RequestExit() { OnRequestExit(); } void HttpServer::Close() { OnClose(); } void HttpServer::RegisterService(HttpService& Service) { OnRegisterService(Service); m_KnownServices.push_back(&Service); } void HttpServer::EnumerateServices(std::function&& Callback) { // This doesn't take a lock because services should only be registered during // server initialization, before it starts accepting requests for (HttpService* Service : m_KnownServices) { Callback(*Service); } } void HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) { OnSetHttpRequestFilter(RequestFilter); } ////////////////////////////////////////////////////////////////////////// HttpRpcHandler::HttpRpcHandler() { } HttpRpcHandler::~HttpRpcHandler() { } void HttpRpcHandler::AddRpc(std::string_view RpcId, std::function HandlerFunction) { ZEN_UNUSED(RpcId, HandlerFunction); } ////////////////////////////////////////////////////////////////////////// Ref CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig& Config) { if (ServerClass == "asio"sv) { ZEN_INFO("using asio HTTP server implementation") return CreateHttpAsioServer(AsioConfig{.ThreadCount = Config.ThreadCount, .ForceLoopback = Config.ForceLoopback, .IsDedicatedServer = Config.IsDedicatedServer}); } #if ZEN_WITH_HTTPSYS else if (ServerClass == "httpsys"sv) { ZEN_INFO("using http.sys server implementation") return Ref(CreateHttpSysServer({.ThreadCount = Config.ThreadCount, .AsyncWorkThreadCount = Config.HttpSys.AsyncWorkThreadCount, .IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled, .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled, .IsDedicatedServer = Config.IsDedicatedServer, .ForceLoopback = Config.ForceLoopback})); } #endif else if (ServerClass == "null"sv) { ZEN_INFO("using null HTTP server implementation") return Ref(new HttpNullServer); } else { ZEN_WARN("unknown HTTP server implementation '{}', falling back to default", ServerClass) #if ZEN_WITH_HTTPSYS return CreateHttpServerClass("httpsys"sv, Config); #else return CreateHttpServerClass("asio"sv, Config); #endif } } #if ZEN_WITH_PLUGINS Ref CreateHttpServerPlugin(const HttpServerPluginConfig& PluginConfig) { const std::string& PluginName = PluginConfig.PluginName; ZEN_INFO("using '{}' plugin HTTP server implementation", PluginName) if (PluginName.starts_with("builtin:"sv)) { # if 0 Ref Plugin = {}; if (PluginName == "builtin:winsock"sv) { Plugin = CreateSocketTransportPlugin(); } else if (PluginName == "builtin:asio"sv) { Plugin = CreateAsioTransportPlugin(); } else { ZEN_WARN("Unknown builtin plugin '{}'", PluginName) return {}; } ZEN_ASSERT(!Plugin.IsNull()); for (const std::pair& Option : PluginConfig.PluginOptions) { Plugin->Configure(Option.first.c_str(), Option.second.c_str()); } Ref Server{CreateHttpPluginServer()}; Server->AddPlugin(Plugin); return Server; # else ZEN_WARN("Builtin plugin '{}' is not supported", PluginName) return {}; # endif } Ref DllPlugin{CreateDllTransportPlugin()}; if (!DllPlugin->LoadDll(PluginName)) { return {}; } for (const std::pair& Option : PluginConfig.PluginOptions) { DllPlugin->ConfigureDll(PluginName, Option.first.c_str(), Option.second.c_str()); } Ref Server{CreateHttpPluginServer()}; Server->AddPlugin(DllPlugin); return Server; } #endif Ref CreateHttpServer(const HttpServerConfig& Config) { using namespace std::literals; #if ZEN_WITH_PLUGINS if (Config.PluginConfigs.empty()) { return CreateHttpServerClass(Config.ServerClass, Config); } else { Ref Server{new HttpMultiServer()}; Server->AddServer(CreateHttpServerClass(Config.ServerClass, Config)); for (const HttpServerPluginConfig& PluginConfig : Config.PluginConfigs) { Ref PluginServer = CreateHttpServerPlugin(PluginConfig); if (!PluginServer.IsNull()) { Server->AddServer(PluginServer); } } return Server; } #else return CreateHttpServerClass(Config.ServerClass, Config); #endif } ////////////////////////////////////////////////////////////////////////// bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref& PackageHandlerRef) { if (Request.RequestVerb() == HttpVerb::kPost) { if (Request.RequestContentType() == HttpContentType::kCbPackageOffer) { // The client is presenting us with a package attachments offer, we need // to filter it down to the list of attachments we need them to send in // the follow-up request PackageHandlerRef = Service.HandlePackageRequest(Request); if (PackageHandlerRef) { CbValidateError ValidationError = CbValidateError::None; if (CbObject OfferMessage = ValidateAndReadCompactBinaryObject(IoBuffer(Request.ReadPayload()), ValidationError); ValidationError == CbValidateError::None) { std::vector OfferCids; for (auto& CidEntry : OfferMessage["offer"]) { if (!CidEntry.IsHash()) { // Should yield bad request response? ZEN_WARN("found invalid entry in offer"); continue; } OfferCids.push_back(CidEntry.AsHash()); } ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size()); PackageHandlerRef->FilterOffer(OfferCids); ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size()); CbObjectWriter ResponseWriter; ResponseWriter.BeginArray("need"); for (const IoHash& Cid : OfferCids) { ResponseWriter.AddHash(Cid); } ResponseWriter.EndArray(); // Emit filter response Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); } else { Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, fmt::format("Invalid request payload: '{}'", ToString(ValidationError))); } return true; } } else if (Request.RequestContentType() == HttpContentType::kCbPackage) { // Process chunks in package request PackageHandlerRef = Service.HandlePackageRequest(Request); // TODO: this should really be done in a streaming fashion, currently this emulates // the intended flow from an API perspective if (PackageHandlerRef) { PackageHandlerRef->OnRequestBegin(); auto CreateBuffer = [&](const IoHash& Cid, uint64_t Size) -> IoBuffer { return PackageHandlerRef->CreateTarget(Cid, Size); }; CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer); PackageHandlerRef->OnRequestComplete(); } } } return false; } ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS TEST_CASE("http.common") { using namespace std::literals; struct TestHttpService : public HttpService { TestHttpService() = default; virtual const char* BaseUri() const override { return "/test"; } virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override { ZEN_UNUSED(HttpServiceRequest); } }; struct TestHttpServerRequest : public HttpServerRequest { TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; } virtual IoBuffer ReadPayload() override { return IoBuffer(); } virtual bool IsLocalMachineRequest() const override { return false; } virtual std::string_view GetAuthorizationHeader() const override { return {}; } virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override { ZEN_UNUSED(ResponseCode, ContentType, Blobs); } virtual void WriteResponse(HttpResponseCode ResponseCode) override { ZEN_UNUSED(ResponseCode); } virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override { ZEN_UNUSED(ResponseCode, ContentType, ResponseString); } virtual void WriteResponseAsync(std::function&& ContinuationHandler) override { ZEN_UNUSED(ContinuationHandler); } virtual Oid ParseSessionId() const override { return Oid(); } virtual uint32_t ParseRequestId() const override { return 0; } }; SUBCASE("router-regex") { bool HandledA = false; bool HandledAA = false; std::vector Captures; auto Reset = [&] { Captures.clear(); HandledA = HandledAA = false; }; TestHttpService Service; HttpRequestRouter r; r.AddPattern("a", "([[:alpha:]]+)"); r.RegisterRoute( "{a}", [&](auto& Req) { HandledA = true; Captures = {std::string(Req.GetCapture(0))}; }, HttpVerb::kGet); r.RegisterRoute( "{a}/{a}", [&](auto& Req) { HandledAA = true; Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; }, HttpVerb::kGet); { Reset(); TestHttpServerRequest req(Service, "abc"sv); r.HandleRequest(req); CHECK(HandledA); CHECK(!HandledAA); REQUIRE_EQ(Captures.size(), 1); CHECK_EQ(Captures[0], "abc"sv); } { Reset(); TestHttpServerRequest req{Service, "abc/def"sv}; r.HandleRequest(req); CHECK(!HandledA); CHECK(HandledAA); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "abc"sv); CHECK_EQ(Captures[1], "def"sv); } { Reset(); TestHttpServerRequest req{Service, "123"sv}; r.HandleRequest(req); CHECK(!HandledA); } { Reset(); TestHttpServerRequest req{Service, "a123"sv}; r.HandleRequest(req); CHECK(!HandledA); } } SUBCASE("router-matcher") { bool HandledA = false; bool HandledAA = false; bool HandledAB = false; bool HandledAandB = false; std::vector Captures; auto Reset = [&] { HandledA = HandledAA = HandledAB = HandledAandB = false; Captures.clear(); }; TestHttpService Service; HttpRequestRouter r; r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; }); r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; }); r.RegisterRoute( "{a}", [&](auto& Req) { HandledA = true; Captures = {std::string(Req.GetCapture(1))}; }, HttpVerb::kGet); r.RegisterRoute( "{a}/{a}", [&](auto& Req) { HandledAA = true; Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; }, HttpVerb::kGet); r.RegisterRoute( "{a}/{b}", [&](auto& Req) { HandledAB = true; Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; }, HttpVerb::kGet); r.RegisterRoute( "{a}/and/{b}", [&](auto& Req) { HandledAandB = true; Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; }, HttpVerb::kGet); { Reset(); TestHttpServerRequest req{Service, "ab"sv}; r.HandleRequest(req); CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); REQUIRE_EQ(Captures.size(), 1); CHECK_EQ(Captures[0], "ab"sv); } { Reset(); TestHttpServerRequest req{Service, "ab/def"sv}; r.HandleRequest(req); CHECK(!HandledA); CHECK(!HandledAA); CHECK(HandledAB); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); } { Reset(); TestHttpServerRequest req{Service, "ab/and/def"sv}; r.HandleRequest(req); CHECK(!HandledA); CHECK(!HandledAA); CHECK(!HandledAB); CHECK(HandledAandB); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); } { Reset(); TestHttpServerRequest req{Service, "123"sv}; r.HandleRequest(req); CHECK(!HandledA); CHECK(!HandledAA); CHECK(!HandledAB); } { Reset(); TestHttpServerRequest req{Service, "a123"sv}; r.HandleRequest(req); CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); } } SUBCASE("content-type") { for (uint8_t i = 0; i < uint8_t(HttpContentType::kCOUNT); ++i) { HttpContentType Ct{i}; if (Ct != HttpContentType::kUnknownContentType) { CHECK_EQ(Ct, ParseContentType(MapContentTypeToString(Ct))); } } } } void http_forcelink() { } #endif } // namespace zen