diff options
| author | Martin Ridgers <[email protected]> | 2021-09-15 09:22:32 +0200 |
|---|---|---|
| committer | Martin Ridgers <[email protected]> | 2021-09-15 09:23:33 +0200 |
| commit | 8f5e773529858223beeecf5d1b69c23991df644e (patch) | |
| tree | 2c360c67e028f5ecd7368212b0adf8b23578ff9d | |
| parent | Use zen::Sleep() in timer.cpp's tests (diff) | |
| parent | Updated function service to new package management API (diff) | |
| download | zen-8f5e773529858223beeecf5d1b69c23991df644e.tar.xz zen-8f5e773529858223beeecf5d1b69c23991df644e.zip | |
Merge main
75 files changed, 4263 insertions, 2373 deletions
diff --git a/.gitignore b/.gitignore index 27ff65485..df6c1a9f8 100644 --- a/.gitignore +++ b/.gitignore @@ -225,3 +225,4 @@ TAGS tags .tags !tags/ +/compile_commands.json diff --git a/scripts/deploybuild.py b/scripts/deploybuild.py index 1bb052ef7..971f34ff9 100644 --- a/scripts/deploybuild.py +++ b/scripts/deploybuild.py @@ -107,7 +107,7 @@ crashpadtarget = os.path.join(target_bin_dir, "crashpad_handler.exe") try: shutil.copy(os.path.join(zenroot, "x64\Release\zenserver.exe"), os.path.join(target_bin_dir, "zenserver.exe")) shutil.copy(os.path.join(zenroot, "x64\Release\zenserver.pdb"), os.path.join(target_bin_dir, "zenserver.pdb")) - shutil.copy(os.path.join(zenroot, r'vcpkg_installed\x64-windows-static\tools\sentry-native\crashpad_handler.exe'), crashpadtarget) + shutil.copy(os.path.join(zenroot, r'vcpkg_installed\x64-windows-static\x64-windows-static\tools\sentry-native\crashpad_handler.exe'), crashpadtarget) P4.add(crashpadtarget).run() except Exception as e: print(f"Caught exception while copying: {e.args}") diff --git a/scripts/gb.exe b/scripts/gb.exe Binary files differnew file mode 100644 index 000000000..d965974d0 --- /dev/null +++ b/scripts/gb.exe diff --git a/vcpkg.json b/vcpkg.json index 4a0c41d3b..45910556a 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -21,6 +21,7 @@ "features": [ "lz4", "zstd" ] }, "sol2", - "sentry-native" + "sentry-native", + "uwebsockets" ] } @@ -17,7 +17,10 @@ add_requires( "vcpkg::curl", "vcpkg::zlib", "vcpkg::zstd", - "vcpkg::http-parser") + "vcpkg::http-parser", + "vcpkg::uwebsockets", + "vcpkg::usockets", + "vcpkg::libuv") add_rules("mode.debug", "mode.release") @@ -40,8 +43,15 @@ add_defines("USE_SENTRY=1") option("vfs") set_showmenu(true) - set_description("Enable or disable VFS functionality") - add_defines("WITH_VFS") + set_description("Enable VFS functionality") + add_defines("ZEN_WITH_VFS") +option_end() + +option("httpsys") + set_default(true) + set_showmenu(true) + set_description("Enable http.sys server") + add_defines("ZEN_WITH_HTTPSYS") option_end() add_defines("UNICODE", "_CONSOLE") @@ -53,6 +63,7 @@ set_languages("cxx20") set_symbols("debug") includes("zencore", "zencore-test") +includes("zenhttp") includes("zenstore", "zenutil") includes("zenserver", "zenserver-test") includes("zen") @@ -1,7 +1,7 @@ Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 16 -VisualStudioVersion = 16.0.28315.86 +# Visual Studio Version 17 +VisualStudioVersion = 17.0.31612.314 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{4EA55E5B-18A1-4E66-B821-44575BC11EA7}" ProjectSection(SolutionItems) = preProject @@ -45,44 +45,72 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "zenutil", "zenutil\zenutil. EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "zentest-appstub", "zentest-appstub\zentest-appstub.vcxproj", "{7FFC7E77-D038-44E9-8D84-41918C355F29}" EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "zenhttp", "zenhttp\zenhttp.vcxproj", "{8EEB3BE5-7001-46BF-AAFD-EDB7558AC012}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|x64 = Debug|x64 + Debug|x86 = Debug|x86 Release|x64 = Release|x64 + Release|x86 = Release|x86 EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution {D75BF9AB-C61E-4FFF-AD59-1563430F05E2}.Debug|x64.ActiveCfg = Debug|x64 {D75BF9AB-C61E-4FFF-AD59-1563430F05E2}.Debug|x64.Build.0 = Debug|x64 + {D75BF9AB-C61E-4FFF-AD59-1563430F05E2}.Debug|x86.ActiveCfg = Debug|x64 {D75BF9AB-C61E-4FFF-AD59-1563430F05E2}.Release|x64.ActiveCfg = Release|x64 {D75BF9AB-C61E-4FFF-AD59-1563430F05E2}.Release|x64.Build.0 = Release|x64 + {D75BF9AB-C61E-4FFF-AD59-1563430F05E2}.Release|x86.ActiveCfg = Release|x64 {C00173DF-B76E-4989-B576-FE2B780B2580}.Debug|x64.ActiveCfg = Debug|x64 {C00173DF-B76E-4989-B576-FE2B780B2580}.Debug|x64.Build.0 = Debug|x64 + {C00173DF-B76E-4989-B576-FE2B780B2580}.Debug|x86.ActiveCfg = Debug|x64 {C00173DF-B76E-4989-B576-FE2B780B2580}.Release|x64.ActiveCfg = Release|x64 {C00173DF-B76E-4989-B576-FE2B780B2580}.Release|x64.Build.0 = Release|x64 + {C00173DF-B76E-4989-B576-FE2B780B2580}.Release|x86.ActiveCfg = Release|x64 {8398D81C-B1B6-4327-82B1-06EACB8A144F}.Debug|x64.ActiveCfg = Debug|x64 {8398D81C-B1B6-4327-82B1-06EACB8A144F}.Debug|x64.Build.0 = Debug|x64 + {8398D81C-B1B6-4327-82B1-06EACB8A144F}.Debug|x86.ActiveCfg = Debug|x64 {8398D81C-B1B6-4327-82B1-06EACB8A144F}.Release|x64.ActiveCfg = Release|x64 {8398D81C-B1B6-4327-82B1-06EACB8A144F}.Release|x64.Build.0 = Release|x64 + {8398D81C-B1B6-4327-82B1-06EACB8A144F}.Release|x86.ActiveCfg = Release|x64 {CA7B9E04-A2D3-4A39-A7D7-FB156A2C6A48}.Debug|x64.ActiveCfg = Debug|x64 {CA7B9E04-A2D3-4A39-A7D7-FB156A2C6A48}.Debug|x64.Build.0 = Debug|x64 + {CA7B9E04-A2D3-4A39-A7D7-FB156A2C6A48}.Debug|x86.ActiveCfg = Debug|x64 {CA7B9E04-A2D3-4A39-A7D7-FB156A2C6A48}.Release|x64.ActiveCfg = Release|x64 {CA7B9E04-A2D3-4A39-A7D7-FB156A2C6A48}.Release|x64.Build.0 = Release|x64 + {CA7B9E04-A2D3-4A39-A7D7-FB156A2C6A48}.Release|x86.ActiveCfg = Release|x64 {2563249E-E695-4CC4-8FFA-335D07680C9D}.Debug|x64.ActiveCfg = Debug|x64 {2563249E-E695-4CC4-8FFA-335D07680C9D}.Debug|x64.Build.0 = Debug|x64 + {2563249E-E695-4CC4-8FFA-335D07680C9D}.Debug|x86.ActiveCfg = Debug|x64 {2563249E-E695-4CC4-8FFA-335D07680C9D}.Release|x64.ActiveCfg = Release|x64 {2563249E-E695-4CC4-8FFA-335D07680C9D}.Release|x64.Build.0 = Release|x64 + {2563249E-E695-4CC4-8FFA-335D07680C9D}.Release|x86.ActiveCfg = Release|x64 {26CBBAEB-14C1-4EFC-877D-80F48215651C}.Debug|x64.ActiveCfg = Debug|x64 {26CBBAEB-14C1-4EFC-877D-80F48215651C}.Debug|x64.Build.0 = Debug|x64 + {26CBBAEB-14C1-4EFC-877D-80F48215651C}.Debug|x86.ActiveCfg = Debug|x64 {26CBBAEB-14C1-4EFC-877D-80F48215651C}.Release|x64.ActiveCfg = Release|x64 {26CBBAEB-14C1-4EFC-877D-80F48215651C}.Release|x64.Build.0 = Release|x64 + {26CBBAEB-14C1-4EFC-877D-80F48215651C}.Release|x86.ActiveCfg = Release|x64 {77F8315D-B21D-4DB0-9A6F-2D3359F88A70}.Debug|x64.ActiveCfg = Debug|x64 {77F8315D-B21D-4DB0-9A6F-2D3359F88A70}.Debug|x64.Build.0 = Debug|x64 + {77F8315D-B21D-4DB0-9A6F-2D3359F88A70}.Debug|x86.ActiveCfg = Debug|x64 {77F8315D-B21D-4DB0-9A6F-2D3359F88A70}.Release|x64.ActiveCfg = Release|x64 {77F8315D-B21D-4DB0-9A6F-2D3359F88A70}.Release|x64.Build.0 = Release|x64 + {77F8315D-B21D-4DB0-9A6F-2D3359F88A70}.Release|x86.ActiveCfg = Release|x64 {7FFC7E77-D038-44E9-8D84-41918C355F29}.Debug|x64.ActiveCfg = Debug|x64 {7FFC7E77-D038-44E9-8D84-41918C355F29}.Debug|x64.Build.0 = Debug|x64 + {7FFC7E77-D038-44E9-8D84-41918C355F29}.Debug|x86.ActiveCfg = Debug|Win32 + {7FFC7E77-D038-44E9-8D84-41918C355F29}.Debug|x86.Build.0 = Debug|Win32 {7FFC7E77-D038-44E9-8D84-41918C355F29}.Release|x64.ActiveCfg = Release|x64 {7FFC7E77-D038-44E9-8D84-41918C355F29}.Release|x64.Build.0 = Release|x64 + {7FFC7E77-D038-44E9-8D84-41918C355F29}.Release|x86.ActiveCfg = Release|Win32 + {7FFC7E77-D038-44E9-8D84-41918C355F29}.Release|x86.Build.0 = Release|Win32 + {8EEB3BE5-7001-46BF-AAFD-EDB7558AC012}.Debug|x64.ActiveCfg = Debug|x64 + {8EEB3BE5-7001-46BF-AAFD-EDB7558AC012}.Debug|x64.Build.0 = Debug|x64 + {8EEB3BE5-7001-46BF-AAFD-EDB7558AC012}.Debug|x86.ActiveCfg = Debug|x64 + {8EEB3BE5-7001-46BF-AAFD-EDB7558AC012}.Release|x64.ActiveCfg = Release|x64 + {8EEB3BE5-7001-46BF-AAFD-EDB7558AC012}.Release|x64.Build.0 = Release|x64 + {8EEB3BE5-7001-46BF-AAFD-EDB7558AC012}.Release|x86.ActiveCfg = Release|x64 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/zen/cmds/cache.cpp b/zen/cmds/cache.cpp index a9d017a2b..3f4a4cdc3 100644 --- a/zen/cmds/cache.cpp +++ b/zen/cmds/cache.cpp @@ -3,6 +3,7 @@ #include "cache.h" #include <zencore/filesystem.h> +#include <zenhttp/httpcommon.h> #include <zenserverprocess.h> #include <spdlog/spdlog.h> @@ -44,7 +45,7 @@ DropCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) Session.SetUrl({"{}/z$/{}"_format(m_HostName, m_BucketName)}); cpr::Response Result = Session.Delete(); - if (Result.status_code >= 200 && Result.status_code < 300) + if (zen::IsHttpSuccessCode(Result.status_code)) { spdlog::info("OK: dropped cache bucket '{}' from '{}'", m_BucketName, m_HostName); diff --git a/zen/zen.vcxproj b/zen/zen.vcxproj index 322b70850..4f0691fab 100644 --- a/zen/zen.vcxproj +++ b/zen/zen.vcxproj @@ -124,6 +124,9 @@ <ProjectReference Include="..\zencore\zencore.vcxproj"> <Project>{d75bf9ab-c61e-4fff-ad59-1563430f05e2}</Project> </ProjectReference> + <ProjectReference Include="..\zenhttp\zenhttp.vcxproj"> + <Project>{8eeb3be5-7001-46bf-aafd-edb7558ac012}</Project> + </ProjectReference> <ProjectReference Include="..\zenstore\zenstore.vcxproj"> <Project>{26cbbaeb-14c1-4efc-877d-80f48215651c}</Project> </ProjectReference> diff --git a/zencore/compactbinary.cpp b/zencore/compactbinary.cpp index 5fe7f272d..b508d8fe8 100644 --- a/zencore/compactbinary.cpp +++ b/zencore/compactbinary.cpp @@ -12,7 +12,7 @@ namespace zen { -const int DaysToMonth[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; +const int DaysToMonth[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; bool IsLeapYear(int Year) diff --git a/zencore/compactbinarypackage.cpp b/zencore/compactbinarypackage.cpp index 7880164f9..9a7e7c098 100644 --- a/zencore/compactbinarypackage.cpp +++ b/zencore/compactbinarypackage.cpp @@ -16,22 +16,45 @@ CbAttachment::CbAttachment(const CompressedBuffer& InValue) : CbAttachment(InVal { } -CbAttachment::CbAttachment(CompressedBuffer&& InValue) +CbAttachment::CbAttachment(const SharedBuffer& InValue) : CbAttachment(CompositeBuffer(InValue)) { - Value.emplace<CompressedBuffer>(std::move(InValue).MakeOwned()); } -CbAttachment::CbAttachment(const SharedBuffer& InValue) -: CbAttachment(InValue.IsNull() ? CompressedBuffer() - : CompressedBuffer::Compress(InValue, OodleCompressor::NotSet, OodleCompressionLevel::None)) +CbAttachment::CbAttachment(const SharedBuffer& InValue, [[maybe_unused]] const IoHash& InHash) +: CbAttachment(CompositeBuffer(InValue), InHash) { } -CbAttachment::CbAttachment(const SharedBuffer& InValue, [[maybe_unused]] const IoHash& InHash) -: CbAttachment(InValue.IsNull() ? CompressedBuffer() - : CompressedBuffer::Compress(InValue, OodleCompressor::NotSet, OodleCompressionLevel::None)) +CbAttachment::CbAttachment(const CompositeBuffer& InValue) : Value{std::in_place_type<BinaryValue>, InValue} +{ + if (std::get<BinaryValue>(Value).Buffer.IsNull()) + { + Value.emplace<nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompositeBuffer&& InValue) : Value{std::in_place_type<BinaryValue>, InValue} +{ + if (std::get<BinaryValue>(Value).Buffer.IsNull()) + { + Value.emplace<nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompositeBuffer&& InValue, const IoHash& Hash) : Value{std::in_place_type<BinaryValue>, InValue, Hash} +{ + if (std::get<BinaryValue>(Value).Buffer.IsNull()) + { + Value.emplace<nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompressedBuffer&& InValue) : Value(std::in_place_type<CompressedBuffer>, InValue) { - // This could be more efficient, and should at the very least try to validate the hash + if (std::get<CompressedBuffer>(Value).IsNull()) + { + Value.emplace<nullptr_t>(); + } } CbAttachment::CbAttachment(const CbObject& InValue, const IoHash* const InHash) @@ -70,114 +93,139 @@ CbAttachment::TryLoad(IoBuffer& InBuffer, BufferAllocator Allocator) bool CbAttachment::TryLoad(CbFieldIterator& Fields) { - const CbObjectView ObjectView = Fields.AsObjectView(); - if (Fields.HasError()) + if (const CbObjectView ObjectView = Fields.AsObjectView(); !Fields.HasError()) + { + // Is a null object or object not prefixed with a precomputed hash value + Value.emplace<CbObjectValue>(CbObject(ObjectView, Fields.GetOuterBuffer()), ObjectView.GetHash()); + ++Fields; + } + else if (const IoHash ObjectAttachmentHash = Fields.AsObjectAttachment(); !Fields.HasError()) + { + // Is an object + ++Fields; + const CbObjectView InnerObjectView = Fields.AsObjectView(); + if (Fields.HasError()) + { + return false; + } + Value.emplace<CbObjectValue>(CbObject(InnerObjectView, Fields.GetOuterBuffer()), ObjectAttachmentHash); + ++Fields; + } + else if (const IoHash BinaryAttachmentHash = Fields.AsBinaryAttachment(); !Fields.HasError()) + { + // Is an uncompressed binary blob + ++Fields; + MemoryView BinaryView = Fields.AsBinaryView(); + if (Fields.HasError()) + { + return false; + } + Value.emplace<BinaryValue>(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer()), BinaryAttachmentHash); + ++Fields; + } + else if (MemoryView BinaryView = Fields.AsBinaryView(); !Fields.HasError()) { - // Is a buffer - const MemoryView BinaryView = Fields.AsBinaryView(); if (BinaryView.GetSize() > 0) { + // Is a compressed binary blob Value.emplace<CompressedBuffer>( CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer())).MakeOwned()); - ++Fields; } else { + // Is an uncompressed empty binary blob + Value.emplace<BinaryValue>(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer()), IoHash::HashBuffer(nullptr, 0)); ++Fields; - Value.emplace<CompressedBuffer>(); } } else { - // It's an object - ++Fields; - IoHash Hash; - if (ObjectView) - { - Hash = Fields.AsObjectAttachment(); - ++Fields; - } - else - { - Hash = IoHash::HashBuffer(MemoryView{}); - } - Value.emplace<CbObjectValue>(CbObject(ObjectView, Fields->GetOuterBuffer()), Hash); + return false; } return true; } -bool -CbAttachment::TryLoad(BinaryReader& Reader, BufferAllocator Allocator) +static bool +TryLoad_ArchiveFieldIntoAttachment(CbAttachment& TargetAttachment, CbField&& Field, BinaryReader& Reader, BufferAllocator Allocator) { - CbField Field = LoadCompactBinary(Reader, Allocator); - const CbObjectView ObjectView = Field.AsObjectView(); - - if (Field.HasError()) + if (const CbObjectView ObjectView = Field.AsObjectView(); !Field.HasError()) { - // It's a buffer - const MemoryView BinaryView = Field.AsBinaryView(); - if (BinaryView.GetSize() > 0) + // Is a null object or object not prefixed with a precomputed hash value + TargetAttachment = CbAttachment(CbObject(ObjectView, std::move(Field)), ObjectView.GetHash()); + } + else if (const IoHash ObjectAttachmentHash = Field.AsObjectAttachment(); !Field.HasError()) + { + // Is an object + Field = LoadCompactBinary(Reader, Allocator); + if (!Field.IsObject()) { - Value.emplace<CompressedBuffer>( - CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView, Field.GetOuterBuffer())).MakeOwned()); + return false; } - else + TargetAttachment = CbAttachment(std::move(Field).AsObject(), ObjectAttachmentHash); + } + else if (const IoHash BinaryAttachmentHash = Field.AsBinaryAttachment(); !Field.HasError()) + { + // Is an uncompressed binary blob + Field = LoadCompactBinary(Reader, Allocator); + SharedBuffer Buffer = Field.AsBinary(); + if (Field.HasError()) { - Value.emplace<CompressedBuffer>(); + return false; } + TargetAttachment = CbAttachment(CompositeBuffer(Buffer), BinaryAttachmentHash); } - else + else if (SharedBuffer Buffer = Field.AsBinary(); !Field.HasError()) { - // It's an object - IoHash Hash; - if (ObjectView) + if (Buffer.GetSize() > 0) { - std::vector<uint8_t> HashBuffer; - CbField HashField = LoadCompactBinary(Reader, [&HashBuffer](uint64_t Size) -> UniqueBuffer { - HashBuffer.resize(Size); - return UniqueBuffer::MakeMutableView(HashBuffer.data(), Size); - }); - Hash = HashField.AsAttachment(); - if (HashField.HasError() || ObjectView.GetHash() != Hash) - { - // Error - return false; - } + // Is a compressed binary blob + TargetAttachment = CbAttachment(CompressedBuffer::FromCompressed(std::move(Buffer))); } else { - Hash = IoHash::HashBuffer(MemoryView()); + // Is an uncompressed empty binary blob + TargetAttachment = CbAttachment(CompositeBuffer(Buffer), IoHash::HashBuffer(nullptr, 0)); } - Value.emplace<CbObjectValue>(CbObject(ObjectView, Field.GetOuterBuffer()), Hash); + } + else + { + return false; } return true; } +bool +CbAttachment::TryLoad(BinaryReader& Reader, BufferAllocator Allocator) +{ + CbField Field = LoadCompactBinary(Reader, Allocator); + return TryLoad_ArchiveFieldIntoAttachment(*this, std::move(Field), Reader, Allocator); +} + void CbAttachment::Save(CbWriter& Writer) const { - if (const CbObjectValue* ObjectValue = std::get_if<CbObjectValue>(&Value)) + if (const CbObjectValue* ObjValue = std::get_if<CbObjectValue>(&Value)) { - Writer.AddObject(ObjectValue->Object); - if (ObjectValue->Object) + if (ObjValue->Object) { - Writer.AddObjectAttachment(ObjectValue->Hash); + Writer.AddObjectAttachment(ObjValue->Hash); } + Writer.AddObject(ObjValue->Object); } - else + else if (const BinaryValue* BinValue = std::get_if<BinaryValue>(&Value)) { - const CompressedBuffer& BufferValue = std::get<CompressedBuffer>(Value); - if (BufferValue.GetRawSize()) - { - Writer.AddBinary(BufferValue.GetCompressed()); - } - else // Null + if (BinValue->Buffer.GetSize() > 0) { - Writer.AddBinary(MemoryView()); + Writer.AddBinaryAttachment(BinValue->Hash); } + Writer.AddBinary(BinValue->Buffer); + } + else if (const CompressedBuffer* BufferValue = std::get_if<CompressedBuffer>(&Value)) + { + Writer.AddBinary(BufferValue->GetCompressed()); } } @@ -192,14 +240,19 @@ CbAttachment::Save(BinaryWriter& Writer) const bool CbAttachment::IsNull() const { - if (const CompressedBuffer* Buffer = std::get_if<CompressedBuffer>(&Value)) - { - return Buffer->IsNull(); - } - else - { - return false; - } + return std::holds_alternative<nullptr_t>(Value); +} + +bool +CbAttachment::IsBinary() const +{ + return std::holds_alternative<BinaryValue>(Value); +} + +bool +CbAttachment::IsCompressedBinary() const +{ + return std::holds_alternative<CompressedBuffer>(Value); } bool @@ -213,25 +266,42 @@ CbAttachment::GetHash() const { if (const CompressedBuffer* Buffer = std::get_if<CompressedBuffer>(&Value)) { - return Buffer->IsNull() ? IoHash::HashBuffer(MemoryView()) : IoHash::FromBLAKE3(Buffer->GetRawHash()); + return IoHash::FromBLAKE3(Buffer->GetRawHash()); } - else + + if (const BinaryValue* BinValue = std::get_if<BinaryValue>(&Value)) { - return std::get<CbObjectValue>(Value).Hash; + return BinValue->Hash; } + + if (const CbObjectValue* ObjectValue = std::get_if<CbObjectValue>(&Value)) + { + return ObjectValue->Hash; + } + + return IoHash::Zero; } -SharedBuffer -CbAttachment::AsBinary() const +CompositeBuffer +CbAttachment::AsCompositeBinary() const { - if (const CompressedBuffer* Buffer = std::get_if<CompressedBuffer>(&Value)) + if (const BinaryValue* BinValue = std::get_if<BinaryValue>(&Value)) { - return Buffer->Decompress(); + return BinValue->Buffer; } - else + + return CompositeBuffer::Null; +} + +SharedBuffer +CbAttachment::AsBinary() const +{ + if (const BinaryValue* BinValue = std::get_if<BinaryValue>(&Value)) { - return std::get<CbObjectValue>(Value).Object.GetBuffer(); + return BinValue->Buffer.Flatten(); } + + return {}; } CompressedBuffer @@ -241,12 +311,8 @@ CbAttachment::AsCompressedBinary() const { return *Buffer; } - else - { - return CompressedBuffer::Compress(std::get<CbObjectValue>(Value).Object.GetBuffer(), - OodleCompressor::NotSet, - OodleCompressionLevel::None); - } + + return CompressedBuffer::Null; } /** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */ @@ -257,10 +323,8 @@ CbAttachment::AsObject() const { return ObjectValue->Object; } - else - { - return {}; - } + + return {}; } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -301,10 +365,7 @@ CbPackage::AddAttachment(const CbAttachment& Attachment, AttachmentResolver* Res if (It != Attachments.end() && *It == Attachment) { CbAttachment& Existing = *It; - if (Attachment.IsObject() && !Existing.IsObject()) - { - Existing = CbAttachment(CbObject(Existing.AsBinary()), Existing.GetHash()); - } + Existing = Attachment; } else { @@ -358,14 +419,14 @@ CbPackage::GatherAttachments(const CbObject& Value, AttachmentResolver Resolver) } else { - AddAttachment(CbAttachment(std::move(Buffer), Hash)); + AddAttachment(CbAttachment(std::move(Buffer))); } } }); } bool -CbPackage::TryLoad(IoBuffer& InBuffer, BufferAllocator Allocator, AttachmentResolver* Mapper) +CbPackage::TryLoad(IoBuffer InBuffer, BufferAllocator Allocator, AttachmentResolver* Mapper) { MemoryInStream InStream(InBuffer.Data(), InBuffer.Size()); BinaryReader Reader(InStream); @@ -377,6 +438,7 @@ bool CbPackage::TryLoad(CbFieldIterator& Fields) { *this = CbPackage(); + while (Fields) { if (Fields.IsNull()) @@ -384,43 +446,76 @@ CbPackage::TryLoad(CbFieldIterator& Fields) ++Fields; break; } - else if (Fields.IsBinary()) + else if (IoHash Hash = Fields.AsHash(); !Fields.HasError() && !Fields.IsAttachment()) { - CbAttachment Attachment; - Attachment.TryLoad(Fields); - AddAttachment(Attachment); - } - else - { - Object = Fields.AsObject(); - if (Fields->HasError()) + ++Fields; + CbObjectView ObjectView = Fields.AsObjectView(); + if (Fields.HasError() || Hash != ObjectView.GetHash()) { return false; } + Object = CbObject(ObjectView, Fields.GetOuterBuffer()); Object.MakeOwned(); + ObjectHash = Hash; ++Fields; - if (Object.CreateIterator()) - { - ObjectHash = Fields.AsObjectAttachment(); - if (Fields.HasError()) - { - return false; - } - ++Fields; - } - else + } + else + { + CbAttachment Attachment; + if (!Attachment.TryLoad(Fields)) { - Object.Reset(); + return false; } + AddAttachment(Attachment); } } - return true; } bool CbPackage::TryLoad(BinaryReader& Reader, BufferAllocator Allocator, AttachmentResolver* Mapper) { + // TODO: this needs to re-grow the ability to accept a reference to an attachment which is + // not embedded + + ZEN_UNUSED(Mapper); + +#if 1 + *this = CbPackage(); + for (;;) + { + CbField Field = LoadCompactBinary(Reader, Allocator); + if (!Field) + { + return false; + } + + if (Field.IsNull()) + { + return true; + } + else if (IoHash Hash = Field.AsHash(); !Field.HasError() && !Field.IsAttachment()) + { + Field = LoadCompactBinary(Reader, Allocator); + CbObjectView ObjectView = Field.AsObjectView(); + if (Field.HasError() || Hash != ObjectView.GetHash()) + { + return false; + } + Object = CbObject(ObjectView, Field.GetOuterBuffer()); + ObjectHash = Hash; + } + else + { + CbAttachment Attachment; + if (!TryLoad_ArchiveFieldIntoAttachment(Attachment, std::move(Field), Reader, Allocator)) + { + return false; + } + AddAttachment(Attachment); + } + } +#else uint8_t StackBuffer[64]; const auto StackAllocator = [&Allocator, &StackBuffer](uint64_t Size) -> UniqueBuffer { if (Size <= sizeof(StackBuffer)) @@ -494,6 +589,7 @@ CbPackage::TryLoad(BinaryReader& Reader, BufferAllocator Allocator, AttachmentRe } } } +#endif } void @@ -501,8 +597,8 @@ CbPackage::Save(CbWriter& Writer) const { if (Object) { + Writer.AddHash(ObjectHash); Writer.AddObject(Object); - Writer.AddObjectAttachment(ObjectHash); } for (const CbAttachment& Attachment : Attachments) { @@ -519,6 +615,136 @@ CbPackage::Save(BinaryWriter& StreamWriter) const Writer.Save(StreamWriter); } +////////////////////////////////////////////////////////////////////////// +// +// Legacy package serialization support +// + +namespace legacy { + + void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer) + { + if (Attachment.IsObject()) + { + CbObject Object = Attachment.AsObject(); + Writer.AddBinary(Object.GetBuffer()); + if (Object) + { + Writer.AddObjectAttachment(Attachment.GetHash()); + } + } + else if (Attachment.IsBinary()) + { + Writer.AddBinary(Attachment.AsBinary()); + Writer.AddBinaryAttachment(Attachment.GetHash()); + } + else if (Attachment.IsNull()) + { + Writer.AddBinary(MemoryView()); + } + else + { + ZEN_NOT_IMPLEMENTED("Compressed binary is not supported in this serialization format"); + } + } + + void SaveCbPackage(const CbPackage& Package, CbWriter& Writer) + { + if (const CbObject& RootObject = Package.GetObject()) + { + Writer.AddObject(RootObject); + Writer.AddObjectAttachment(Package.GetObjectHash()); + } + for (const CbAttachment& Attachment : Package.GetAttachments()) + { + SaveCbAttachment(Attachment, Writer); + } + Writer.AddNull(); + } + + void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar) + { + CbWriter Writer; + SaveCbPackage(Package, Writer); + Writer.Save(Ar); + } + + bool TryLoadCbPackage(CbPackage& Package, IoBuffer InBuffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + { + MemoryInStream InStream(InBuffer.Data(), InBuffer.Size()); + BinaryReader Reader(InStream); + + return TryLoadCbPackage(Package, Reader, Allocator, Mapper); + } + + bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + { + Package = CbPackage(); + for (;;) + { + CbField ValueField = LoadCompactBinary(Reader, Allocator); + if (!ValueField) + { + return false; + } + if (ValueField.IsNull()) + { + return true; + } + if (ValueField.IsBinary()) + { + const MemoryView View = ValueField.AsBinaryView(); + if (View.GetSize() > 0) + { + SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned(); + CbField HashField = LoadCompactBinary(Reader, Allocator); + const IoHash& Hash = HashField.AsAttachment(); + if (HashField.HasError() || IoHash::HashBuffer(Buffer) != Hash) + { + return false; + } + if (HashField.IsObjectAttachment()) + { + Package.AddAttachment(CbAttachment(CbObject(std::move(Buffer)), Hash)); + } + else + { + Package.AddAttachment(CbAttachment(CompositeBuffer(std::move(Buffer)), Hash)); + } + } + } + else if (ValueField.IsHash()) + { + const IoHash Hash = ValueField.AsHash(); + + ZEN_ASSERT(Mapper); + + Package.AddAttachment(CbAttachment((*Mapper)(Hash), Hash)); + } + else + { + CbObject Object = ValueField.AsObject(); + if (ValueField.HasError()) + { + return false; + } + + if (Object) + { + CbField HashField = LoadCompactBinary(Reader, Allocator); + IoHash ObjectHash = HashField.AsObjectAttachment(); + if (HashField.HasError() || Object.GetHash() != ObjectHash) + { + return false; + } + Package.SetObject(Object, ObjectHash); + } + } + } + } + +} // namespace legacy + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// void @@ -567,8 +793,7 @@ TEST_CASE("usonpackage") CHECK_FALSE(bool(Attachment.AsObject())); CHECK_FALSE(Attachment.IsBinary()); CHECK_FALSE(Attachment.IsObject()); - CHECK(Attachment.GetHash() == IoHash::HashBuffer({})); - TestSaveLoadValidate("Null", Attachment); + CHECK(Attachment.GetHash() == IoHash::Zero); } SUBCASE("Binary Attachment") @@ -596,12 +821,12 @@ TEST_CASE("usonpackage") CHECK_FALSE(Attachment.IsNull()); CHECK(bool(Attachment)); - CHECK(Attachment.AsBinary() == Object.GetBuffer()); + CHECK(Attachment.AsBinary() == SharedBuffer()); CHECK(Attachment.AsObject().Equals(Object)); - CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsBinary()); CHECK(Attachment.IsObject()); CHECK(Attachment.GetHash() == Object.GetHash()); - TestSaveLoadValidate("CompactBinary", Attachment); + TestSaveLoadValidate("Object", Attachment); } SUBCASE("Binary View") @@ -633,7 +858,7 @@ TEST_CASE("usonpackage") CHECK(Attachment.AsBinary() != ObjectView.GetBuffer()); CHECK(Attachment.AsObject().Equals(Object)); - CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsBinary()); CHECK(Attachment.IsObject()); CHECK(Attachment.GetHash() == IoHash(Object.GetHash())); } @@ -677,15 +902,14 @@ TEST_CASE("usonpackage") CbFieldIterator FieldsView = CbFieldIterator::MakeRangeView(CbFieldViewIterator(Fields)); Attachment.TryLoad(FieldsView); + MemoryView View; CHECK_FALSE(Attachment.IsNull()); CHECK(bool(Attachment)); - - CHECK(Attachment.AsBinary().GetView().EqualBytes(Value.GetView())); - CHECK_FALSE(FieldsView.GetBuffer().GetView().Contains(Attachment.AsObject().GetBuffer().GetView())); - CHECK(Attachment.IsBinary()); + CHECK(Attachment.AsBinary().GetView().EqualBytes(MemoryView())); + CHECK_FALSE((!Attachment.AsObject().TryGetSerializedView(View) || FieldsView.GetOuterBuffer().GetView().Contains(View))); + CHECK_FALSE(Attachment.IsBinary()); CHECK(Attachment.IsObject()); - CHECK(Attachment.GetHash() == Value.GetHash()); } @@ -696,7 +920,7 @@ TEST_CASE("usonpackage") CHECK(Attachment.IsNull()); CHECK_FALSE(Attachment.IsBinary()); CHECK_FALSE(Attachment.IsObject()); - CHECK(Attachment.GetHash() == IoHash::HashBuffer(SharedBuffer{})); + CHECK(Attachment.GetHash() == IoHash::Zero); } SUBCASE("Binary Empty") @@ -714,7 +938,7 @@ TEST_CASE("usonpackage") const CbAttachment Attachment(CbObject{}); CHECK_FALSE(Attachment.IsNull()); - CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsBinary()); CHECK(Attachment.IsObject()); CHECK(Attachment.GetHash() == CbObject().GetHash()); } @@ -833,17 +1057,16 @@ TEST_CASE("usonpackage.serialization") CHECK((Object1Attachment && Object1Attachment->AsObject().Equals(Object1))); CHECK((Object2Attachment && Object2Attachment->AsBinary() == Object2.GetBuffer())); - Package.AddAttachment(CbAttachment(SharedBuffer::Clone(Object1.GetView()))); + SharedBuffer Object1ClonedBuffer = SharedBuffer::Clone(Object1.GetOuterBuffer()); + Package.AddAttachment(CbAttachment(Object1ClonedBuffer)); Package.AddAttachment(CbAttachment(CbObject::Clone(Object2))); CHECK(Package.GetAttachments().size() == 2); CHECK(Package.FindAttachment(Object1.GetHash()) == Object1Attachment); CHECK(Package.FindAttachment(Object2.GetHash()) == Object2Attachment); - CHECK((Object1Attachment && Object1Attachment->AsObject().Equals(Object1))); - CHECK((Object1Attachment && Object1Attachment->AsBinary() == Object1.GetBuffer())); + CHECK((Object1Attachment && Object1Attachment->AsBinary() == Object1ClonedBuffer)); CHECK((Object2Attachment && Object2Attachment->AsObject().Equals(Object2))); - CHECK((Object2Attachment && Object2Attachment->AsBinary() == Object2.GetBuffer())); CHECK(std::is_sorted(begin(Package.GetAttachments()), end(Package.GetAttachments()))); } @@ -884,8 +1107,8 @@ TEST_CASE("usonpackage.serialization") const IoHash Level1Hash = Level1.GetHash(); const auto Resolver = [&Level2, &Level2Hash, &Level3, &Level3Hash, &Level4, &Level4Hash](const IoHash& Hash) -> SharedBuffer { - return Hash == Level2Hash ? Level2.GetBuffer() - : Hash == Level3Hash ? Level3.GetBuffer() + return Hash == Level2Hash ? Level2.GetOuterBuffer() + : Hash == Level3Hash ? Level3.GetOuterBuffer() : Hash == Level4Hash ? Level4 : SharedBuffer(); }; @@ -907,8 +1130,9 @@ TEST_CASE("usonpackage.serialization") const CbAttachment* const Level4Attachment = Package.FindAttachment(Level4Hash); CHECK((Level2Attachment && Level2Attachment->AsObject().Equals(Level2))); CHECK((Level3Attachment && Level3Attachment->AsObject().Equals(Level3))); - CHECK((Level4Attachment && Level4Attachment->AsBinary() != Level4 && - Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView()))); + REQUIRE(Level4Attachment); + CHECK(Level4Attachment->AsBinary() != Level4); + CHECK(Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView())); CHECK(std::is_sorted(begin(Package.GetAttachments()), end(Package.GetAttachments()))); @@ -932,15 +1156,15 @@ TEST_CASE("usonpackage.serialization") // Out of Order { - CbWriter Writer; - Writer.AddBinary(Level2.GetBuffer()); - Writer.AddObjectAttachment(Level2Hash); - Writer.AddBinary(Level4); - Writer.AddBinaryAttachment(Level4Hash); + CbWriter Writer; + CbAttachment Attachment2(Level2, Level2Hash); + Attachment2.Save(Writer); + CbAttachment Attachment4(Level4); + Attachment4.Save(Writer); + Writer.AddHash(Level1Hash); Writer.AddObject(Level1); - Writer.AddObjectAttachment(Level1Hash); - Writer.AddBinary(Level3.GetBuffer()); - Writer.AddObjectAttachment(Level3Hash); + CbAttachment Attachment3(Level3, Level3Hash); + Attachment3.Save(Writer); Writer.AddNull(); CbFieldIterator Fields = Writer.Save(); @@ -961,11 +1185,9 @@ TEST_CASE("usonpackage.serialization") const MemoryView FieldsOuterBufferView = Fields.GetOuterBuffer().GetView(); CHECK(Level2Attachment->AsObject().Equals(Level2)); - CHECK(FieldsOuterBufferView.Contains(Level2Attachment->AsBinary().GetView())); CHECK(Level2Attachment->GetHash() == Level2Hash); CHECK(Level3Attachment->AsObject().Equals(Level3)); - CHECK(FieldsOuterBufferView.Contains(Level3Attachment->AsBinary().GetView())); CHECK(Level3Attachment->GetHash() == Level3Hash); CHECK(Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView())); @@ -983,21 +1205,23 @@ TEST_CASE("usonpackage.serialization") Writer.Reset(); FromArchive.Save(Writer); CbFieldIterator Saved = Writer.Save(); - CHECK(Saved.AsObject().Equals(Level1)); + + CHECK(Saved.AsHash() == Level1Hash); ++Saved; - CHECK(Saved.AsObjectAttachment() == Level1Hash); + CHECK(Saved.AsObject().Equals(Level1)); ++Saved; - CHECK(Saved.AsBinaryView().EqualBytes(Level2.GetView())); + CHECK_EQ(Saved.AsObjectAttachment(), Level2Hash); ++Saved; - CHECK(Saved.AsObjectAttachment() == Level2Hash); + CHECK(Saved.AsObject().Equals(Level2)); ++Saved; - CHECK(Saved.AsBinaryView().EqualBytes(Level3.GetView())); + CHECK_EQ(Saved.AsObjectAttachment(), Level3Hash); ++Saved; - CHECK(Saved.AsObjectAttachment() == Level3Hash); + CHECK(Saved.AsObject().Equals(Level3)); ++Saved; - CHECK(Saved.AsBinaryView().EqualBytes(Level4.GetView())); + CHECK_EQ(Saved.AsBinaryAttachment(), Level4Hash); ++Saved; - CHECK(Saved.AsBinaryAttachment() == Level4Hash); + SharedBuffer SavedLevel4Buffer = SharedBuffer::MakeView(Saved.AsBinaryView()); + CHECK(SavedLevel4Buffer.GetView().EqualBytes(Level4.GetView())); ++Saved; CHECK(Saved.IsNull()); ++Saved; diff --git a/zencore/compactbinaryvalidation.cpp b/zencore/compactbinaryvalidation.cpp index 52f625313..316da76a6 100644 --- a/zencore/compactbinaryvalidation.cpp +++ b/zencore/compactbinaryvalidation.cpp @@ -416,92 +416,125 @@ ValidateCbPackageField(MemoryView& View, CbValidateMode Mode, CbValidateError& E static IoHash ValidateCbPackageAttachment(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error) { - const CbObjectView ObjectView = Value.AsObjectView(); - if (Value.HasError()) + if (const CbObjectView ObjectView = Value.AsObjectView(); !Value.HasError()) { - const MemoryView BinaryView = Value.AsBinaryView(); - if (Value.HasError() && EnumHasAnyFlags(Mode, CbValidateMode::Package)) + return CbObject().GetHash(); + } + + if (const IoHash ObjectAttachmentHash = Value.AsObjectAttachment(); !Value.HasError()) + { + if (CbFieldView ObjectField = ValidateCbPackageField(View, Mode, Error)) { - if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + const CbObjectView InnerObjectView = ObjectField.AsObjectView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && ObjectField.HasError()) { AddError(Error, CbValidateError::InvalidPackageFormat); } + else if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (ObjectAttachmentHash != InnerObjectView.GetHash())) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + return ObjectAttachmentHash; } - else if (BinaryView.GetSize()) + } + else if (const IoHash BinaryAttachmentHash = Value.AsBinaryAttachment(); !Value.HasError()) + { + if (CbFieldView BinaryField = ValidateCbPackageField(View, Mode, Error)) { - if (EnumHasAnyFlags(Mode, CbValidateMode::Package | CbValidateMode::PackageHash)) + const MemoryView BinaryView = BinaryField.AsBinaryView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && BinaryField.HasError()) { - CompressedBuffer Buffer = CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView)); - if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && Buffer.IsNull()) + AddError(Error, CbValidateError::InvalidPackageFormat); + } + else + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && BinaryView.IsEmpty()) { - AddError(Error, CbValidateError::InvalidPackageFormat); + AddError(Error, CbValidateError::NullPackageAttachment); } - if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && - (IoHash::FromBLAKE3(Buffer.GetRawHash()) != IoHash::HashBuffer(Buffer.DecompressToComposite()))) + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (BinaryAttachmentHash != IoHash::HashBuffer(BinaryView))) { AddError(Error, CbValidateError::InvalidPackageHash); } - return IoHash::FromBLAKE3(Buffer.GetRawHash()); } + return BinaryAttachmentHash; } } - else + else if (const MemoryView BinaryView = Value.AsBinaryView(); !Value.HasError()) { - if (ObjectView) + if (BinaryView.GetSize() > 0) { - if (CbFieldView HashField = ValidateCbPackageField(View, Mode, Error)) + CompressedBuffer Buffer = CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView)); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && Buffer.IsNull()) { - const IoHash Hash = HashField.AsAttachment(); - if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && HashField.HasError()) - { - AddError(Error, CbValidateError::InvalidPackageFormat); - } - if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (Hash != ObjectView.GetHash())) - { - AddError(Error, CbValidateError::InvalidPackageHash); - } - return Hash; + AddError(Error, CbValidateError::NullPackageAttachment); } + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && + (IoHash::FromBLAKE3(Buffer.GetRawHash()) != IoHash::HashBuffer(Buffer.DecompressToComposite()))) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + return IoHash::FromBLAKE3(Buffer.GetRawHash()); } else { - return CbObject().GetHash(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::NullPackageAttachment); + } + return IoHash::HashBuffer(MemoryView()); } } - return {}; -} - -static IoHash -ValidateCbPackageObject(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error) -{ - CbObjectView Object = Value.AsObjectView(); - if (Value.HasError()) + else { if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) { AddError(Error, CbValidateError::InvalidPackageFormat); } } - else if (CbFieldView HashField = ValidateCbPackageField(View, Mode, Error)) + + return IoHash(); +} + +static IoHash +ValidateCbPackageObject(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + if (IoHash RootObjectHash = Value.AsHash(); !Value.HasError() && !Value.IsAttachment()) { - const IoHash Hash = HashField.AsAttachment(); + CbFieldView RootObjectField = ValidateCbPackageField(View, Mode, Error); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) { - if (!Object) - { - AddError(Error, CbValidateError::NullPackageObject); - } - if (HashField.HasError()) + if (RootObjectField.HasError()) { AddError(Error, CbValidateError::InvalidPackageFormat); } - else if (Hash != Value.GetHash()) + } + + const CbObjectView RootObjectView = RootObjectField.AsObjectView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + if (!RootObjectView) { - AddError(Error, CbValidateError::InvalidPackageHash); + AddError(Error, CbValidateError::NullPackageObject); } } - return Hash; + + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (RootObjectHash != RootObjectView.GetHash())) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + + return RootObjectHash; + } + else + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } } + return IoHash(); } @@ -562,24 +595,20 @@ ValidateCompactBinaryPackage(MemoryView View, CbValidateMode Mode) uint32_t ObjectCount = 0; while (CbFieldView Value = ValidateCbPackageField(View, Mode, Error)) { - if (Value.IsBinary()) + if (Value.IsHash() && !Value.IsAttachment()) { - const IoHash Hash = ValidateCbPackageAttachment(Value, View, Mode, Error); - if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + ValidateCbPackageObject(Value, View, Mode, Error); + if (++ObjectCount > 1 && EnumHasAnyFlags(Mode, CbValidateMode::Package)) { - Attachments.push_back(Hash); - if (Value.AsBinaryView().IsEmpty()) - { - AddError(Error, CbValidateError::NullPackageAttachment); - } + AddError(Error, CbValidateError::MultiplePackageObjects); } } - else if (Value.IsObject()) + else if (Value.IsBinary() || Value.IsAttachment() || Value.IsObject()) { - ValidateCbPackageObject(Value, View, Mode, Error); - if (++ObjectCount > 1 && EnumHasAnyFlags(Mode, CbValidateMode::Package)) + const IoHash Hash = ValidateCbPackageAttachment(Value, View, Mode, Error); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) { - AddError(Error, CbValidateError::MultiplePackageObjects); + Attachments.push_back(Hash); } } else if (Value.IsNull()) diff --git a/zencore/except.cpp b/zencore/except.cpp index 00cb826f6..9bd447308 100644 --- a/zencore/except.cpp +++ b/zencore/except.cpp @@ -28,7 +28,13 @@ ThrowLastError(std::string_view Message) std::string GetLastErrorAsString() { - throw std::error_code(::GetLastError(), std::system_category()).message(); + return GetWindowsErrorAsString(::GetLastError()); +} + +std::string +GetWindowsErrorAsString(uint32_t Win32ErrorCode) +{ + return std::error_code(Win32ErrorCode, std::system_category()).message(); } void diff --git a/zencore/httpclient.cpp b/zencore/httpclient.cpp deleted file mode 100644 index 268483403..000000000 --- a/zencore/httpclient.cpp +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zencore/httpclient.h> - -#include <spdlog/spdlog.h> - -#include <doctest/doctest.h> - -namespace zen { - -TEST_CASE("httpclient") -{ - using namespace std::literals; - - SUBCASE("client") {} -} - -void -httpclient_forcelink() -{ -} - -} // namespace zen diff --git a/zencore/httpserver.cpp b/zencore/httpserver.cpp deleted file mode 100644 index e85c5ed2b..000000000 --- a/zencore/httpserver.cpp +++ /dev/null @@ -1,1641 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zencore/httpserver.h> - -#define _WINSOCKAPI_ -#include <zencore/windows.h> -#include "iothreadpool.h" - -#include <atlbase.h> -#include <conio.h> -#include <http.h> -#include <new.h> -#include <zencore/compactbinary.h> -#include <zencore/compactbinarypackage.h> -#include <zencore/iobuffer.h> -#include <zencore/logging.h> -#include <zencore/refcount.h> -#include <zencore/stream.h> -#include <zencore/string.h> -#include <zencore/thread.h> -#include <charconv> -#include <span> -#include <string_view> - -#include <spdlog/spdlog.h> - -#include <doctest/doctest.h> - -#if ZEN_PLATFORM_WINDOWS -# pragma comment(lib, "httpapi.lib") -#endif - -////////////////////////////////////////////////////////////////////////// - -std::wstring -UTF8_to_wstring(const char* in) -{ - std::wstring out; - unsigned int codepoint; - - while (*in != 0) - { - unsigned char ch = static_cast<unsigned char>(*in); - - if (ch <= 0x7f) - codepoint = ch; - else if (ch <= 0xbf) - codepoint = (codepoint << 6) | (ch & 0x3f); - else if (ch <= 0xdf) - codepoint = ch & 0x1f; - else if (ch <= 0xef) - codepoint = ch & 0x0f; - else - codepoint = ch & 0x07; - - ++in; - - if (((*in & 0xc0) != 0x80) && (codepoint <= 0x10ffff)) - { - if (sizeof(wchar_t) > 2) - { - out.append(1, static_cast<wchar_t>(codepoint)); - } - else if (codepoint > 0xffff) - { - out.append(1, static_cast<wchar_t>(0xd800 + (codepoint >> 10))); - out.append(1, static_cast<wchar_t>(0xdc00 + (codepoint & 0x03ff))); - } - else if (codepoint < 0xd800 || codepoint >= 0xe000) - { - out.append(1, static_cast<wchar_t>(codepoint)); - } - } - } - - return out; -} - -////////////////////////////////////////////////////////////////////////// - -const char* -ReasonStringForHttpResultCode(int HttpCode) -{ - switch (HttpCode) - { - // 1xx Informational - - case 100: - return "Continue"; - case 101: - return "Switching Protocols"; - - // 2xx Success - - case 200: - return "OK"; - case 201: - return "Created"; - case 202: - return "Accepted"; - case 204: - return "No Content"; - case 205: - return "Reset Content"; - case 206: - return "Partial Content"; - - // 3xx Redirection - - case 300: - return "Multiple Choices"; - case 301: - return "Moved Permanently"; - case 302: - return "Found"; - case 303: - return "See Other"; - case 304: - return "Not Modified"; - case 305: - return "Use Proxy"; - case 306: - return "Switch Proxy"; - case 307: - return "Temporary Redirect"; - case 308: - return "Permanent Redirect"; - - // 4xx Client errors - - case 400: - return "Bad Request"; - case 401: - return "Unauthorized"; - case 402: - return "Payment Required"; - case 403: - return "Forbidden"; - case 404: - return "Not Found"; - case 405: - return "Method Not Allowed"; - case 406: - return "Not Acceptable"; - case 407: - return "Proxy Authentication Required"; - case 408: - return "Request Timeout"; - case 409: - return "Conflict"; - case 410: - return "Gone"; - case 411: - return "Length Required"; - case 412: - return "Precondition Failed"; - case 413: - return "Payload Too Large"; - case 414: - return "URI Too Long"; - case 415: - return "Unsupported Media Type"; - case 416: - return "Range Not Satisifiable"; - case 417: - return "Expectation Failed"; - case 418: - return "I'm a teapot"; - case 421: - return "Misdirected Request"; - case 422: - return "Unprocessable Entity"; - case 423: - return "Locked"; - case 424: - return "Failed Dependency"; - case 425: - return "Too Early"; - case 426: - return "Upgrade Required"; - case 428: - return "Precondition Required"; - case 429: - return "Too Many Requests"; - case 431: - return "Request Header Fields Too Large"; - - // 5xx Server errors - - case 500: - return "Internal Server Error"; - case 501: - return "Not Implemented"; - case 502: - return "Bad Gateway"; - case 503: - return "Service Unavailable"; - case 504: - return "Gateway Timeout"; - case 505: - return "HTTP Version Not Supported"; - case 506: - return "Variant Also Negotiates"; - case 507: - return "Insufficient Storage"; - case 508: - return "Loop Detected"; - case 510: - return "Not Extended"; - case 511: - return "Network Authentication Required"; - - default: - return "Unknown Result"; - } -} - -namespace zen { - -using namespace std::literals; - -static const uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); -static const uint32_t HashJson = HashStringDjb2("application/json"sv); -static const uint32_t HashYaml = HashStringDjb2("text/yaml"sv); -static const uint32_t HashText = HashStringDjb2("text/plain"sv); -static const uint32_t HashCompactBinary = HashStringDjb2("application/x-ue-cb"sv); -static const uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv); - -HttpContentType -MapContentType(const std::string_view& ContentTypeString) -{ - if (!ContentTypeString.empty()) - { - const uint32_t CtHash = HashStringDjb2(ContentTypeString); - - if (CtHash == HashBinary) - { - return HttpContentType::kBinary; - } - else if (CtHash == HashCompactBinary) - { - return HttpContentType::kCbObject; - } - else if (CtHash == HashCompactBinaryPackage) - { - return HttpContentType::kCbPackage; - } - else if (CtHash == HashJson) - { - return HttpContentType::kJSON; - } - else if (CtHash == HashYaml) - { - return HttpContentType::kYAML; - } - else if (CtHash == HashText) - { - return HttpContentType::kText; - } - } - - return HttpContentType::kUnknownContentType; -} - -////////////////////////////////////////////////////////////////////////// - -HttpServerRequest::HttpServerRequest() -{ -} - -HttpServerRequest::~HttpServerRequest() -{ -} - -struct CbPackageHeader -{ - uint32_t HeaderMagic; - uint32_t AttachmentCount; - uint32_t Reserved1; - uint32_t Reserved2; -}; - -static constinit uint32_t kCbPkgMagic = 0xaa77aacc; - -struct CbAttachmentEntry -{ - uint64_t AttachmentSize; - uint32_t Reserved1; - IoHash AttachmentHash; -}; - -void -HttpServerRequest::WriteResponse(HttpResponse HttpResponseCode, CbPackage Data) -{ - const std::span<const CbAttachment>& Attachments = Data.GetAttachments(); - - std::vector<IoBuffer> ResponseBuffers; - ResponseBuffers.reserve(3 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each - // attachment is likely to consist of several buffers - - uint64_t TotalAttachmentsSize = 0; - - // Fixed size header - - CbPackageHeader Hdr{.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())}; - - ResponseBuffers.push_back(IoBufferBuilder::MakeCloneFromMemory(&Hdr, sizeof Hdr)); - - // Attachment metadata array - - IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)}; - - CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(AttachmentMetadataBuffer.MutableData()); - - ResponseBuffers.push_back(AttachmentMetadataBuffer); // Attachment metadata - - // Root object - - IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer(); - ResponseBuffers.push_back(RootIoBuffer); // Root object - - *AttachmentInfo++ = {.AttachmentSize = RootIoBuffer.Size(), .AttachmentHash = Data.GetObjectHash()}; - - // Attachment payloads - - for (const CbAttachment& Attachment : Attachments) - { - CompressedBuffer AttachmentBuffer = Attachment.AsCompressedBinary(); - CompositeBuffer Compressed = AttachmentBuffer.GetCompressed(); - - *AttachmentInfo++ = {.AttachmentSize = AttachmentBuffer.GetCompressedSize(), - .AttachmentHash = IoHash::FromBLAKE3(AttachmentBuffer.GetRawHash())}; - - for (const SharedBuffer& Segment : Compressed.GetSegments()) - { - ResponseBuffers.push_back(Segment.AsIoBuffer()); - TotalAttachmentsSize += Segment.GetSize(); - } - } - - return WriteResponse(HttpResponseCode, HttpContentType::kCbPackage, ResponseBuffers); -} - -void -HttpServerRequest::WriteResponse(HttpResponse HttpResponseCode, CbObject Data) -{ - SharedBuffer Buf = Data.GetBuffer(); - std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; - return WriteResponse(HttpResponseCode, HttpContentType::kCbObject, Buffers); -} - -void -HttpServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::string_view ResponseString) -{ - return WriteResponse(HttpResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()}); -} - -void -HttpServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, IoBuffer Blob) -{ - std::array<IoBuffer, 1> Buffers{Blob}; - return WriteResponse(HttpResponseCode, ContentType, Buffers); -} - -HttpServerRequest::QueryParams -HttpServerRequest::GetQueryParams() -{ - QueryParams Params; - - const std::string_view QStr = QueryString(); - - const char* QueryIt = QStr.data(); - const char* QueryEnd = QueryIt + QStr.size(); - - while (QueryIt != QueryEnd) - { - if (*QueryIt == '&') - { - ++QueryIt; - continue; - } - - const std::string_view Query{QueryIt, QueryEnd}; - - size_t DelimIndex = Query.find('&', 0); - - if (DelimIndex == std::string_view::npos) - { - DelimIndex = Query.size(); - } - - std::string_view ThisQuery{QueryIt, DelimIndex}; - - size_t EqIndex = ThisQuery.find('=', 0); - - if (EqIndex != std::string_view::npos) - { - std::string_view Parm{ThisQuery.data(), EqIndex}; - ThisQuery.remove_prefix(EqIndex + 1); - - Params.KvPairs.emplace_back(Parm, ThisQuery); - } - - QueryIt += DelimIndex; - } - - return Params; -} - -CbObject -HttpServerRequest::ReadPayloadObject() -{ - IoBuffer Payload = ReadPayload(); - - if (Payload) - { - return LoadCompactBinaryObject(std::move(Payload)); - } - else - { - return {}; - } -} - -CbPackage -HttpServerRequest::ReadPayloadPackage() -{ - // TODO: this should not read into a contiguous buffer! - - IoBuffer Payload = ReadPayload(); - MemoryInStream InStream(Payload); - BinaryReader Reader(InStream); - - if (!Payload) - { - return {}; - } - - CbPackage Package; - - CbPackageHeader Hdr; - Reader.Read(&Hdr, sizeof Hdr); - - if (Hdr.HeaderMagic != kCbPkgMagic) - { - // report error - return {}; - } - - uint32_t ChunkCount = Hdr.AttachmentCount + 1; - - std::unique_ptr<CbAttachmentEntry[]> AttachmentEntries{new CbAttachmentEntry[ChunkCount]}; - - Reader.Read(AttachmentEntries.get(), sizeof(CbAttachmentEntry) * ChunkCount); - - for (uint32_t i = 0; i < ChunkCount; ++i) - { - const uint64_t AttachmentSize = AttachmentEntries[i].AttachmentSize; - IoBuffer AttachmentBuffer{AttachmentSize}; - Reader.Read(AttachmentBuffer.MutableData(), AttachmentSize); - CompressedBuffer CompBuf(CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer))); - - if (i == 0) - { - Package.SetObject(LoadCompactBinaryObject(CompBuf)); - } - else - { - CbAttachment Attachment(CompBuf); - Package.AddAttachment(Attachment); - } - } - - return Package; -} - -////////////////////////////////////////////////////////////////////////// -// -// http.sys implementation -// - -#if ZEN_PLATFORM_WINDOWS -class HttpSysServer; -class HttpTransaction; - -class HttpSysRequestHandler -{ -public: - HttpSysRequestHandler(HttpTransaction& InRequest) : m_Request(InRequest) {} - virtual ~HttpSysRequestHandler() = default; - - virtual void IssueRequest() = 0; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; - - HttpTransaction& Transaction() { return m_Request; } - -private: - HttpTransaction& m_Request; // Outermost HTTP transaction object -}; - -/** HTTP transaction - - There will be an instance of this per pending and in-flight HTTP transaction - - */ -class HttpTransaction -{ -public: - HttpTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler) {} - - virtual ~HttpTransaction() {} - - enum class Status - { - kDone, - kRequestPending - }; - - Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); - - static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, - PVOID pContext /* HttpSysServer */, - PVOID pOverlapped, - ULONG IoResult, - ULONG_PTR NumberOfBytesTransferred, - PTP_IO Io) - { - UNREFERENCED_PARAMETER(Io); - UNREFERENCED_PARAMETER(Instance); - UNREFERENCED_PARAMETER(pContext); - - // Note that for a given transaction we may be in this completion function on more - // than one thread at any given moment. This means we need to be careful about what - // happens in here - - HttpTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpTransaction, m_HttpOverlapped); - - if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpTransaction::Status::kDone) - { - delete Transaction; - } - } - - void IssueInitialRequest(); - - PTP_IO Iocp(); - HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } - inline HttpSysServer& Server() { return m_HttpServer; } - - inline PHTTP_REQUEST HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } - -protected: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; - HttpSysRequestHandler* m_HttpHandler{nullptr}; - RwLock m_Lock; - -private: - struct InitialRequestHandler : public HttpSysRequestHandler - { - inline PHTTP_REQUEST HttpRequest() { return (PHTTP_REQUEST)m_RequestBuffer; } - inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } - - InitialRequestHandler(HttpTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} - ~InitialRequestHandler() {} - - virtual void IssueRequest() override; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - - PHTTP_REQUEST m_HttpRequestPtr = (HTTP_REQUEST*)(m_RequestBuffer); - UCHAR m_RequestBuffer[16384 + sizeof(HTTP_REQUEST)]; - } m_InitialHttpHandler{*this}; -}; - -////////////////////////////////////////////////////////////////////////// - -class HttpMessageResponseRequest : public HttpSysRequestHandler -{ -public: - HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode); - HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, const char* Message); - HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, const void* Payload, size_t PayloadSize); - HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs); - ~HttpMessageResponseRequest(); - - virtual void IssueRequest() override; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - - void SuppressResponseBody(); - -private: - std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks; - uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes - - uint16_t m_HttpResponseCode = 0; - uint32_t m_NextDataChunkOffset = 0; // This is used for responses where the number of chunks exceed the maximum number for one API call - uint32_t m_RemainingChunkCount = 0; - bool m_IsInitialResponse = true; - - void Initialize(uint16_t ResponseCode, std::span<IoBuffer> Blobs); - - std::vector<IoBuffer> m_DataBuffers; -}; - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode) : HttpSysRequestHandler(InRequest) -{ - std::array<IoBuffer, 0> buffers; - - Initialize(ResponseCode, buffers); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, const char* Message) -: HttpSysRequestHandler(InRequest) -{ - IoBuffer MessageBuffer(IoBuffer::Wrap, Message, strlen(Message)); - std::array<IoBuffer, 1> buffers({MessageBuffer}); - - Initialize(ResponseCode, buffers); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InRequest, - uint16_t ResponseCode, - const void* Payload, - size_t PayloadSize) -: HttpSysRequestHandler(InRequest) -{ - IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); - std::array<IoBuffer, 1> buffers({MessageBuffer}); - - Initialize(ResponseCode, buffers); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs) -: HttpSysRequestHandler(InRequest) -{ - Initialize(ResponseCode, Blobs); -} - -HttpMessageResponseRequest::~HttpMessageResponseRequest() -{ -} - -void -HttpMessageResponseRequest::Initialize(uint16_t ResponseCode, std::span<IoBuffer> Blobs) -{ - m_HttpResponseCode = ResponseCode; - - const uint32_t ChunkCount = (uint32_t)Blobs.size(); - - m_HttpDataChunks.resize(ChunkCount); - m_DataBuffers.reserve(ChunkCount); - - for (IoBuffer& Buffer : Blobs) - { - m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); - } - - // Initialize the full array up front - - uint64_t LocalDataSize = 0; - - { - PHTTP_DATA_CHUNK ChunkPtr = m_HttpDataChunks.data(); - - for (IoBuffer& Buffer : m_DataBuffers) - { - const ULONG BufferDataSize = (ULONG)Buffer.Size(); - - ZEN_ASSERT(BufferDataSize); - - IoBufferFileReference FileRef; - if (Buffer.GetFileReference(/* out */ FileRef)) - { - ChunkPtr->DataChunkType = HttpDataChunkFromFileHandle; - ChunkPtr->FromFileHandle.FileHandle = FileRef.FileHandle; - ChunkPtr->FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; - ChunkPtr->FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; - } - else - { - ChunkPtr->DataChunkType = HttpDataChunkFromMemory; - ChunkPtr->FromMemory.pBuffer = (void*)Buffer.Data(); - ChunkPtr->FromMemory.BufferLength = BufferDataSize; - } - ++ChunkPtr; - - LocalDataSize += BufferDataSize; - } - } - - m_RemainingChunkCount = ChunkCount; - m_TotalDataSize = LocalDataSize; -} - -void -HttpMessageResponseRequest::SuppressResponseBody() -{ - m_RemainingChunkCount = 0; - m_HttpDataChunks.clear(); - m_DataBuffers.clear(); -} - -HttpSysRequestHandler* -HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - ZEN_UNUSED(NumberOfBytesTransferred); - ZEN_UNUSED(IoResult); - - if (m_RemainingChunkCount == 0) - return nullptr; // All done - - return this; -} - -void -HttpMessageResponseRequest::IssueRequest() -{ - HttpTransaction& Tx = Transaction(); - HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); - PTP_IO const Iocp = Tx.Iocp(); - - StartThreadpoolIo(Iocp); - - // Split payload into batches to play well with the underlying API - - const int MaxChunksPerCall = 9999; - - const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall); - const int ThisRequestChunkOffset = m_NextDataChunkOffset; - - m_RemainingChunkCount -= ThisRequestChunkCount; - m_NextDataChunkOffset += ThisRequestChunkCount; - - ULONG SendFlags = 0; - - if (m_RemainingChunkCount) - { - // We need to make more calls to send the full amount of data - SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; - } - - ULONG SendResult = 0; - - if (m_IsInitialResponse) - { - // Populate response structure - - HTTP_RESPONSE HttpResponse = {}; - - HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); - HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; - - // Content-length header - - char ContentLengthString[32]; - _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); - - PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; - ContentLengthHeader->pRawValue = ContentLengthString; - ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); - - // Content-type header - - PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - - ContentTypeHeader->pRawValue = "application/octet-stream"; /* TODO! We must respect the content type specified */ - ContentTypeHeader->RawValueLength = (USHORT)strlen(ContentTypeHeader->pRawValue); - - HttpResponse.StatusCode = m_HttpResponseCode; - HttpResponse.pReason = ReasonStringForHttpResultCode(m_HttpResponseCode); - HttpResponse.ReasonLength = (USHORT)strlen(HttpResponse.pReason); - - // Cache policy - - HTTP_CACHE_POLICY CachePolicy; - - CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; - CachePolicy.SecondsToLive = 0; - - // Initial response API call - - SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - &HttpResponse, - &CachePolicy, - NULL, - NULL, - 0, - Tx.Overlapped(), - NULL); - - m_IsInitialResponse = false; - } - else - { - // Subsequent response API calls - - SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - (USHORT)ThisRequestChunkCount, // EntityChunkCount - &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks - NULL, // BytesSent - NULL, // Reserved1 - 0, // Reserved2 - Tx.Overlapped(), // Overlapped - NULL // LogData - ); - } - - if ((SendResult != NO_ERROR) // Synchronous completion, but the completion event will still be posted to IOCP - && (SendResult != ERROR_IO_PENDING) // Asynchronous completion - ) - { - // Some error occurred, no completion will be posted - - CancelThreadpoolIo(Iocp); - - spdlog::error("failed to send HTTP response (error: {}) URL: {}", SendResult, HttpReq->pRawUrl); - - throw HttpServerException("Failed to send HTTP response", SendResult); - } -} - -////////////////////////////////////////////////////////////////////////// - -class HttpSysServer -{ - friend class HttpTransaction; - -public: - HttpSysServer(WinIoThreadPool& InThreadPool); - ~HttpSysServer(); - - void Initialize(const wchar_t* UrlPath); - void Run(bool TestMode); - - void RequestExit() { m_ShutdownEvent.Set(); } - - void StartServer(); - void StopServer(); - - void OnHandlingRequest(); - void IssueNewRequestMaybe(); - - inline bool IsOk() const { return m_IsOk; } - - void AddEndpoint(const char* Endpoint, HttpService& Service); - void RemoveEndpoint(const char* Endpoint, HttpService& Service); - -private: - bool m_IsOk = false; - bool m_IsHttpInitialized = false; - WinIoThreadPool& m_ThreadPool; - - std::wstring m_BaseUri; // http://*:nnnn/ - HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; - HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; - HANDLE m_RequestQueueHandle = 0; - std::atomic_int32_t m_PendingRequests{0}; - int32_t m_MinPendingRequests = 4; - int32_t m_MaxPendingRequests = 32; - Event m_ShutdownEvent; -}; - -HttpSysServer::HttpSysServer(WinIoThreadPool& InThreadPool) : m_ThreadPool(InThreadPool) -{ - ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); - - if (Result != NO_ERROR) - { - return; - } - - m_IsHttpInitialized = true; - m_IsOk = true; -} - -HttpSysServer::~HttpSysServer() -{ - if (m_IsHttpInitialized) - { - HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); - } -} - -void -HttpSysServer::Initialize(const wchar_t* UrlPath) -{ - // check(bIsOk); - - ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); - - if (Result != NO_ERROR) - { - // Flag error - - return; - } - - Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); - - if (Result != NO_ERROR) - { - // Flag error - - return; - } - - m_BaseUri = UrlPath; - - Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, UrlPath, /* #TODO UrlContext */ HTTP_URL_CONTEXT(0), 0); - - if (Result != NO_ERROR) - { - // Flag error - - return; - } - - HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; - - Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, NULL, NULL, 0, &m_RequestQueueHandle); - - if (Result != NO_ERROR) - { - // Flag error! - - return; - } - - HttpBindingInfo.Flags.Present = 1; - HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; - - Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); - - if (Result != NO_ERROR) - { - // Flag error! - - return; - } - - // Create I/O completion port - - m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpTransaction::IoCompletionCallback, this); - - // Check result! -} - -void -HttpSysServer::StartServer() -{ - int RequestCount = 32; - - for (int i = 0; i < RequestCount; ++i) - { - IssueNewRequestMaybe(); - } -} - -void -HttpSysServer::Run(bool TestMode) -{ - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); - } - - do - { - int WaitTimeout = -1; - - if (!TestMode) - { - WaitTimeout = 1000; - } - - if (!TestMode && _kbhit() != 0) - { - char c = (char)_getch(); - - if (c == 27 || c == 'Q' || c == 'q') - { - RequestApplicationExit(0); - } - } - - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -} - -void -HttpSysServer::OnHandlingRequest() -{ - --m_PendingRequests; - - if (m_PendingRequests > m_MinPendingRequests) - { - // We have more than the minimum number of requests pending, just let someone else - // enqueue new requests - return; - } - - IssueNewRequestMaybe(); -} - -void -HttpSysServer::IssueNewRequestMaybe() -{ - if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) - { - return; - } - - std::unique_ptr<HttpTransaction> Request = std::make_unique<HttpTransaction>(*this); - - Request->IssueInitialRequest(); - - // This may end up exceeding the MaxPendingRequests limit, but it's not - // really a problem. I'm doing it this way mostly to avoid dealing with - // exceptions here - ++m_PendingRequests; - - Request.release(); -} - -void -HttpSysServer::StopServer() -{ -} - -void -HttpSysServer::AddEndpoint(const char* UrlPath, HttpService& Service) -{ - if (UrlPath[0] == '/') - { - ++UrlPath; - } - - const std::wstring Path16 = UTF8_to_wstring(UrlPath); - Service.SetUriPrefixLength(Path16.size() + 1 /* leading slash */); - - // Convert to wide string - - std::wstring Url16 = m_BaseUri + Path16; - - ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); - - if (Result != NO_ERROR) - { - spdlog::error("HttpAddUrlToUrlGroup failed with result {}", Result); - - return; - } -} - -void -HttpSysServer::RemoveEndpoint(const char* UrlPath, HttpService& Service) -{ - ZEN_UNUSED(Service); - - if (UrlPath[0] == '/') - { - ++UrlPath; - } - - const std::wstring Path16 = UTF8_to_wstring(UrlPath); - - // Convert to wide string - - std::wstring Url16 = m_BaseUri + Path16; - - ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); - - if (Result != NO_ERROR) - { - spdlog::error("HttpRemoveUrlFromUrlGroup failed with result {}", Result); - } -} - -////////////////////////////////////////////////////////////////////////// - -class HttpSysServerRequest : public HttpServerRequest -{ -public: - HttpSysServerRequest(HttpTransaction& Tx, HttpService& Service) : m_HttpTx(Tx) - { - PHTTP_REQUEST HttpRequestPtr = Tx.HttpRequest(); - - const int PrefixLength = Service.UriPrefixLength(); - const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); - - if (AbsPathLength >= PrefixLength) - { - // We convert the URI immediately because most of the code involved prefers to deal - // with utf8. This has some performance impact which I'd prefer to avoid but for now - // we just have to live with it - - WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, - m_Uri); - } - else - { - m_Uri.Reset(); - } - - if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) - { - --QueryStringLength; - - WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryString); - } - else - { - m_QueryString.Reset(); - } - - switch (HttpRequestPtr->Verb) - { - case HttpVerbOPTIONS: - m_Verb = HttpVerb::kOptions; - break; - - case HttpVerbGET: - m_Verb = HttpVerb::kGet; - break; - - case HttpVerbHEAD: - m_Verb = HttpVerb::kHead; - break; - - case HttpVerbPOST: - m_Verb = HttpVerb::kPost; - break; - - case HttpVerbPUT: - m_Verb = HttpVerb::kPut; - break; - - case HttpVerbDELETE: - m_Verb = HttpVerb::kDelete; - break; - - case HttpVerbCOPY: - m_Verb = HttpVerb::kCopy; - break; - - default: - // TODO: invalid request? - m_Verb = (HttpVerb)0; - break; - } - - const HTTP_KNOWN_HEADER& clh = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentLength]; - std::string_view cl(clh.pRawValue, clh.RawValueLength); - std::from_chars(cl.data(), cl.data() + cl.size(), m_ContentLength); - - const HTTP_KNOWN_HEADER& CtHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentType]; - m_ContentType = MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); - - const HTTP_KNOWN_HEADER& AcceptHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAccept]; - m_AcceptType = MapContentType({AcceptHdr.pRawValue, AcceptHdr.RawValueLength}); - } - - ~HttpSysServerRequest() {} - - virtual IoBuffer ReadPayload() override - { - // This is presently synchronous for simplicity, but we - // need to implement an asynchronous version also - - HTTP_REQUEST* const HttpReq = m_HttpTx.HttpRequest(); - - IoBuffer PayloadBuffer(m_ContentLength); - - HttpContentType ContentType = RequestContentType(); - PayloadBuffer.SetContentType(ContentType); - - uint64_t BytesToRead = m_ContentLength; - - uint8_t* ReadPointer = reinterpret_cast<uint8_t*>(PayloadBuffer.MutableData()); - - // First deal with any payload which has already been copied - // into our request buffer - - const int EntityChunkCount = HttpReq->EntityChunkCount; - - for (int i = 0; i < EntityChunkCount; ++i) - { - HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; - - ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); - - const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; - - ZEN_ASSERT(BufferLength <= BytesToRead); - - memcpy(ReadPointer, EntityChunk.FromMemory.pBuffer, BufferLength); - - ReadPointer += BufferLength; - BytesToRead -= BufferLength; - } - - // Call http.sys API to receive the remaining data - - static const uint64_t kMaxBytesPerApiCall = 1 * 1024 * 1024; - - while (BytesToRead) - { - ULONG BytesRead = 0; - - const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); - - ULONG ApiResult = HttpReceiveRequestEntityBody(m_HttpTx.RequestQueueHandle(), - HttpReq->RequestId, - 0, /* Flags */ - ReadPointer, - gsl::narrow<ULONG>(BytesToReadThisCall), - &BytesRead, - NULL /* Overlapped */ - ); - - if (ApiResult != NO_ERROR && ApiResult != ERROR_HANDLE_EOF) - { - throw HttpServerException("payload read failed", ApiResult); - } - - BytesToRead -= BytesRead; - ReadPointer += BytesRead; - } - - PayloadBuffer.MakeImmutable(); - - return PayloadBuffer; - } - - virtual void WriteResponse(HttpResponse HttpResponseCode) override - { - ZEN_ASSERT(m_IsHandled == false); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; - } - - virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override - { - ZEN_ASSERT(m_IsHandled == false); - ZEN_UNUSED(ContentType); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, Blobs); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; - } - - virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override - { - ZEN_ASSERT(m_IsHandled == false); - ZEN_UNUSED(ContentType); - - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ResponseString.data(), ResponseString.size()); - - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } - - m_IsHandled = true; - } - - HttpTransaction& m_HttpTx; - HttpMessageResponseRequest* m_Response = nullptr; -}; - -////////////////////////////////////////////////////////////////////////// - -PTP_IO -HttpTransaction::Iocp() -{ - return m_HttpServer.m_ThreadPool.Iocp(); -} - -HANDLE -HttpTransaction::RequestQueueHandle() -{ - return m_HttpServer.m_RequestQueueHandle; -} - -void -HttpTransaction::IssueInitialRequest() -{ - m_InitialHttpHandler.IssueRequest(); -} - -HttpTransaction::Status -HttpTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - // We use this to ensure sequential execution of completion handlers - // for any given transaction. - RwLock::ExclusiveLockScope _(m_Lock); - - bool RequestPending = false; - - if (HttpSysRequestHandler* CurrentHandler = m_HttpHandler) - { - const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler); - - if (IsInitialRequest) - { - // Ensure we have a sufficient number of pending requests outstanding - m_HttpServer.OnHandlingRequest(); - } - - m_HttpHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); - - if (m_HttpHandler) - { - try - { - m_HttpHandler->IssueRequest(); - - RequestPending = true; - } - catch (std::exception& Ex) - { - spdlog::error("exception caught from IssueRequest(): {}", Ex.what()); - - // something went wrong, no request is pending - } - } - else - { - if (IsInitialRequest == false) - { - delete CurrentHandler; - } - } - } - - m_HttpServer.IssueNewRequestMaybe(); - - if (RequestPending) - { - return Status::kRequestPending; - } - - return Status::kDone; -} - -////////////////////////////////////////////////////////////////////////// - -void -HttpTransaction::InitialRequestHandler::IssueRequest() -{ - PTP_IO Iocp = Transaction().Iocp(); - - StartThreadpoolIo(Iocp); - - HttpTransaction& Tx = Transaction(); - - HTTP_REQUEST* HttpReq = Tx.HttpRequest(); - - ULONG Result = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), - HTTP_NULL_ID, - HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, - HttpReq, - RequestBufferSize(), - NULL, - Tx.Overlapped()); - - if (Result != ERROR_IO_PENDING && Result != NO_ERROR) - { - CancelThreadpoolIo(Iocp); - - if (Result == ERROR_MORE_DATA) - { - // ProcessReceiveAndPostResponse(pIoRequest, pServerContext->Io, ERROR_MORE_DATA); - } - - // CleanupHttpIoRequest(pIoRequest); - - fprintf(stderr, "HttpReceiveHttpRequest failed, error 0x%lx\n", Result); - - return; - } -} - -HttpSysRequestHandler* -HttpTransaction::InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - ZEN_UNUSED(IoResult); - ZEN_UNUSED(NumberOfBytesTransferred); - - // Route requests - - try - { - if (HttpService* Service = reinterpret_cast<HttpService*>(m_HttpRequestPtr->UrlContext)) - { - HttpSysServerRequest ThisRequest(Transaction(), *Service); - - Service->HandleRequest(ThisRequest); - - if (!ThisRequest.IsHandled()) - { - return new HttpMessageResponseRequest(Transaction(), 404, "Not found"); - } - - if (ThisRequest.m_Response) - { - return ThisRequest.m_Response; - } - } - - // Unable to route - return new HttpMessageResponseRequest(Transaction(), 404, "Item unknown"); - } - catch (std::exception& ex) - { - // TODO provide more meaningful error output - - return new HttpMessageResponseRequest(Transaction(), 500, ex.what()); - } -} -#endif // ZEN_PLATFORM_WINDOWS - -////////////////////////////////////////////////////////////////////////// - -struct HttpServer::Impl : public RefCounted -{ - WinIoThreadPool m_ThreadPool; - HttpSysServer m_HttpServer; - - Impl(int ThreadCount) : m_ThreadPool(ThreadCount), m_HttpServer(m_ThreadPool) {} - - void Initialize(int BasePort) - { - using namespace std::literals; - - WideStringBuilder<64> BaseUri; - BaseUri << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; - - m_HttpServer.Initialize(BaseUri.c_str()); - m_HttpServer.StartServer(); - } - - void Run(bool TestMode) { m_HttpServer.Run(TestMode); } - - void RequestExit() { m_HttpServer.RequestExit(); } - - void Cleanup() { m_HttpServer.StopServer(); } - - void AddEndpoint(const char* Endpoint, HttpService& Service) { m_HttpServer.AddEndpoint(Endpoint, Service); } - - void AddEndpoint([[maybe_unused]] const char* endpoint, [[maybe_unused]] std::function<void(HttpServerRequest&)> handler) - { - ZEN_NOT_IMPLEMENTED(); - } -}; - -HttpServer::HttpServer() -{ - m_Impl = new Impl(32); -} - -HttpServer::~HttpServer() -{ - m_Impl->Cleanup(); -} - -void -HttpServer::AddEndpoint(HttpService& Service) -{ - m_Impl->AddEndpoint(Service.BaseUri(), Service); -} - -void -HttpServer::AddEndpoint(const char* endpoint, std::function<void(HttpServerRequest&)> handler) -{ - m_Impl->AddEndpoint(endpoint, handler); -} - -void -HttpServer::Initialize(int BasePort) -{ - m_Impl->Initialize(BasePort); -} - -void -HttpServer::Run(bool TestMode) -{ - m_Impl->Run(TestMode); -} - -void -HttpServer::RequestExit() -{ - m_Impl->RequestExit(); -} - -////////////////////////////////////////////////////////////////////////// - -HttpServerException::HttpServerException(const char* Message, uint32_t Error) : m_ErrorCode(Error) -{ - using namespace fmt::literals; - - m_Message = "{} (HTTP error {})"_format(Message, m_ErrorCode); -} - -const char* -HttpServerException::what() const noexcept -{ - return m_Message.c_str(); -} - -////////////////////////////////////////////////////////////////////////// - -void -HttpRequestRouter::AddPattern(const char* Id, const char* Regex) -{ - ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); - - m_PatternMap.insert({Id, Regex}); -} - -void -HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) -{ - // Expand patterns - - ExtendableStringBuilder<128> ExpandedRegex; - - size_t RegexLen = strlen(Regex); - - for (size_t i = 0; i < RegexLen;) - { - bool matched = false; - - if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\'))) - { - // Might have a pattern reference - find closing brace - - for (size_t j = i + 1; j < RegexLen; ++j) - { - if (Regex[j] == '}') - { - std::string Pattern(&Regex[i + 1], j - i - 1); - - if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) - { - ExpandedRegex.Append(it->second.c_str()); - } - else - { - // Default to anything goes (or should this just be an error?) - - ExpandedRegex.Append("(.+?)"); - } - - // skip ahead - i = j + 1; - - matched = true; - - break; - } - } - } - - if (!matched) - { - ExpandedRegex.Append(Regex[i++]); - } - } - - m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex); -} - -bool -HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) -{ - const HttpVerb Verb = Request.RequestVerb(); - - std::string_view Uri = Request.RelativeUri(); - HttpRouterRequest RouterRequest(Request); - - for (const auto& Handler : m_Handlers) - { - if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) - { - Handler.Handler(RouterRequest); - - return true; // Route matched - } - } - - return false; // No route matched -} - -TEST_CASE("http") -{ - using namespace std::literals; - - SUBCASE("router") - { - HttpRequestRouter r; - r.AddPattern("a", "[[:alpha:]]+"); - r.RegisterRoute( - "{a}", - [&](auto) {}, - HttpVerb::kGet); - - // struct TestHttpServerRequest : public HttpServerRequest - //{ - // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {} - //}; - - // TestHttpServerRequest req{}; - // r.HandleRequest(req); - } -} - -void -http_forcelink() -{ -} - -} // namespace zen diff --git a/zencore/include/zencore/compactbinary.h b/zencore/include/zencore/compactbinary.h index b214802bf..4fce129ea 100644 --- a/zencore/include/zencore/compactbinary.h +++ b/zencore/include/zencore/compactbinary.h @@ -1098,15 +1098,15 @@ public: /** Access the field as an object. Defaults to an empty object on error. */ inline CbObject AsObject() &; - - /** Access the field as an object. Defaults to an empty object on error. */ inline CbObject AsObject() &&; /** Access the field as an array. Defaults to an empty array on error. */ inline CbArray AsArray() &; - - /** Access the field as an array. Defaults to an empty array on error. */ inline CbArray AsArray() &&; + + /** Access the field as binary. Returns the provided default on error. */ + inline SharedBuffer AsBinary(const SharedBuffer& Default = SharedBuffer()) &; + inline SharedBuffer AsBinary(const SharedBuffer& Default = SharedBuffer()) &&; }; /** @@ -1268,6 +1268,20 @@ CbField::AsArray() && return IsArray() ? CbArray(AsArrayView(), std::move(*this)) : CbArray(); } +inline SharedBuffer +CbField::AsBinary(const SharedBuffer& Default) & +{ + const MemoryView View = AsBinaryView(); + return !HasError() ? SharedBuffer::MakeView(View, GetOuterBuffer()) : Default; +} + +inline SharedBuffer +CbField::AsBinary(const SharedBuffer& Default) && +{ + const MemoryView View = AsBinaryView(); + return !HasError() ? SharedBuffer::MakeView(View, std::move(*this).GetOuterBuffer()) : Default; +} + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// /** diff --git a/zencore/include/zencore/compactbinarypackage.h b/zencore/include/zencore/compactbinarypackage.h index d60155d1a..e31bc4bfd 100644 --- a/zencore/include/zencore/compactbinarypackage.h +++ b/zencore/include/zencore/compactbinarypackage.h @@ -38,18 +38,27 @@ public: CbAttachment() = default; /** Construct a compact binary attachment. Value is cloned if not owned. */ - inline explicit CbAttachment(const CbObject& Value) : CbAttachment(Value, nullptr) {} + inline explicit CbAttachment(const CbObject& InValue) : CbAttachment(InValue, nullptr) {} /** Construct a compact binary attachment. Value is cloned if not owned. Hash must match Value. */ - inline explicit CbAttachment(const CbObject& Value, const IoHash& Hash) : CbAttachment(Value, &Hash) {} + inline explicit CbAttachment(const CbObject& InValue, const IoHash& Hash) : CbAttachment(InValue, &Hash) {} - /** Construct a binary attachment. Value is cloned if not owned. */ - ZENCORE_API explicit CbAttachment(const SharedBuffer& Value); + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue); - /** Construct a binary attachment. Value is cloned if not owned. Hash must match Value. */ - ZENCORE_API explicit CbAttachment(const SharedBuffer& Value, const IoHash& Hash); + /** Construct a raw binary attachment. Value is cloned if not owned. Hash must match Value. */ + ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue, const IoHash& Hash); - /** Construct a binary attachment. Value is cloned if not owned. */ + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(const CompositeBuffer& InValue); + + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue); + + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue, const IoHash& Hash); + + /** Construct a compressed binary attachment. Value is cloned if not owned. */ ZENCORE_API explicit CbAttachment(const CompressedBuffer& InValue); ZENCORE_API explicit CbAttachment(CompressedBuffer&& InValue); @@ -66,13 +75,19 @@ public: ZENCORE_API [[nodiscard]] SharedBuffer AsBinary() const; /** Access the attachment as compressed binary. Defaults to a null buffer if the attachment is null. */ + ZENCORE_API [[nodiscard]] CompositeBuffer AsCompositeBinary() const; + + /** Access the attachment as compressed binary. Defaults to a null buffer if the attachment is null. */ ZENCORE_API [[nodiscard]] CompressedBuffer AsCompressedBinary() const; /** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */ ZENCORE_API [[nodiscard]] CbObject AsObject() const; - /** Returns true if the attachment is either binary or an object */ - [[nodiscard]] inline bool IsBinary() const { return !IsNull(); } + /** Returns true if the attachment is binary */ + ZENCORE_API [[nodiscard]] bool IsBinary() const; + + /** Returns true if the attachment is compressed binary */ + ZENCORE_API [[nodiscard]] bool IsCompressedBinary() const; /** Returns whether the attachment is an object. */ ZENCORE_API [[nodiscard]] bool IsObject() const; @@ -122,7 +137,19 @@ private: CbObjectValue(CbObject&& InObject, const IoHash& InHash) : Object(std::move(InObject)), Hash(InHash) {} }; - std::variant<CompressedBuffer, CbObjectValue> Value; + struct BinaryValue + { + CompositeBuffer Buffer; + IoHash Hash; + + BinaryValue(const CompositeBuffer& InBuffer) : Buffer(InBuffer.MakeOwned()), Hash(IoHash::HashBuffer(InBuffer)) {} + BinaryValue(const CompositeBuffer& InBuffer, const IoHash& InHash) : Buffer(InBuffer.MakeOwned()), Hash(InHash) {} + BinaryValue(CompositeBuffer&& InBuffer) : Buffer(std::move(InBuffer)), Hash(IoHash::HashBuffer(Buffer)) {} + BinaryValue(CompositeBuffer&& InBuffer, const IoHash& InHash) : Buffer(std::move(InBuffer)), Hash(InHash) {} + BinaryValue(SharedBuffer&& InBuffer, const IoHash& InHash) : Buffer(std::move(InBuffer)), Hash(InHash) {} + }; + + std::variant<nullptr_t, CbObjectValue, BinaryValue, CompressedBuffer> Value; }; /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -290,9 +317,7 @@ public: * The iterator is advanced as object and attachment fields are consumed from it. */ ZENCORE_API bool TryLoad(CbFieldIterator& Fields); - - ZENCORE_API bool TryLoad(IoBuffer& Buffer, BufferAllocator Allocator = UniqueBuffer::Alloc, AttachmentResolver* Mapper = nullptr); - + ZENCORE_API bool TryLoad(IoBuffer Buffer, BufferAllocator Allocator = UniqueBuffer::Alloc, AttachmentResolver* Mapper = nullptr); ZENCORE_API bool TryLoad(BinaryReader& Reader, BufferAllocator Allocator = UniqueBuffer::Alloc, AttachmentResolver* Mapper = nullptr); /** Save the object and attachments into the writer as a stream of compact binary fields. */ @@ -313,6 +338,17 @@ private: IoHash ObjectHash; }; +namespace legacy { + void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer); + void SaveCbPackage(const CbPackage& Package, CbWriter& Writer); + void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar); + bool TryLoadCbPackage(CbPackage& Package, IoBuffer Buffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper = nullptr); + bool TryLoadCbPackage(CbPackage& Package, + BinaryReader& Reader, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper = nullptr); +} // namespace legacy + void usonpackage_forcelink(); // internal } // namespace zen diff --git a/zencore/include/zencore/compactbinaryvalidation.h b/zencore/include/zencore/compactbinaryvalidation.h index 9799c594a..b1fab9572 100644 --- a/zencore/include/zencore/compactbinaryvalidation.h +++ b/zencore/include/zencore/compactbinaryvalidation.h @@ -58,10 +58,13 @@ enum class CbValidateMode : uint32_t Padding = 1 << 3, /** - * Validate that a package or attachment has the expected fields and matches its saved hashes. + * Validate that a package or attachment has the expected fields. */ Package = 1 << 4, + /** + * Validate that a package or attachment matches its saved hashes. + */ PackageHash = 1 << 5, /** Perform all validation described above. */ diff --git a/zencore/include/zencore/except.h b/zencore/include/zencore/except.h index 0ae31dc71..8625f01d0 100644 --- a/zencore/include/zencore/except.h +++ b/zencore/include/zencore/except.h @@ -55,5 +55,18 @@ ThrowSystemException(const char* Message) ZENCORE_API void ThrowLastError(std::string_view Message); ZENCORE_API void ThrowLastError(std::string_view Message, const std::source_location& Location); ZENCORE_API std::string GetLastErrorAsString(); +ZENCORE_API std::string GetWindowsErrorAsString(uint32_t Win32ErrorCode); + +inline std::error_code +MakeWin32ErrorCode(uint32_t Win32ErrorCode) noexcept +{ + return std::error_code(Win32ErrorCode, std::system_category()); +} + +inline std::error_code +MakeErrorCodeFromLastError() noexcept +{ + return std::error_code(::GetLastError(), std::system_category()); +} } // namespace zen diff --git a/zencore/include/zencore/httpclient.h b/zencore/include/zencore/httpclient.h deleted file mode 100644 index 4b30eb09b..000000000 --- a/zencore/include/zencore/httpclient.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "zencore.h" - -#include <zencore/string.h> -#include <gsl/gsl-lite.hpp> - -namespace zen { - -/** Asynchronous HTTP client implementation for Zen use cases - */ -class HttpClient -{ -public: -private: -}; - -} // namespace zen - -void httpclient_forcelink(); // internal diff --git a/zencore/include/zencore/iobuffer.h b/zencore/include/zencore/iobuffer.h index 121b73adc..034c3566f 100644 --- a/zencore/include/zencore/iobuffer.h +++ b/zencore/include/zencore/iobuffer.h @@ -19,6 +19,7 @@ enum class ZenContentType : uint8_t kCbObject, kCbPackage, kYAML, + kCbPackageOffer, kUnknownContentType }; diff --git a/zencore/include/zencore/logging.h b/zencore/include/zencore/logging.h index 7a08cc48b..a2404a5e9 100644 --- a/zencore/include/zencore/logging.h +++ b/zencore/include/zencore/logging.h @@ -19,5 +19,4 @@ spdlog::logger& Get(std::string_view Name); void InitializeLogging(); void ShutdownLogging(); - } // namespace zen::logging diff --git a/zencore/include/zencore/refcount.h b/zencore/include/zencore/refcount.h index 288b649c6..50bd82f59 100644 --- a/zencore/include/zencore/refcount.h +++ b/zencore/include/zencore/refcount.h @@ -117,6 +117,7 @@ public: [[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; } inline explicit operator bool() const { return m_Ref != nullptr; } inline T* operator->() const { return m_Ref; } + inline T* Get() const { return m_Ref; } inline std::strong_ordering operator<=>(const Ref& Rhs) const = default; diff --git a/zencore/logging.cpp b/zencore/logging.cpp index 89d588650..6441fc3bc 100644 --- a/zencore/logging.cpp +++ b/zencore/logging.cpp @@ -1,3 +1,5 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + #include "zencore/logging.h" #include <spdlog/sinks/stdout_color_sinks.h> @@ -42,12 +44,12 @@ ConsoleLog() return *ConLogger; } -void +void InitializeLogging() { } -void +void ShutdownLogging() { spdlog::drop_all(); diff --git a/zencore/zencore.vcxproj b/zencore/zencore.vcxproj index 4040d5ae1..4f1e63670 100644 --- a/zencore/zencore.vcxproj +++ b/zencore/zencore.vcxproj @@ -122,8 +122,6 @@ <ClInclude Include="include\zencore\compress.h" /> <ClInclude Include="include\zencore\filesystem.h" /> <ClInclude Include="include\zencore\fmtutils.h" /> - <ClInclude Include="include\zencore\httpclient.h" /> - <ClInclude Include="include\zencore\httpserver.h" /> <ClInclude Include="include\zencore\intmath.h" /> <ClInclude Include="include\zencore\iohash.h" /> <ClInclude Include="include\zencore\logging.h" /> @@ -154,7 +152,6 @@ <ClInclude Include="include\zencore\windows.h" /> <ClInclude Include="include\zencore\xxhash.h" /> <ClInclude Include="include\zencore\zencore.h" /> - <ClInclude Include="iothreadpool.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="blake3.cpp" /> @@ -163,11 +160,8 @@ <ClCompile Include="crc32.cpp" /> <ClCompile Include="except.cpp" /> <ClCompile Include="filesystem.cpp" /> - <ClCompile Include="httpclient.cpp" /> - <ClCompile Include="httpserver.cpp" /> <ClCompile Include="intmath.cpp" /> <ClCompile Include="iohash.cpp" /> - <ClCompile Include="iothreadpool.cpp" /> <ClCompile Include="logging.cpp" /> <ClCompile Include="md5.cpp" /> <ClCompile Include="memory.cpp" /> diff --git a/zencore/zencore.vcxproj.filters b/zencore/zencore.vcxproj.filters index 3a291e967..de3d915b8 100644 --- a/zencore/zencore.vcxproj.filters +++ b/zencore/zencore.vcxproj.filters @@ -21,7 +21,6 @@ <ClInclude Include="include\zencore\enumflags.h" /> <ClInclude Include="include\zencore\except.h" /> <ClInclude Include="include\zencore\filesystem.h" /> - <ClInclude Include="include\zencore\httpserver.h" /> <ClInclude Include="include\zencore\refcount.h" /> <ClInclude Include="include\zencore\memory.h" /> <ClInclude Include="include\zencore\windows.h" /> @@ -31,11 +30,9 @@ <ClInclude Include="include\zencore\compactbinarybuilder.h" /> <ClInclude Include="include\zencore\compactbinarypackage.h" /> <ClInclude Include="include\zencore\compactbinaryvalidation.h" /> - <ClInclude Include="include\zencore\httpclient.h" /> <ClInclude Include="include\zencore\md5.h" /> <ClInclude Include="include\zencore\fmtutils.h" /> <ClInclude Include="include\zencore\xxhash.h" /> - <ClInclude Include="iothreadpool.h" /> <ClInclude Include="include\zencore\varint.h" /> <ClInclude Include="include\zencore\endian.h" /> <ClInclude Include="include\zencore\compositebuffer.h" /> @@ -53,7 +50,6 @@ <ClCompile Include="uid.cpp" /> <ClCompile Include="blake3.cpp" /> <ClCompile Include="filesystem.cpp" /> - <ClCompile Include="httpserver.cpp" /> <ClCompile Include="memory.cpp" /> <ClCompile Include="refcount.cpp" /> <ClCompile Include="stats.cpp" /> @@ -68,11 +64,9 @@ <ClCompile Include="compactbinarybuilder.cpp" /> <ClCompile Include="compactbinarypackage.cpp" /> <ClCompile Include="compactbinaryvalidation.cpp" /> - <ClCompile Include="httpclient.cpp" /> <ClCompile Include="md5.cpp" /> <ClCompile Include="except.cpp" /> <ClCompile Include="xxhash.cpp" /> - <ClCompile Include="iothreadpool.cpp" /> <ClCompile Include="compress.cpp" /> <ClCompile Include="compositebuffer.cpp" /> <ClCompile Include="crc32.cpp" /> diff --git a/zenhttp/httpclient.cpp b/zenhttp/httpclient.cpp new file mode 100644 index 000000000..b7df12026 --- /dev/null +++ b/zenhttp/httpclient.cpp @@ -0,0 +1,158 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpclient.h> +#include <zenhttp/httpserver.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/sharedbuffer.h> +#include <zencore/stream.h> + +#include "httpshared.h" + +#include <doctest/doctest.h> + +static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; + +namespace zen { + +using namespace std::literals; + +HttpClient::Response +FromCprResponse(cpr::Response& InResponse) +{ + return {.StatusCode = InResponse.status_code}; +} + +////////////////////////////////////////////////////////////////////////// + +HttpClient::HttpClient(std::string_view BaseUri) : m_BaseUri(BaseUri) +{ +} + +HttpClient::~HttpClient() +{ +} + +HttpClient::Response +HttpClient::TransactPackage(std::string_view Url, CbPackage Package) +{ + cpr::Session Sess; + Sess.SetUrl(m_BaseUri + std::string(Url)); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++HttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + IoHash Hash = Attachment.GetHash(); + + Writer.AddHash(Hash); + } + + Writer.EndArray(); + + MemoryOutStream MemOut; + BinaryWriter MemWriter(MemOut); + Writer.Save(MemWriter); + + Sess.SetHeader( + {{"Content-Type", "application/x-ue-offer"}, {"UE-Session", "123456789012345678901234"}, {"UE-Request", RequestIdString}}); + Sess.SetBody(cpr::Body{(const char*)MemOut.Data(), MemOut.Size()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (FilterResponse.status_code == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); + CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer); + + for (auto& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + else + { + // This should be an error -- server asked to have something we can't find + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + Sess.SetHeader( + {{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", "123456789012345678901234"}, {"UE-Request", RequestIdString}}); + Sess.SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (!IsHttpSuccessCode(FilterResponse.status_code)) + { + return FromCprResponse(FilterResponse); + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); + + if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) + { + HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + return {.StatusCode = FilterResponse.status_code, .ResponsePayload = ResponseBuffer}; +} + +HttpClient::Response +HttpClient::Delete(std::string_view Url) +{ + ZEN_UNUSED(Url); + return {}; +} + +////////////////////////////////////////////////////////////////////////// + +TEST_CASE("httpclient") +{ + using namespace std::literals; + + SUBCASE("client") {} +} + +void +httpclient_forcelink() +{ +} + +} // namespace zen diff --git a/zenhttp/httpnull.cpp b/zenhttp/httpnull.cpp new file mode 100644 index 000000000..57cba13d3 --- /dev/null +++ b/zenhttp/httpnull.cpp @@ -0,0 +1,67 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpnull.h" + +#include <conio.h> +#include <zencore/logging.h> + +namespace zen { + +HttpNullServer::HttpNullServer() +{ +} + +HttpNullServer::~HttpNullServer() +{ +} + +void +HttpNullServer::RegisterService(HttpService& Service) +{ + ZEN_UNUSED(Service); +} + +void +HttpNullServer::Initialize(int BasePort) +{ + ZEN_UNUSED(BasePort); +} + +void +HttpNullServer::Run(bool TestMode) +{ + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit"); + } + + do + { + int WaitTimeout = -1; + + if (!TestMode) + { + WaitTimeout = 1000; + } + + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +} + +void +HttpNullServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +} // namespace zen diff --git a/zenhttp/httpnull.h b/zenhttp/httpnull.h new file mode 100644 index 000000000..b15b1b123 --- /dev/null +++ b/zenhttp/httpnull.h @@ -0,0 +1,29 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +namespace zen { + +/** + * @brief Null implementation of "http" server. Does nothing + */ + +class HttpNullServer : public HttpServer +{ +public: + HttpNullServer(); + ~HttpNullServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual void Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + +private: + Event m_ShutdownEvent; +}; + +} // namespace zen diff --git a/zenhttp/httpserver.cpp b/zenhttp/httpserver.cpp new file mode 100644 index 000000000..f97ac0067 --- /dev/null +++ b/zenhttp/httpserver.cpp @@ -0,0 +1,533 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpserver.h> + +#include "httpnull.h" +#include "httpshared.h" +#include "httpsys.h" +#include "httpuws.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/refcount.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/thread.h> + +#include <conio.h> +#include <new.h> +#include <charconv> +#include <span> +#include <string_view> + +#include <doctest/doctest.h> + +namespace zen { + +using namespace std::literals; + +std::string_view +MapContentTypeToString(HttpContentType ContentType) +{ + switch (ContentType) + { + default: + case HttpContentType::kUnknownContentType: + case HttpContentType::kBinary: + return "application/octet-stream"sv; + + case HttpContentType::kText: + return "text/plain"sv; + + case HttpContentType::kJSON: + return "application/json"sv; + + case HttpContentType::kCbObject: + return "application/x-ue-cb"sv; + + case HttpContentType::kCbPackage: + return "application/x-ue-cbpkg"sv; + + case HttpContentType::kCbPackageOffer: + return "application/x-ue-offer"sv; + + case HttpContentType::kYAML: + return "text/yaml"sv; + } +} + +static const uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); +static const uint32_t HashJson = HashStringDjb2("application/json"sv); +static const uint32_t HashYaml = HashStringDjb2("text/yaml"sv); +static const uint32_t HashText = HashStringDjb2("text/plain"sv); +static const uint32_t HashCompactBinary = HashStringDjb2("application/x-ue-cb"sv); +static const uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv); +static const uint32_t HashCompactBinaryPackageOffer = HashStringDjb2("application/x-ue-offer"sv); + +HttpContentType +ParseContentType(const std::string_view& ContentTypeString) +{ + if (!ContentTypeString.empty()) + { + const uint32_t CtHash = HashStringDjb2(ContentTypeString); + + if (CtHash == HashBinary) + { + return HttpContentType::kBinary; + } + else if (CtHash == HashCompactBinary) + { + return HttpContentType::kCbObject; + } + else if (CtHash == HashCompactBinaryPackage) + { + return HttpContentType::kCbPackage; + } + else if (CtHash == HashCompactBinaryPackageOffer) + { + return HttpContentType::kCbPackageOffer; + } + else if (CtHash == HashJson) + { + return HttpContentType::kJSON; + } + else if (CtHash == HashYaml) + { + return HttpContentType::kYAML; + } + else if (CtHash == HashText) + { + return HttpContentType::kText; + } + } + + return HttpContentType::kUnknownContentType; +} + +const char* +ReasonStringForHttpResultCode(int HttpCode) +{ + switch (HttpCode) + { + // 1xx Informational + + case 100: + return "Continue"; + case 101: + return "Switching Protocols"; + + // 2xx Success + + case 200: + return "OK"; + case 201: + return "Created"; + case 202: + return "Accepted"; + case 204: + return "No Content"; + case 205: + return "Reset Content"; + case 206: + return "Partial Content"; + + // 3xx Redirection + + case 300: + return "Multiple Choices"; + case 301: + return "Moved Permanently"; + case 302: + return "Found"; + case 303: + return "See Other"; + case 304: + return "Not Modified"; + case 305: + return "Use Proxy"; + case 306: + return "Switch Proxy"; + case 307: + return "Temporary Redirect"; + case 308: + return "Permanent Redirect"; + + // 4xx Client errors + + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 402: + return "Payment Required"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 406: + return "Not Acceptable"; + case 407: + return "Proxy Authentication Required"; + case 408: + return "Request Timeout"; + case 409: + return "Conflict"; + case 410: + return "Gone"; + case 411: + return "Length Required"; + case 412: + return "Precondition Failed"; + case 413: + return "Payload Too Large"; + case 414: + return "URI Too Long"; + case 415: + return "Unsupported Media Type"; + case 416: + return "Range Not Satisifiable"; + case 417: + return "Expectation Failed"; + case 418: + return "I'm a teapot"; + case 421: + return "Misdirected Request"; + case 422: + return "Unprocessable Entity"; + case 423: + return "Locked"; + case 424: + return "Failed Dependency"; + case 425: + return "Too Early"; + case 426: + return "Upgrade Required"; + case 428: + return "Precondition Required"; + case 429: + return "Too Many Requests"; + case 431: + return "Request Header Fields Too Large"; + + // 5xx Server errors + + case 500: + return "Internal Server Error"; + case 501: + return "Not Implemented"; + case 502: + return "Bad Gateway"; + case 503: + return "Service Unavailable"; + case 504: + return "Gateway Timeout"; + case 505: + return "HTTP Version Not Supported"; + case 506: + return "Variant Also Negotiates"; + case 507: + return "Insufficient Storage"; + case 508: + return "Loop Detected"; + case 510: + return "Not Extended"; + case 511: + return "Network Authentication Required"; + + default: + return "Unknown Result"; + } +} + +////////////////////////////////////////////////////////////////////////// + +Ref<IHttpPackageHandler> +HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) +{ + ZEN_UNUSED(HttpServiceRequest); + + return nullptr; +} + +////////////////////////////////////////////////////////////////////////// + +HttpServerRequest::HttpServerRequest() +{ +} + +HttpServerRequest::~HttpServerRequest() +{ +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbPackage Data) +{ + std::vector<IoBuffer> ResponseBuffers = FormatPackageMessage(Data); + return WriteResponse(ResponseCode, HttpContentType::kCbPackage, ResponseBuffers); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data) +{ + SharedBuffer Buf = Data.GetBuffer(); + std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; + return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString) +{ + return WriteResponse(ResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()}); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob) +{ + std::array<IoBuffer, 1> Buffers{Blob}; + return WriteResponse(ResponseCode, ContentType, Buffers); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) +{ + std::span<const SharedBuffer> Segments = Payload.GetSegments(); + + std::vector<IoBuffer> Buffers; + + for (auto& Segment : Segments) + { + Buffers.push_back(Segment.AsIoBuffer()); + } + + WriteResponse(ResponseCode, ContentType, Payload); +} + +HttpServerRequest::QueryParams +HttpServerRequest::GetQueryParams() +{ + QueryParams Params; + + const std::string_view QStr = QueryString(); + + const char* QueryIt = QStr.data(); + const char* QueryEnd = QueryIt + QStr.size(); + + while (QueryIt != QueryEnd) + { + if (*QueryIt == '&') + { + ++QueryIt; + continue; + } + + const std::string_view Query{QueryIt, QueryEnd}; + + size_t DelimIndex = Query.find('&', 0); + + if (DelimIndex == std::string_view::npos) + { + DelimIndex = Query.size(); + } + + std::string_view ThisQuery{QueryIt, DelimIndex}; + + size_t EqIndex = ThisQuery.find('=', 0); + + if (EqIndex != std::string_view::npos) + { + std::string_view Parm{ThisQuery.data(), EqIndex}; + ThisQuery.remove_prefix(EqIndex + 1); + + Params.KvPairs.emplace_back(Parm, ThisQuery); + } + + QueryIt += DelimIndex; + } + + return Params; +} + +Oid +HttpServerRequest::SessionId() const +{ + if (m_Flags & kHaveSessionId) + { + return m_SessionId; + } + + m_SessionId = ParseSessionId(); + m_Flags |= kHaveSessionId; + return m_SessionId; +} + +uint32_t +HttpServerRequest::RequestId() const +{ + if (m_Flags & kHaveRequestId) + { + return m_RequestId; + } + + m_RequestId = ParseRequestId(); + m_Flags |= kHaveRequestId; + return m_RequestId; +} + +CbObject +HttpServerRequest::ReadPayloadObject() +{ + if (IoBuffer Payload = ReadPayload()) + { + return LoadCompactBinaryObject(std::move(Payload)); + } + + return {}; +} + +CbPackage +HttpServerRequest::ReadPayloadPackage() +{ + if (IoBuffer Payload = ReadPayload()) + { + return ParsePackageMessage(std::move(Payload)); + } + + return {}; +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpRequestRouter::AddPattern(const char* Id, const char* Regex) +{ + ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); + + m_PatternMap.insert({Id, Regex}); +} + +void +HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) +{ + ExtendableStringBuilder<128> ExpandedRegex; + ProcessRegexSubstitutions(Regex, ExpandedRegex); + + m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex); +} + +void +HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) +{ + size_t RegexLen = strlen(Regex); + + for (size_t i = 0; i < RegexLen;) + { + bool matched = false; + + if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\'))) + { + // Might have a pattern reference - find closing brace + + for (size_t j = i + 1; j < RegexLen; ++j) + { + if (Regex[j] == '}') + { + std::string Pattern(&Regex[i + 1], j - i - 1); + + if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) + { + OutExpandedRegex.Append(it->second.c_str()); + } + else + { + // Default to anything goes (or should this just be an error?) + + OutExpandedRegex.Append("(.+?)"); + } + + // skip ahead + i = j + 1; + + matched = true; + + break; + } + } + } + + if (!matched) + { + OutExpandedRegex.Append(Regex[i++]); + } + } +} + +bool +HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) +{ + const HttpVerb Verb = Request.RequestVerb(); + + std::string_view Uri = Request.RelativeUri(); + HttpRouterRequest RouterRequest(Request); + + for (const auto& Handler : m_Handlers) + { + if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) + { + Handler.Handler(RouterRequest); + + return true; // Route matched + } + } + + return false; // No route matched +} + +////////////////////////////////////////////////////////////////////////// + +Ref<HttpServer> +CreateHttpServer() +{ +#if 0 + return new HttpUwsServer; +#elif ZEN_WITH_HTTPSYS + return new HttpSysServer{std::thread::hardware_concurrency()}; +#else + return new HttpNullServer; +#endif +} + +////////////////////////////////////////////////////////////////////////// + +TEST_CASE("http") +{ + using namespace std::literals; + + SUBCASE("router") + { + HttpRequestRouter r; + r.AddPattern("a", "[[:alpha:]]+"); + r.RegisterRoute( + "{a}", + [&](auto) {}, + HttpVerb::kGet); + + // struct TestHttpServerRequest : public HttpServerRequest + //{ + // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {} + //}; + + // TestHttpServerRequest req{}; + // r.HandleRequest(req); + } +} + +void +http_forcelink() +{ +} + +} // namespace zen diff --git a/zenhttp/httpshared.cpp b/zenhttp/httpshared.cpp new file mode 100644 index 000000000..68252a763 --- /dev/null +++ b/zenhttp/httpshared.cpp @@ -0,0 +1,138 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpshared.h" + +#include <zencore/compactbinarypackage.h> +#include <zencore/compositebuffer.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/stream.h> + +#include <span> +#include <vector> + +namespace zen { + +CompositeBuffer +FormatPackageMessageBuffer(const CbPackage& Data) +{ + std::vector<IoBuffer> Message = FormatPackageMessage(Data); + + std::vector<SharedBuffer> Buffers; + + for (IoBuffer& Buf : Message) + { + Buffers.push_back(SharedBuffer(Buf)); + } + + return CompositeBuffer(std::move(Buffers)); +} + +std::vector<IoBuffer> +FormatPackageMessage(const CbPackage& Data) +{ + const std::span<const CbAttachment>& Attachments = Data.GetAttachments(); + + std::vector<IoBuffer> ResponseBuffers; + ResponseBuffers.reserve(3 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each + // attachment is likely to consist of several buffers + + uint64_t TotalAttachmentsSize = 0; + + // Fixed size header + + CbPackageHeader Hdr{.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())}; + + ResponseBuffers.push_back(IoBufferBuilder::MakeCloneFromMemory(&Hdr, sizeof Hdr)); + + // Attachment metadata array + + IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)}; + + CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(AttachmentMetadataBuffer.MutableData()); + + ResponseBuffers.push_back(AttachmentMetadataBuffer); // Attachment metadata + + // Root object + + IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer(); + ResponseBuffers.push_back(RootIoBuffer); // Root object + + *AttachmentInfo++ = {.AttachmentSize = RootIoBuffer.Size(), .AttachmentHash = Data.GetObjectHash()}; + + // Attachment payloads + + for (const CbAttachment& Attachment : Attachments) + { + CompressedBuffer AttachmentBuffer = Attachment.AsCompressedBinary(); + CompositeBuffer Compressed = AttachmentBuffer.GetCompressed(); + + *AttachmentInfo++ = {.AttachmentSize = AttachmentBuffer.GetCompressedSize(), + .AttachmentHash = IoHash::FromBLAKE3(AttachmentBuffer.GetRawHash())}; + + for (const SharedBuffer& Segment : Compressed.GetSegments()) + { + ResponseBuffers.push_back(Segment.AsIoBuffer()); + TotalAttachmentsSize += Segment.GetSize(); + } + } + + return std::move(ResponseBuffers); +} + +CbPackage +ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) +{ + if (!Payload) + { + return {}; + } + + MemoryInStream InStream(Payload); + BinaryReader Reader(InStream); + + CbPackage Package; + + CbPackageHeader Hdr; + Reader.Read(&Hdr, sizeof Hdr); + + if (Hdr.HeaderMagic != kCbPkgMagic) + { + // report error + return {}; + } + + uint32_t ChunkCount = Hdr.AttachmentCount + 1; + + std::unique_ptr<CbAttachmentEntry[]> AttachmentEntries{new CbAttachmentEntry[ChunkCount]}; + + Reader.Read(AttachmentEntries.get(), sizeof(CbAttachmentEntry) * ChunkCount); + + for (uint32_t i = 0; i < ChunkCount; ++i) + { + const CbAttachmentEntry& Entry = AttachmentEntries[i]; + const uint64_t AttachmentSize = Entry.AttachmentSize; + IoBuffer AttachmentBuffer = CreateBuffer(Entry.AttachmentHash, AttachmentSize); + + ZEN_ASSERT(AttachmentBuffer); + ZEN_ASSERT(AttachmentBuffer.Size() == AttachmentSize); + + Reader.Read(AttachmentBuffer.MutableData(), AttachmentSize); + + CompressedBuffer CompBuf(CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer))); + + if (i == 0) + { + Package.SetObject(LoadCompactBinaryObject(std::move(CompBuf))); + } + else + { + CbAttachment Attachment(std::move(CompBuf)); + Package.AddAttachment(Attachment); + } + } + + return Package; +} + +} // namespace zen
\ No newline at end of file diff --git a/zenhttp/httpshared.h b/zenhttp/httpshared.h new file mode 100644 index 000000000..06fdb104f --- /dev/null +++ b/zenhttp/httpshared.h @@ -0,0 +1,45 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> + +#include <functional> + +namespace zen { + +class IoBuffer; +class CbPackage; +class CompositeBuffer; + +struct CbPackageHeader +{ + uint32_t HeaderMagic; + uint32_t AttachmentCount; + uint32_t Reserved1; + uint32_t Reserved2; +}; + +static_assert(sizeof(CbPackageHeader) == 16); + +static constinit uint32_t kCbPkgMagic = 0xaa77aacc; + +struct CbAttachmentEntry +{ + uint64_t AttachmentSize; + uint32_t Reserved1; + IoHash AttachmentHash; +}; + +static_assert(sizeof(CbAttachmentEntry) == 32); + +std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data); +CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data); +CbPackage ParsePackageMessage( + IoBuffer Payload, + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { + return IoBuffer{Size}; + }); + +} // namespace zen diff --git a/zenhttp/httpsys.cpp b/zenhttp/httpsys.cpp new file mode 100644 index 000000000..9ee004c5c --- /dev/null +++ b/zenhttp/httpsys.cpp @@ -0,0 +1,1431 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpsys.h" + +#include "httpshared.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/except.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> + +#if ZEN_WITH_HTTPSYS + +# include <conio.h> +# include <mstcpip.h> +# pragma comment(lib, "httpapi.lib") + +std::wstring +UTF8_to_wstring(const char* in) +{ + std::wstring out; + unsigned int codepoint; + + while (*in != 0) + { + unsigned char ch = static_cast<unsigned char>(*in); + + if (ch <= 0x7f) + codepoint = ch; + else if (ch <= 0xbf) + codepoint = (codepoint << 6) | (ch & 0x3f); + else if (ch <= 0xdf) + codepoint = ch & 0x1f; + else if (ch <= 0xef) + codepoint = ch & 0x0f; + else + codepoint = ch & 0x07; + + ++in; + + if (((*in & 0xc0) != 0x80) && (codepoint <= 0x10ffff)) + { + if (sizeof(wchar_t) > 2) + { + out.append(1, static_cast<wchar_t>(codepoint)); + } + else if (codepoint > 0xffff) + { + out.append(1, static_cast<wchar_t>(0xd800 + (codepoint >> 10))); + out.append(1, static_cast<wchar_t>(0xdc00 + (codepoint & 0x03ff))); + } + else if (codepoint < 0xd800 || codepoint >= 0xe000) + { + out.append(1, static_cast<wchar_t>(codepoint)); + } + } + } + + return out; +} + +namespace zen { + +using namespace std::literals; + +class HttpSysServer; +class HttpSysTransaction; +class HttpMessageResponseRequest; + +////////////////////////////////////////////////////////////////////////// + +HttpVerb +TranslateHttpVerb(HTTP_VERB ReqVerb) +{ + switch (ReqVerb) + { + case HttpVerbOPTIONS: + return HttpVerb::kOptions; + + case HttpVerbGET: + return HttpVerb::kGet; + + case HttpVerbHEAD: + return HttpVerb::kHead; + + case HttpVerbPOST: + return HttpVerb::kPost; + + case HttpVerbPUT: + return HttpVerb::kPut; + + case HttpVerbDELETE: + return HttpVerb::kDelete; + + case HttpVerbCOPY: + return HttpVerb::kCopy; + + default: + // TODO: invalid request? + return (HttpVerb)0; + } +} + +uint64_t +GetContentLength(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + uint64_t ContentLength = 0; + std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); + return ContentLength; +}; + +HttpContentType +GetContentType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +HttpContentType +GetAcceptType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +/** + * @brief Base class for any pending or active HTTP transactions + */ +class HttpSysRequestHandler +{ +public: + explicit HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {} + virtual ~HttpSysRequestHandler() = default; + + virtual void IssueRequest(std::error_code& ErrorCode) = 0; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; + + HttpSysTransaction& Transaction() { return m_Request; } + +private: + HttpSysTransaction& m_Request; // Related HTTP transaction object +}; + +/** + * This is the handler for the initial HTTP I/O request which will receive the headers + * and however much of the remaining payload might fit in the embedded request buffer. + * + * It is also used to receive any entity body data relating to the request + * + */ +struct InitialRequestHandler : public HttpSysRequestHandler +{ + inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } + inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + inline bool IsInitialRequest() const { return m_IsInitialRequest; } + + InitialRequestHandler(HttpSysTransaction& InRequest); + ~InitialRequestHandler(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + + bool m_IsInitialRequest = true; + uint64_t m_CurrentPayloadOffset = 0; + uint64_t m_ContentLength = ~uint64_t(0); + IoBuffer m_PayloadBuffer; + UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpSysServerRequest : public HttpServerRequest +{ +public: + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpSysServerRequest() = default; + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + + using HttpServerRequest::WriteResponse; + + HttpSysTransaction& m_HttpTx; + HttpMessageResponseRequest* m_Response = nullptr; // TODO: make this more general + IoBuffer m_PayloadBuffer; +}; + +/** HTTP transaction + + There will be an instance of this per pending and in-flight HTTP transaction + + */ +class HttpSysTransaction final +{ +public: + HttpSysTransaction(HttpSysServer& Server); + virtual ~HttpSysTransaction(); + + enum class Status + { + kDone, + kRequestPending + }; + + Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + + static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io); + + void IssueInitialRequest(std::error_code& ErrorCode); + PTP_IO Iocp(); + HANDLE RequestQueueHandle(); + inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline HttpSysServer& Server() { return m_HttpServer; } + inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } + + HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); + +private: + OVERLAPPED m_HttpOverlapped{}; + HttpSysServer& m_HttpServer; + + // Tracks which handler is due to handle the next I/O completion event + HttpSysRequestHandler* m_CompletionHandler = nullptr; + RwLock m_CompletionMutex; + InitialRequestHandler m_InitialHttpHandler{*this}; + std::optional<HttpSysServerRequest> m_HandlerRequest; + Ref<IHttpPackageHandler> m_PackageHandler; +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * @brief HTTP request response I/O request handler + * + * Asynchronously streams out a response to an HTTP request via compound + * responses from memory or directly from file + */ + +class HttpMessageResponseRequest : public HttpSysRequestHandler +{ +public: + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> Blobs); + ~HttpMessageResponseRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + void SuppressResponseBody(); // typically used for HEAD requests + +private: + std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks; + uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes + uint16_t m_ResponseCode = 0; + uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists + uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends + bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::vector<IoBuffer> m_DataBuffers; + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); +}; + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) +: HttpSysRequestHandler(InRequest) +{ + std::array<IoBuffer, 0> EmptyBufferList; + + InitializeForPayload(ResponseCode, EmptyBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) +: HttpSysRequestHandler(InRequest) +, m_ContentType(HttpContentType::kText) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> BlobList) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + InitializeForPayload(ResponseCode, BlobList); +} + +HttpMessageResponseRequest::~HttpMessageResponseRequest() +{ +} + +void +HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) +{ + m_ResponseCode = ResponseCode; + + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_HttpDataChunks.reserve(ChunkCount); + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + + // Initialize the full array up front + + uint64_t LocalDataSize = 0; + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // Use direct file transfer + + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); + + Chunk.DataChunkType = HttpDataChunkFromFileHandle; + Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; + Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; + Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; + } + else + { + // Send from memory, need to make sure we chunk the buffer up since + // the underlying data structure only accepts 32-bit chunk sizes for + // memory chunks. When this happens the vector will be reallocated, + // which is fine since this will be a pretty rare case and sending + // the data is going to take a lot longer than a memory allocation :) + + const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data()); + + while (BufferDataSize) + { + const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); + + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); + + Chunk.DataChunkType = HttpDataChunkFromMemory; + Chunk.FromMemory.pBuffer = (void*)WriteCursor; + Chunk.FromMemory.BufferLength = ThisChunkSize; + + BufferDataSize -= ThisChunkSize; + WriteCursor += ThisChunkSize; + } + } + } + + m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size()); + m_TotalDataSize = LocalDataSize; +} + +void +HttpMessageResponseRequest::SuppressResponseBody() +{ + m_RemainingChunkCount = 0; + m_HttpDataChunks.clear(); + m_DataBuffers.clear(); +} + +HttpSysRequestHandler* +HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + if (IoResult) + { + spdlog::warn("response aborted due to error: '{}'", GetWindowsErrorAsString(IoResult)); + + // if one transmit failed there's really no need to go on + return nullptr; + } + + if (m_RemainingChunkCount == 0) + { + return nullptr; // All done + } + + return this; +} + +void +HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) +{ + HttpSysTransaction& Tx = Transaction(); + HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); + PTP_IO const Iocp = Tx.Iocp(); + + StartThreadpoolIo(Iocp); + + // Split payload into batches to play well with the underlying API + + const int MaxChunksPerCall = 9999; + + const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall); + const int ThisRequestChunkOffset = m_NextDataChunkOffset; + + m_RemainingChunkCount -= ThisRequestChunkCount; + m_NextDataChunkOffset += ThisRequestChunkCount; + + /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? + + From the docs: + + This flag enables buffering of data in the kernel on a per-response basis. It should + be used by an application doing synchronous I/O, or by a an application doing + asynchronous I/O with no more than one send outstanding at a time. + + Applications using asynchronous I/O which may have more than one send outstanding at + a time should not use this flag. + + When this flag is set, it should be used consistently in calls to the + HttpSendHttpResponse function as well. + */ + + ULONG SendFlags = 0; + + if (m_RemainingChunkCount) + { + // We need to make more calls to send the full amount of data + SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + } + + ULONG SendResult = 0; + + if (m_IsInitialResponse) + { + // Populate response structure + + HTTP_RESPONSE HttpResponse = {}; + + HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); + HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; + + // Content-length header + + char ContentLengthString[32]; + _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); + + PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; + ContentLengthHeader->pRawValue = ContentLengthString; + ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); + + // Content-type header + + PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; + + std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + + ContentTypeHeader->pRawValue = ContentTypeString.data(); + ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + + HttpResponse.StatusCode = m_ResponseCode; + HttpResponse.pReason = ReasonStringForHttpResultCode(m_ResponseCode); + HttpResponse.ReasonLength = (USHORT)strlen(HttpResponse.pReason); + + // Cache policy + + HTTP_CACHE_POLICY CachePolicy; + + CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.SecondsToLive = 0; + + // Initial response API call + + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + &HttpResponse, + &CachePolicy, + NULL, + NULL, + 0, + Tx.Overlapped(), + NULL); + + m_IsInitialResponse = false; + } + else + { + // Subsequent response API calls + + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + (USHORT)ThisRequestChunkCount, // EntityChunkCount + &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); + } + + if ((SendResult != NO_ERROR) // Synchronous completion, but the completion event will still be posted to IOCP + && (SendResult != ERROR_IO_PENDING) // Asynchronous completion + ) + { + // Some error occurred, no completion will be posted + + CancelThreadpoolIo(Iocp); + + spdlog::error("failed to send HTTP response (error: '{}'), request URL: {}"sv, SendResult, HttpReq->pRawUrl); + + ErrorCode = MakeWin32ErrorCode(SendResult); + } + else + { + ErrorCode = {}; + } +} + +/** + _________ + / _____/ ______________ __ ___________ + \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ + / \ ___/| | \/\ /\ ___/| | \/ + /_______ /\___ >__| \_/ \___ >__| + \/ \/ \/ +*/ + +HttpSysServer::HttpSysServer(unsigned int ThreadCount) : m_ThreadPool(ThreadCount) +{ + ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); + + if (Result != NO_ERROR) + { + return; + } + + m_IsHttpInitialized = true; + m_IsOk = true; +} + +HttpSysServer::~HttpSysServer() +{ + if (m_IsHttpInitialized) + { + Cleanup(); + + HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); + } +} + +void +HttpSysServer::Initialize(const wchar_t* UrlPath) +{ + m_IsOk = false; + + ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); + + if (Result != NO_ERROR) + { + spdlog::error("Failed to create server session for '{}': {x}"sv, WideToUtf8(UrlPath), Result); + + return; + } + + Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); + + if (Result != NO_ERROR) + { + spdlog::error("Failed to create URL group for '{}': {x}"sv, WideToUtf8(UrlPath), Result); + + return; + } + + m_BaseUri = UrlPath; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, UrlPath, HTTP_URL_CONTEXT(0), 0); + + if (Result != NO_ERROR) + { + spdlog::error("Failed to add base URL to URL group for '{}': {x}"sv, WideToUtf8(UrlPath), Result); + + return; + } + + HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; + + Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ nullptr, + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); + + if (Result != NO_ERROR) + { + spdlog::error("Failed to create request queue for '{}': {x}"sv, WideToUtf8(UrlPath), Result); + + return; + } + + HttpBindingInfo.Flags.Present = 1; + HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; + + Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); + + if (Result != NO_ERROR) + { + spdlog::error("Failed to set server binding property for '{}': {x}"sv, WideToUtf8(UrlPath), Result); + + return; + } + + // Create I/O completion port + + std::error_code ErrorCode; + m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); + + if (ErrorCode) + { + spdlog::error("Failed to create IOCP for '{}': {}"sv, WideToUtf8(UrlPath), ErrorCode.message()); + } + else + { + m_IsOk = true; + } +} + +void +HttpSysServer::Cleanup() +{ + ++m_IsShuttingDown; + + if (m_RequestQueueHandle) + { + HttpCloseRequestQueue(m_RequestQueueHandle); + m_RequestQueueHandle = nullptr; + } + + if (m_HttpUrlGroupId) + { + HttpCloseUrlGroup(m_HttpUrlGroupId); + m_HttpUrlGroupId = 0; + } + + if (m_HttpSessionId) + { + HttpCloseServerSession(m_HttpSessionId); + m_HttpSessionId = 0; + } +} + +void +HttpSysServer::StartServer() +{ + const int InitialRequestCount = 32; + + for (int i = 0; i < InitialRequestCount; ++i) + { + IssueNewRequestMaybe(); + } +} + +void +HttpSysServer::Run(bool TestMode) +{ + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); + } + + do + { + int WaitTimeout = -1; + + if (!TestMode) + { + WaitTimeout = 1000; + } + + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +} + +void +HttpSysServer::OnHandlingRequest() +{ + if (--m_PendingRequests > m_MinPendingRequests) + { + // We have more than the minimum number of requests pending, just let someone else + // enqueue new requests + return; + } + + IssueNewRequestMaybe(); +} + +void +HttpSysServer::IssueNewRequestMaybe() +{ + if (m_IsShuttingDown.load(std::memory_order::acquire)) + { + return; + } + + if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) + { + return; + } + + std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this); + + std::error_code ErrorCode; + Request->IssueInitialRequest(ErrorCode); + + if (ErrorCode) + { + // No request was actually issued. What is the appropriate response? + + return; + } + + // This may end up exceeding the MaxPendingRequests limit, but it's not + // really a problem. I'm doing it this way mostly to avoid dealing with + // exceptions here + ++m_PendingRequests; + + Request.release(); +} + +void +HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) +{ + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring Path16 = UTF8_to_wstring(UrlPath); + Service.SetUriPrefixLength(Path16.size() + 1 /* leading slash */); + + // Convert to wide string + + std::wstring Url16 = m_BaseUri + Path16; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + spdlog::error("HttpAddUrlToUrlGroup failed with result {}"sv, Result); + + return; + } +} + +void +HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) +{ + ZEN_UNUSED(Service); + + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring Path16 = UTF8_to_wstring(UrlPath); + + // Convert to wide string + + std::wstring Url16 = m_BaseUri + Path16; + + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + + if (Result != NO_ERROR) + { + spdlog::error("HttpRemoveUrlFromUrlGroup failed with result {}"sv, Result); + } +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) +{ +} + +HttpSysTransaction::~HttpSysTransaction() +{ +} + +PTP_IO +HttpSysTransaction::Iocp() +{ + return m_HttpServer.m_ThreadPool.Iocp(); +} + +HANDLE +HttpSysTransaction::RequestQueueHandle() +{ + return m_HttpServer.m_RequestQueueHandle; +} + +void +HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) +{ + m_InitialHttpHandler.IssueRequest(ErrorCode); +} + +void +HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io) +{ + UNREFERENCED_PARAMETER(Io); + UNREFERENCED_PARAMETER(Instance); + UNREFERENCED_PARAMETER(pContext); + + // Note that for a given transaction we may be in this completion function on more + // than one thread at any given moment. This means we need to be careful about what + // happens in here + + HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + + if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) + { + delete Transaction; + } +} + +HttpSysTransaction::Status +HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // We use this to ensure sequential execution of completion handlers + // for any given transaction. It also ensures all member variables are + // in a consistent state for the current thread + + RwLock::ExclusiveLockScope _(m_CompletionMutex); + + bool IsRequestPending = false; + + if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) + { + const bool IsInitialRequest = (CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest(); + + if (IsInitialRequest) + { + // Ensure we have a sufficient number of pending requests outstanding + m_HttpServer.OnHandlingRequest(); + } + + m_CompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); + + if (m_CompletionHandler) + { + try + { + std::error_code ErrorCode; + m_CompletionHandler->IssueRequest(ErrorCode); + + if (ErrorCode) + { + spdlog::error("IssueRequest() failed {}"sv, ErrorCode.message()); + } + else + { + IsRequestPending = true; + } + } + catch (std::exception& Ex) + { + spdlog::error("exception caught from IssueRequest(): {}"sv, Ex.what()); + + // something went wrong, no request is pending + } + } + else + { + if (IsInitialRequest == false) + { + delete CurrentHandler; + } + } + } + + // Ensure new requests are enqueued as necessary + m_HttpServer.IssueNewRequestMaybe(); + + if (IsRequestPending) + { + // There is another request pending on this transaction, so it needs to remain valid + return Status::kRequestPending; + } + + // Transaction done, caller should clean up (delete) this instance + return Status::kDone; +} + +HttpSysServerRequest& +HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) +{ + HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + + if (ThisRequest.RequestVerb() == HttpVerb::kPost) + { + if (ThisRequest.RequestContentType() == HttpContentType::kCbPackageOffer) + { + // The client is presenting us with a package attachments offer, we need + // to filter it down to the list of attachments we need them to send in + // the follow-up request + + m_PackageHandler = Service.HandlePackageRequest(ThisRequest); + + if (m_PackageHandler) + { + CbObject OfferMessage = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> OfferCids; + + for (auto& CidEntry : OfferMessage["offer"]) + { + if (!CidEntry.IsHash()) + { + // Should yield bad request response? + + continue; + } + + OfferCids.push_back(CidEntry.AsHash(IoHash::Zero)); + } + + m_PackageHandler->FilterOffer(OfferCids); + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + + for (const IoHash& Cid : OfferCids) + { + ResponseWriter.AddHash(Cid); + } + + ResponseWriter.EndArray(); + + // Emit filter response + ThisRequest.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + + return ThisRequest; + } + } + else if (ThisRequest.RequestContentType() == HttpContentType::kCbPackage) + { + // Process chunks in package request + + m_PackageHandler = Service.HandlePackageRequest(ThisRequest); + + // TODO: this should really be done in a streaming fashion, currently this emulates + // the intended flow from an API perspective + + if (m_PackageHandler) + { + m_PackageHandler->OnRequestBegin(); + + auto CreateBuffer = [&](const IoHash& Cid, uint64_t Size) -> IoBuffer { return m_PackageHandler->CreateTarget(Cid, Size); }; + + CbPackage Package = ParsePackageMessage(ThisRequest.ReadPayload(), CreateBuffer); + + m_PackageHandler->OnRequestComplete(); + } + } + } + + // Default request handling + + Service.HandleRequest(ThisRequest); + + return ThisRequest; +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) +: m_HttpTx(Tx) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); + + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); + + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This has some performance impact which I'd prefer to avoid but for now + // we just have to live with it + + WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, + m_UriUtf8); + } + else + { + m_UriUtf8.Reset(); + } + + if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; + + WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryStringUtf8); + } + else + { + m_QueryStringUtf8.Reset(); + } + + m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); + m_ContentLength = GetContentLength(HttpRequestPtr); + m_ContentType = GetContentType(HttpRequestPtr); + m_AcceptType = GetAcceptType(HttpRequestPtr); +} + +Oid +HttpSysServerRequest::ParseSessionId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Session"sv) + { + if (Header.RawValueLength == Oid::StringLength) + { + return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); + } + } + } + + return {}; +} + +uint32_t +HttpSysServerRequest::ParseRequestId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Request"sv) + { + std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; + uint32_t RequestId = 0; + std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); + return RequestId; + } + } + + return 0; +} + +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(IsHandled() == false); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + + if (SuppressBody()) + { + m_Response->SuppressResponseBody(); + } + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(IsHandled() == false); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + + if (SuppressBody()) + { + m_Response->SuppressResponseBody(); + } + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(IsHandled() == false); + + m_Response = + new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); + + if (SuppressBody()) + { + m_Response->SuppressResponseBody(); + } + + SetIsHandled(); +} + +////////////////////////////////////////////////////////////////////////// + +InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) +{ +} + +InitialRequestHandler::~InitialRequestHandler() +{ +} + +void +InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) +{ + HttpSysTransaction& Tx = Transaction(); + PTP_IO Iocp = Tx.Iocp(); + HTTP_REQUEST* HttpReq = Tx.HttpRequest(); + + StartThreadpoolIo(Iocp); + + ULONG HttpApiResult; + + if (IsInitialRequest()) + { + HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), + HTTP_NULL_ID, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + HttpReq, + RequestBufferSize(), + NULL, + Tx.Overlapped()); + } + else + { + // The http.sys team recommends limiting the size to 128KB + static const uint64_t kMaxBytesPerApiCall = 128 * 1024; + + uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; + + HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + BufferWriteCursor, + gsl::narrow<ULONG>(BytesToReadThisCall), + nullptr, // BytesReturned + Tx.Overlapped()); + } + + if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) + { + CancelThreadpoolIo(Iocp); + + if (HttpApiResult == ERROR_MORE_DATA) + { + // ProcessReceiveAndPostResponse(pIoRequest, pServerContext->Io, ERROR_MORE_DATA); + } + + // CleanupHttpIoRequest(pIoRequest); + + ErrorCode = MakeWin32ErrorCode(HttpApiResult); + + spdlog::error("HttpReceiveHttpRequest failed, error {}", ErrorCode.message()); + + return; + } + + ErrorCode = std::error_code(); +} + +HttpSysRequestHandler* +InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); + + switch (IoResult) + { + case ERROR_OPERATION_ABORTED: + return nullptr; + + case ERROR_MORE_DATA: + // Insufficient buffer space + break; + } + + // Route requests + + try + { + HTTP_REQUEST* HttpReq = HttpRequest(); + +# if 0 + for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + { + auto& ReqInfo = HttpReq->pRequestInfo[i]; + + switch (ReqInfo.InfoType) + { + case HttpRequestInfoTypeRequestTiming: + { + const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + + spdlog::info(""); + } + break; + case HttpRequestInfoTypeAuth: + spdlog::info(""); + break; + case HttpRequestInfoTypeChannelBind: + spdlog::info(""); + break; + case HttpRequestInfoTypeSslProtocol: + spdlog::info(""); + break; + case HttpRequestInfoTypeSslTokenBindingDraft: + spdlog::info(""); + break; + case HttpRequestInfoTypeSslTokenBinding: + spdlog::info(""); + break; + case HttpRequestInfoTypeTcpInfoV0: + { + const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + + spdlog::info(""); + } + break; + case HttpRequestInfoTypeRequestSizing: + { + const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); + spdlog::info(""); + } + break; + case HttpRequestInfoTypeQuicStats: + spdlog::info(""); + break; + case HttpRequestInfoTypeTcpInfoV1: + { + const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + + spdlog::info(""); + } + break; + } + } +# endif + + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) + { + if (m_IsInitialRequest) + { + m_ContentLength = GetContentLength(HttpReq); + const HttpContentType ContentType = GetContentType(HttpReq); + + if (m_ContentLength) + { + // Handle initial chunk read by copying any payload which has already been copied + // into our embedded request buffer + + m_PayloadBuffer = IoBuffer(m_ContentLength); + m_PayloadBuffer.SetContentType(ContentType); + + uint64_t BytesToRead = m_ContentLength; + uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()); + uint8_t* BufferWriteCursor = BufferBase; + + const int EntityChunkCount = HttpReq->EntityChunkCount; + + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + + ZEN_ASSERT(BufferLength <= BytesToRead); + + memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); + + BufferWriteCursor += BufferLength; + BytesToRead -= BufferLength; + } + + m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; + } + } + else + { + m_CurrentPayloadOffset += NumberOfBytesTransferred; + } + + if (m_CurrentPayloadOffset == m_ContentLength) + { + m_PayloadBuffer.MakeImmutable(); + + // Body received completely - call request handler + + HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); + + if (!ThisRequest.IsHandled()) + { + return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); + } + + if (HttpMessageResponseRequest* Response = ThisRequest.m_Response) + { + return Response; + } + } + else + { + // Issue a read request for more body data + return this; + } + } + + // Unable to route + return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); + } + catch (std::exception& ex) + { + // TODO provide more meaningful error output + + return new HttpMessageResponseRequest(Transaction(), 500, ex.what()); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpServer interface implementation +// + +void +HttpSysServer::Initialize(int BasePort) +{ + using namespace std::literals; + + WideStringBuilder<64> BaseUri; + BaseUri << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + Initialize(BaseUri.c_str()); + StartServer(); +} + +void +HttpSysServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} +void +HttpSysServer::RegisterService(HttpService& Service) +{ + RegisterService(Service.BaseUri(), Service); +} + +} // namespace zen +#endif
\ No newline at end of file diff --git a/zenhttp/httpsys.h b/zenhttp/httpsys.h new file mode 100644 index 000000000..6616817ec --- /dev/null +++ b/zenhttp/httpsys.h @@ -0,0 +1,75 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +#ifndef ZEN_WITH_HTTPSYS +# if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_HTTPSYS 1 +# else +# define ZEN_WITH_HTTPSYS 0 +# endif +#endif + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> +# include "iothreadpool.h" + +# include <atlbase.h> +# include <http.h> + +namespace zen { + +/** + * @brief Windows implementation of HTTP server based on http.sys + * + * This requires elevation to function + */ +class HttpSysServer : public HttpServer +{ + friend class HttpSysTransaction; + +public: + explicit HttpSysServer(unsigned int ThreadCount); + ~HttpSysServer(); + + // HttpServer interface implementation + + virtual void Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + virtual void RegisterService(HttpService& Service) override; + +private: + void Initialize(const wchar_t* UrlPath); + void Cleanup(); + + void StartServer(); + void OnHandlingRequest(); + void IssueNewRequestMaybe(); + + inline bool IsOk() const { return m_IsOk; } + + void RegisterService(const char* Endpoint, HttpService& Service); + void UnregisterService(const char* Endpoint, HttpService& Service); + +private: + bool m_IsOk = false; + bool m_IsHttpInitialized = false; + WinIoThreadPool m_ThreadPool; + + std::wstring m_BaseUri; // http://*:nnnn/ + HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; + HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; + HANDLE m_RequestQueueHandle = 0; + std::atomic_int32_t m_PendingRequests{0}; + std::atomic<int32_t> m_IsShuttingDown{0}; + int32_t m_MinPendingRequests = 16; + int32_t m_MaxPendingRequests = 128; + Event m_ShutdownEvent; +}; + +} // namespace zen +#endif diff --git a/zenhttp/httpuws.cpp b/zenhttp/httpuws.cpp new file mode 100644 index 000000000..992809b17 --- /dev/null +++ b/zenhttp/httpuws.cpp @@ -0,0 +1,96 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpuws.h" + +#pragma warning(push) +#pragma warning(disable : 4244 4324 4267 4458 4706) +#include <uwebsockets/App.h> +#pragma warning(pop) + +#include <conio.h> +#include <zencore/logging.h> + +#if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "Iphlpapi.lib") +#endif + +namespace zen { + +HttpUwsServer::HttpUwsServer() +{ +} + +HttpUwsServer::~HttpUwsServer() +{ +} + +void +HttpUwsServer::RegisterService(HttpService& Service) +{ + ZEN_UNUSED(Service); +} + +void +HttpUwsServer::Initialize(int BasePort) +{ + m_BasePort = BasePort; +} + +void +HttpUwsServer::Run(bool TestMode) +{ + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit"); + } + + ::uWS::App() + .get("/*", + [](uWS::HttpResponse<false>* res, uWS::HttpRequest* req) { + res->end("Hello world!"); + ZEN_UNUSED(req); + }) + .post("/*", + [](uWS::HttpResponse<false>* res, uWS::HttpRequest* req) { + res->onData([&](std::string_view Data, bool fin) { + ZEN_UNUSED(Data); + if (fin) + res->end("Hello world!"); + }); + + res->onAborted([&] {}); + ZEN_UNUSED(req); + }) + .listen(m_BasePort, [](auto* listen_socket) { ZEN_UNUSED(listen_socket); }) + .run(); + + do + { + int WaitTimeout = -1; + + if (!TestMode) + { + WaitTimeout = 1000; + } + + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +} + +void +HttpUwsServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +} // namespace zen diff --git a/zenhttp/httpuws.h b/zenhttp/httpuws.h new file mode 100644 index 000000000..ec55ae2fd --- /dev/null +++ b/zenhttp/httpuws.h @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +#include <zencore/thread.h> + +namespace zen { + +class HttpUwsServer : public HttpServer +{ +public: + HttpUwsServer(); + ~HttpUwsServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual void Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + +private: + Event m_ShutdownEvent; + int m_BasePort = 0; +}; + +} // namespace zen
\ No newline at end of file diff --git a/zenhttp/include/zenhttp/httpclient.h b/zenhttp/include/zenhttp/httpclient.h new file mode 100644 index 000000000..8975f6fe1 --- /dev/null +++ b/zenhttp/include/zenhttp/httpclient.h @@ -0,0 +1,48 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zencore/iobuffer.h> +#include <zenhttp/httpcommon.h> + +#include <zencore/windows.h> + +// For some reason, these don't seem to stick, so we disable the warnings +//# define _SILENCE_CXX17_C_HEADER_DEPRECATION_WARNING 1 +//# define _SILENCE_ALL_CXX17_DEPRECATION_WARNINGS 1 +#pragma warning(push) +#pragma warning(disable : 4004) +#pragma warning(disable : 4996) +#include <cpr/cpr.h> +#pragma warning(pop) + +namespace zen { + +class CbPackage; + +/** Asynchronous HTTP client implementation for Zen use cases + */ +class HttpClient +{ +public: + HttpClient(std::string_view BaseUri); + ~HttpClient(); + + struct Response + { + int StatusCode = 0; + IoBuffer ResponsePayload; + }; + + [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package); + [[nodiscard]] Response Delete(std::string_view Url); + +private: + std::string m_BaseUri; +}; + +} // namespace zen + +void httpclient_forcelink(); // internal diff --git a/zencore/include/zencore/httpserver.h b/zenhttp/include/zenhttp/httpcommon.h index 19ac8732e..41ec706f4 100644 --- a/zencore/include/zencore/httpserver.h +++ b/zenhttp/include/zenhttp/httpcommon.h @@ -2,19 +2,11 @@ #pragma once -#include "zencore.h" - -#include <zencore/enumflags.h> #include <zencore/iobuffer.h> -#include <zencore/refcount.h> -#include <zencore/string.h> -#include <functional> +#include <string_view> + #include <gsl/gsl-lite.hpp> -#include <list> -#include <regex> -#include <span> -#include <unordered_map> namespace zen { @@ -25,7 +17,17 @@ class CbObject; class CbPackage; class StringBuilderBase; -enum class HttpVerb +std::string_view MapContentTypeToString(HttpContentType ContentType); +HttpContentType ParseContentType(const std::string_view& ContentTypeString); +const char* ReasonStringForHttpResultCode(int HttpCode); + +[[nodiscard]] inline bool +IsHttpSuccessCode(int HttpCode) +{ + return (HttpCode >= 200) && (HttpCode < 300); +} + +enum class HttpVerb : uint8_t { kGet = 1 << 0, kPut = 1 << 1, @@ -38,7 +40,7 @@ enum class HttpVerb gsl_DEFINE_ENUM_BITMASK_OPERATORS(HttpVerb); -enum class HttpResponse +enum class HttpResponseCode { // 1xx - Informational @@ -165,237 +167,4 @@ enum class HttpResponse NetworkAuthenticationRequired = 511, //!< Indicates that the client needs to authenticate to gain network access. }; -/** HTTP Server Request - */ -class HttpServerRequest -{ -public: - HttpServerRequest(); - ~HttpServerRequest(); - - // Synchronous operations - - [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix - [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } - inline bool IsHandled() const { return m_IsHandled; } - - struct QueryParams - { - std::vector<std::pair<std::string_view, std::string_view>> KvPairs; - - std::string_view GetValue(std::string_view ParamName) - { - for (const auto& Kv : KvPairs) - { - const std::string_view& Key = Kv.first; - - if (Key.size() == ParamName.size()) - { - if (0 == _strnicmp(Key.data(), ParamName.data(), Key.size())) - { - return Kv.second; - } - } - } - - return std::string_view(); - } - }; - - QueryParams GetQueryParams(); - - inline HttpVerb RequestVerb() const { return m_Verb; } - inline HttpContentType RequestContentType() { return m_ContentType; } - inline HttpContentType AcceptContentType() { return m_AcceptType; } - - const char* HeaderAccept() const; - const char* HeaderAcceptEncoding() const; - const char* HeaderContentType() const; - const char* HeaderContentEncoding() const; - inline uint64_t HeaderContentLength() const { return m_ContentLength; } - - void SetSuppressResponseBody() { m_SuppressBody = true; } - - // Asynchronous operations - - /** Read POST/PUT payload - - This will return a null buffer if the contents are not fully available yet, and the handler should - at that point return - another completion request will be issued once the contents have been received - fully. - - NOTE: in practice, via the http.sys implementation this always operates synchronously. This should - be updated to provide fully asynchronous operation for better scalability on shared instances - */ - virtual IoBuffer ReadPayload() = 0; - - ZENCORE_API CbObject ReadPayloadObject(); - ZENCORE_API CbPackage ReadPayloadPackage(); - - /** Respond with payload - - Note that this is destructive in the sense that the IoBuffer instances referred to by Blobs will be - moved into our response handler array where they are kept alive, in order to reduce ref-counting storms - */ - virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) = 0; - virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, IoBuffer Blob); - virtual void WriteResponse(HttpResponse HttpResponseCode) = 0; - - virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0; - - void WriteResponse(HttpResponse HttpResponseCode, CbObject Data); - void WriteResponse(HttpResponse HttpResponseCode, CbPackage Package); - void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::string_view ResponseString); - -protected: - bool m_IsHandled = false; - bool m_SuppressBody = false; - HttpVerb m_Verb = HttpVerb::kGet; - uint64_t m_ContentLength = ~0ull; - HttpContentType m_ContentType = HttpContentType::kBinary; - HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; - ExtendableStringBuilder<256> m_Uri; - ExtendableStringBuilder<256> m_QueryString; -}; - -class HttpServerException : public std::exception -{ -public: - HttpServerException(const char* Message, uint32_t Error); - - virtual const char* what() const noexcept override; - -private: - uint32_t m_ErrorCode; - std::string m_Message; -}; - -/** - * Base class for implementing an HTTP "service" - * - * A service exposes one or more endpoints with a certain URI prefix - * - */ - -class HttpService -{ -public: - HttpService() = default; - virtual ~HttpService() = default; - - virtual const char* BaseUri() const = 0; - virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; - - // Internals - - inline void SetUriPrefixLength(size_t PrefixLength) { m_UriPrefixLength = (int)PrefixLength; } - inline int UriPrefixLength() const { return m_UriPrefixLength; } - -private: - int m_UriPrefixLength = 0; -}; - -/** HTTP server - * - * Implements the main event loop to service HTTP requests, and handles routing - * requests to the appropriate endpoint handler as registered via AddEndpoint - */ -class HttpServer -{ -public: - HttpServer(); - ~HttpServer(); - - void AddEndpoint(const char* endpoint, std::function<void(HttpServerRequest&)> handler); - void AddEndpoint(HttpService& Service); - - void Initialize(int BasePort); - void Run(bool TestMode); - void RequestExit(); - -private: - struct Impl; - - RefPtr<Impl> m_Impl; -}; - -////////////////////////////////////////////////////////////////////////// - -class HttpRouterRequest -{ -public: - HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} - - ZENCORE_API std::string GetCapture(uint32_t Index) const; - inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } - -private: - using MatchResults_t = std::match_results<std::string_view::const_iterator>; - - HttpServerRequest& m_HttpRequest; - MatchResults_t m_Match; - - friend class HttpRequestRouter; -}; - -inline std::string -HttpRouterRequest::GetCapture(uint32_t Index) const -{ - ZEN_ASSERT(Index < m_Match.size()); - - return m_Match[Index]; -} - -////////////////////////////////////////////////////////////////////////// - -/** HTTP request router helper - * - * This helper class allows a service implementer to register one or more - * endpoints using pattern matching (currently using regex matching) - * - */ - -class HttpRequestRouter -{ -public: - typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; - - void AddPattern(const char* Id, const char* Regex); - void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs); - bool HandleRequest(zen::HttpServerRequest& Request); - -private: - struct HandlerEntry - { - HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) - : RegEx(Regex, std::regex::icase | std::regex::ECMAScript) - , Verbs(SupportedVerbs) - , Handler(std::move(Handler)) - , Pattern(Pattern) - { - } - - ~HandlerEntry() = default; - - std::regex RegEx; - HttpVerb Verbs; - HandlerFunc_t Handler; - const char* Pattern; - }; - - std::list<HandlerEntry> m_Handlers; - std::unordered_map<std::string, std::string> m_PatternMap; -}; - -////////////////////////////////////////////////////////////////////////// -// -// HTTP Client -// - -class HttpClient -{ -}; - } // namespace zen - -void http_forcelink(); // internal diff --git a/zenhttp/include/zenhttp/httpserver.h b/zenhttp/include/zenhttp/httpserver.h new file mode 100644 index 000000000..de097ceb3 --- /dev/null +++ b/zenhttp/include/zenhttp/httpserver.h @@ -0,0 +1,273 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zenhttp/httpcommon.h> + +#include <zencore/enumflags.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/refcount.h> +#include <zencore/string.h> +#include <zencore/uid.h> + +#include <functional> +#include <gsl/gsl-lite.hpp> +#include <list> +#include <regex> +#include <span> +#include <unordered_map> + +namespace zen { + +/** HTTP Server Request + */ +class HttpServerRequest +{ +public: + HttpServerRequest(); + ~HttpServerRequest(); + + // Synchronous operations + + [[nodiscard]] inline std::string_view RelativeUri() const { return m_UriUtf8; } // Returns URI without service prefix + [[nodiscard]] inline std::string_view QueryString() const { return m_QueryStringUtf8; } + + struct QueryParams + { + std::vector<std::pair<std::string_view, std::string_view>> KvPairs; + + std::string_view GetValue(std::string_view ParamName) + { + for (const auto& Kv : KvPairs) + { + const std::string_view& Key = Kv.first; + + if (Key.size() == ParamName.size()) + { + if (0 == _strnicmp(Key.data(), ParamName.data(), Key.size())) + { + return Kv.second; + } + } + } + + return std::string_view(); + } + }; + + QueryParams GetQueryParams(); + + inline HttpVerb RequestVerb() const { return m_Verb; } + inline HttpContentType RequestContentType() { return m_ContentType; } + inline HttpContentType AcceptContentType() { return m_AcceptType; } + + inline uint64_t ContentLength() const { return m_ContentLength; } + Oid SessionId() const; + uint32_t RequestId() const; + + inline bool IsHandled() const { return !!(m_Flags & kIsHandled); } + inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); } + inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; } + + /** Read POST/PUT payload for request body, which is always available without delay + */ + virtual IoBuffer ReadPayload() = 0; + + ZENCORE_API CbObject ReadPayloadObject(); + ZENCORE_API CbPackage ReadPayloadPackage(); + + /** Respond with payload + + No data will have been sent when any of these functions return. Instead, the response will be transmitted + asynchronously, after returning from a request handler function. + + Note that this is destructive in the sense that the IoBuffer instances referred to by Blobs will be + moved into our response handler array where they are kept alive, in order to reduce ref-counting storms + */ + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload); + + void WriteResponse(HttpResponseCode ResponseCode, CbObject Data); + void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package); + void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString); + void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob); + +protected: + enum + { + kIsHandled = 1 << 0, + kSuppressBody = 1 << 1, + kHaveRequestId = 1 << 2, + kHaveSessionId = 1 << 3, + }; + + mutable uint32_t m_Flags = 0; + HttpVerb m_Verb = HttpVerb::kGet; + HttpContentType m_ContentType = HttpContentType::kBinary; + HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; + uint64_t m_ContentLength = ~0ull; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; + mutable uint32_t m_RequestId = ~uint32_t(0); + mutable Oid m_SessionId = Oid::Zero; + + inline void SetIsHandled() { m_Flags |= kIsHandled; } + + virtual Oid ParseSessionId() const = 0; + virtual uint32_t ParseRequestId() const = 0; +}; + +class IHttpPackageHandler : public RefCounted +{ +public: + virtual void FilterOffer(std::vector<IoHash>& OfferCids) = 0; + virtual void OnRequestBegin() = 0; + virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) = 0; + virtual void OnRequestComplete() = 0; +}; + +/** + * Base class for implementing an HTTP "service" + * + * A service exposes one or more endpoints with a certain URI prefix + * + */ + +class HttpService +{ +public: + HttpService() = default; + virtual ~HttpService() = default; + + virtual const char* BaseUri() const = 0; + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + + // Internals + + inline void SetUriPrefixLength(size_t PrefixLength) { m_UriPrefixLength = (int)PrefixLength; } + inline int UriPrefixLength() const { return m_UriPrefixLength; } + +private: + int m_UriPrefixLength = 0; +}; + +/** HTTP server + * + * Implements the main event loop to service HTTP requests, and handles routing + * requests to the appropriate handler as registered via RegisterService + */ +class HttpServer : public RefCounted +{ +public: + virtual void RegisterService(HttpService& Service) = 0; + virtual void Initialize(int BasePort) = 0; + virtual void Run(bool TestMode) = 0; + virtual void RequestExit() = 0; +}; + +Ref<HttpServer> CreateHttpServer(); + +////////////////////////////////////////////////////////////////////////// + +class HttpRouterRequest +{ +public: + HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} + + ZENCORE_API std::string GetCapture(uint32_t Index) const; + inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } + +private: + using MatchResults_t = std::match_results<std::string_view::const_iterator>; + + HttpServerRequest& m_HttpRequest; + MatchResults_t m_Match; + + friend class HttpRequestRouter; +}; + +inline std::string +HttpRouterRequest::GetCapture(uint32_t Index) const +{ + ZEN_ASSERT(Index < m_Match.size()); + + return m_Match[Index]; +} + +/** HTTP request router helper + * + * This helper class allows a service implementer to register one or more + * endpoints using pattern matching (currently using regex matching) + * + * This is intended to be initialized once only, there is no thread + * safety so you can absolutely not add or remove endpoints once the handler + * goes live + */ + +class HttpRequestRouter +{ +public: + typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; + + /** + * @brief Add pattern which can be referenced by name, commonly used for URL components + * @param Id String used to identify patterns for replacement + * @param Regex String which will replace the Id string in any registered URL paths + */ + void AddPattern(const char* Id, const char* Regex); + + /** + * @brief Register a an endpoint handler for the given route + * @param Regex Regular expression used to match the handler to a request. This may + * contain pattern aliases registered via AddPattern + * @param HandlerFunc Handler function to call for any matching request + * @param SupportedVerbs Supported HTTP verbs for this handler + */ + void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs); + + void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); + + /** + * @brief HTTP request handling function - this should be called to route the + * request to a registered handler + * @param Request Request to route to a handler + * @return Function returns true if the request was routed successfully + */ + bool HandleRequest(zen::HttpServerRequest& Request); + +private: + struct HandlerEntry + { + HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) + : RegEx(Regex, std::regex::icase | std::regex::ECMAScript) + , Verbs(SupportedVerbs) + , Handler(std::move(Handler)) + , Pattern(Pattern) + { + } + + ~HandlerEntry() = default; + + std::regex RegEx; + HttpVerb Verbs; + HandlerFunc_t Handler; + const char* Pattern; + + private: + HandlerEntry& operator=(const HandlerEntry&) = delete; + HandlerEntry(const HandlerEntry&) = delete; + }; + + std::list<HandlerEntry> m_Handlers; + std::unordered_map<std::string, std::string> m_PatternMap; +}; + +} // namespace zen + +void http_forcelink(); // internal diff --git a/zenhttp/include/zenhttp/zenhttp.h b/zenhttp/include/zenhttp/zenhttp.h new file mode 100644 index 000000000..c6ec92e7c --- /dev/null +++ b/zenhttp/include/zenhttp/zenhttp.h @@ -0,0 +1,7 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#define ZENHTTP_API // Placeholder to allow DLL configs in the future diff --git a/zencore/iothreadpool.cpp b/zenhttp/iothreadpool.cpp index 4ed81d7a2..4f1a6642b 100644 --- a/zencore/iothreadpool.cpp +++ b/zenhttp/iothreadpool.cpp @@ -2,6 +2,8 @@ #include "iothreadpool.h" +#include <zencore/except.h> + namespace zen { WinIoThreadPool::WinIoThreadPool(int InThreadCount) @@ -28,9 +30,14 @@ WinIoThreadPool::~WinIoThreadPool() } void -WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context) +WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) { m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment); + + if (!m_ThreadPoolIo) + { + ErrorCode = MakeErrorCodeFromLastError(); + } } } // namespace zen diff --git a/zencore/iothreadpool.h b/zenhttp/iothreadpool.h index f64868540..4418b940b 100644 --- a/zencore/iothreadpool.h +++ b/zenhttp/iothreadpool.h @@ -4,6 +4,8 @@ #include <zencore/windows.h> +#include <system_error> + namespace zen { ////////////////////////////////////////////////////////////////////////// @@ -18,7 +20,7 @@ public: WinIoThreadPool(int InThreadCount); ~WinIoThreadPool(); - void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context); + void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode); inline PTP_IO Iocp() const { return m_ThreadPoolIo; } private: diff --git a/zenhttp/xmake.lua b/zenhttp/xmake.lua new file mode 100644 index 000000000..65d5f08ea --- /dev/null +++ b/zenhttp/xmake.lua @@ -0,0 +1,7 @@ +target('zenhttp') + set_kind("static") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore") + add_packages("vcpkg::gsl-lite", "vcpkg::uwebsockets", "vcpkg::usockets", "vcpkg::libuv") + add_options("httpsys")
\ No newline at end of file diff --git a/zenhttp/zenhttp.vcxproj b/zenhttp/zenhttp.vcxproj new file mode 100644 index 000000000..3536d1929 --- /dev/null +++ b/zenhttp/zenhttp.vcxproj @@ -0,0 +1,126 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{8eeb3be5-7001-46bf-aafd-edb7558ac012}</ProjectGuid> + <RootNamespace>zenhttp</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v142</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v142</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\zen_base_debug.props" /> + <Import Project="..\zenfs_common.props" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\zen_base_release.props" /> + <Import Project="..\zenfs_common.props" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <LinkIncremental>true</LinkIncremental> + <PublicIncludeDirectories>$(ProjectDir)include;$(PublicIncludeDirectories)</PublicIncludeDirectories> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <LinkIncremental>false</LinkIncremental> + <PublicIncludeDirectories>$(ProjectDir)include;$(PublicIncludeDirectories)</PublicIncludeDirectories> + </PropertyGroup> + <PropertyGroup Label="Vcpkg"> + <VcpkgEnableManifest>true</VcpkgEnableManifest> + </PropertyGroup> + <PropertyGroup Label="Vcpkg" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <VcpkgUseStatic>true</VcpkgUseStatic> + </PropertyGroup> + <PropertyGroup Label="Vcpkg" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <VcpkgUseStatic>true</VcpkgUseStatic> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <AdditionalIncludeDirectories>.\include</AdditionalIncludeDirectories> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <AdditionalIncludeDirectories>.\include</AdditionalIncludeDirectories> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ClCompile Include="httpclient.cpp" /> + <ClCompile Include="httpnull.cpp" /> + <ClCompile Include="httpserver.cpp" /> + <ClCompile Include="httpshared.cpp" /> + <ClCompile Include="httpsys.cpp" /> + <ClCompile Include="httpuws.cpp" /> + <ClCompile Include="iothreadpool.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="httpnull.h" /> + <ClInclude Include="httpshared.h" /> + <ClInclude Include="httpsys.h" /> + <ClInclude Include="httpuws.h" /> + <ClInclude Include="include\zenhttp\httpclient.h" /> + <ClInclude Include="include\zenhttp\httpcommon.h" /> + <ClInclude Include="include\zenhttp\httpserver.h" /> + <ClInclude Include="include\zenhttp\zenhttp.h" /> + <ClInclude Include="iothreadpool.h" /> + </ItemGroup> + <ItemGroup> + <ProjectReference Include="..\zencore\zencore.vcxproj"> + <Project>{d75bf9ab-c61e-4fff-ad59-1563430f05e2}</Project> + </ProjectReference> + </ItemGroup> + <ItemGroup> + <None Include="xmake.lua" /> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/zenhttp/zenhttp.vcxproj.filters b/zenhttp/zenhttp.vcxproj.filters new file mode 100644 index 000000000..da292c18f --- /dev/null +++ b/zenhttp/zenhttp.vcxproj.filters @@ -0,0 +1,26 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <ClCompile Include="httpclient.cpp" /> + <ClCompile Include="httpserver.cpp" /> + <ClCompile Include="httpsys.cpp" /> + <ClCompile Include="iothreadpool.cpp" /> + <ClCompile Include="httpnull.cpp" /> + <ClCompile Include="httpuws.cpp" /> + <ClCompile Include="httpshared.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="include\zenhttp\httpclient.h" /> + <ClInclude Include="include\zenhttp\httpserver.h" /> + <ClInclude Include="httpsys.h" /> + <ClInclude Include="iothreadpool.h" /> + <ClInclude Include="include\zenhttp\zenhttp.h" /> + <ClInclude Include="httpnull.h" /> + <ClInclude Include="httpuws.h" /> + <ClInclude Include="httpshared.h" /> + <ClInclude Include="include\zenhttp\httpcommon.h" /> + </ItemGroup> + <ItemGroup> + <None Include="xmake.lua" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/zenserver-test/xmake.lua b/zenserver-test/xmake.lua index 10e287c67..0e30da12e 100644 --- a/zenserver-test/xmake.lua +++ b/zenserver-test/xmake.lua @@ -1,5 +1,5 @@ target("zenserver-test") set_kind("binary") add_files("*.cpp") - add_deps("zencore", "zenutil") + add_deps("zencore", "zenutil", "zenhttp") add_packages("vcpkg::http-parser", "vcpkg::mimalloc") diff --git a/zenserver-test/zenserver-test.cpp b/zenserver-test/zenserver-test.cpp index 455ab2495..efcbf5da8 100644 --- a/zenserver-test/zenserver-test.cpp +++ b/zenserver-test/zenserver-test.cpp @@ -5,6 +5,7 @@ #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> #include <zencore/compactbinarypackage.h> +#include <zencore/compress.h> #include <zencore/except.h> #include <zencore/filesystem.h> #include <zencore/fmtutils.h> @@ -12,6 +13,7 @@ #include <zencore/string.h> #include <zencore/thread.h> #include <zencore/timer.h> +#include <zenhttp/httpclient.h> #include <zenserverprocess.h> #include <mimalloc.h> @@ -1380,4 +1382,112 @@ TEST_CASE("mesh.basic") } } +class ZenServerTestHelper +{ +public: + ZenServerTestHelper(std::string_view HelperId, int ServerCount) : m_HelperId{HelperId}, m_ServerCount{ServerCount} {} + ~ZenServerTestHelper() {} + + void SpawnServers() + { + SpawnServers([](ZenServerInstance&) {}); + } + + void SpawnServers(auto&& Callback) + { + spdlog::info("{}: spawning {} server instances", m_HelperId, m_ServerCount); + + m_Instances.resize(m_ServerCount); + + for (int i = 0; i < m_ServerCount; ++i) + { + auto& Instance = m_Instances[i]; + + Instance = std::make_unique<ZenServerInstance>(TestEnv); + Instance->SetTestDir(TestEnv.CreateNewTestDir()); + + Callback(*Instance); + + Instance->SpawnServer(13337 + i); + } + + for (int i = 0; i < m_ServerCount; ++i) + { + auto& Instance = m_Instances[i]; + + Instance->WaitUntilReady(); + } + } + + ZenServerInstance& GetInstance(int Index) { return *m_Instances[Index]; } + +private: + std::string m_HelperId; + int m_ServerCount = 0; + std::vector<std::unique_ptr<ZenServerInstance> > m_Instances; +}; + +TEST_CASE("http.basics") +{ + using namespace std::literals; + + ZenServerTestHelper Servers{"http.basics"sv, 1}; + Servers.SpawnServers(); + + ZenServerInstance& Instance = Servers.GetInstance(0); + const std::string BaseUri = Instance.GetBaseUri(); + + { + cpr::Response r = cpr::Get(cpr::Url{"{}/testing/hello"_format(BaseUri)}); + CHECK_EQ(r.status_code, 200); + } + + { + cpr::Response r = cpr::Post(cpr::Url{"{}/testing/hello"_format(BaseUri)}); + CHECK_EQ(r.status_code, 404); + } + + { + cpr::Response r = cpr::Post(cpr::Url{"{}/testing/echo"_format(BaseUri)}, cpr::Body{"yoyoyoyo"}); + CHECK_EQ(r.status_code, 200); + CHECK_EQ(r.text, "yoyoyoyo"); + } +} + +TEST_CASE("http.package") +{ + using namespace std::literals; + + ZenServerTestHelper Servers{"http.package"sv, 1}; + Servers.SpawnServers(); + + ZenServerInstance& Instance = Servers.GetInstance(0); + const std::string BaseUri = Instance.GetBaseUri(); + + static const uint8_t Data1[] = {0, 1, 2, 3}; + static const uint8_t Data2[] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + + zen::CbAttachment Attach1{zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone({Data1, 4}), + zen::OodleCompressor::NotSet, + zen::OodleCompressionLevel::None)}; + zen::CbAttachment Attach2{zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone({Data2, 8}), + zen::OodleCompressor::NotSet, + zen::OodleCompressionLevel::None)}; + + zen::CbObjectWriter Writer; + + Writer.AddAttachment("attach1", Attach1); + Writer.AddAttachment("attach2", Attach2); + + zen::CbObject CoreObject = Writer.Save(); + + zen::CbPackage TestPackage; + TestPackage.SetObject(CoreObject); + TestPackage.AddAttachment(Attach1); + TestPackage.AddAttachment(Attach2); + + zen::HttpClient TestClient(BaseUri); + zen::HttpClient::Response Response = TestClient.TransactPackage("/testing/package"sv, TestPackage); +} + #endif diff --git a/zenserver-test/zenserver-test.vcxproj b/zenserver-test/zenserver-test.vcxproj index 54027cba3..a39fce7ec 100644 --- a/zenserver-test/zenserver-test.vcxproj +++ b/zenserver-test/zenserver-test.vcxproj @@ -97,6 +97,9 @@ <ProjectReference Include="..\zencore\zencore.vcxproj"> <Project>{d75bf9ab-c61e-4fff-ad59-1563430f05e2}</Project> </ProjectReference> + <ProjectReference Include="..\zenhttp\zenhttp.vcxproj"> + <Project>{8eeb3be5-7001-46bf-aafd-edb7558ac012}</Project> + </ProjectReference> <ProjectReference Include="..\zenutil\zenutil.vcxproj"> <Project>{77f8315d-b21d-4db0-9a6f-2d3359f88a70}</Project> </ProjectReference> diff --git a/zenserver/admin/admin.h b/zenserver/admin/admin.h index 3bb8a9158..f90ad4537 100644 --- a/zenserver/admin/admin.h +++ b/zenserver/admin/admin.h @@ -2,7 +2,7 @@ #pragma once -#include <zencore/httpserver.h> +#include <zenhttp/httpserver.h> class HttpAdminService : public zen::HttpService { diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp index 0e235a9be..b2f8d191c 100644 --- a/zenserver/cache/structuredcache.cpp +++ b/zenserver/cache/structuredcache.cpp @@ -4,8 +4,8 @@ #include <zencore/compactbinaryvalidation.h> #include <zencore/compress.h> #include <zencore/fmtutils.h> -#include <zencore/httpserver.h> #include <zencore/timer.h> +#include <zenhttp/httpserver.h> #include "structuredcache.h" #include "structuredcachestore.h" @@ -25,27 +25,6 @@ namespace zen { using namespace std::literals; -zen::HttpContentType -MapToHttpContentType(zen::ZenContentType Type) -{ - switch (Type) - { - default: - case zen::ZenContentType::kBinary: - return zen::HttpContentType::kBinary; - case zen::ZenContentType::kCbObject: - return zen::HttpContentType::kCbObject; - case zen::ZenContentType::kCbPackage: - return zen::HttpContentType::kCbPackage; - case zen::ZenContentType::kText: - return zen::HttpContentType::kText; - case zen::ZenContentType::kJSON: - return zen::HttpContentType::kJSON; - case zen::ZenContentType::kYAML: - return zen::HttpContentType::kYAML; - } -}; - ////////////////////////////////////////////////////////////////////////// HttpStructuredCacheService::HttpStructuredCacheService(::ZenCacheStore& InCacheStore, @@ -93,7 +72,7 @@ HttpStructuredCacheService::HandleRequest(zen::HttpServerRequest& Request) return HandleCacheBucketRequest(Request, Key); } - return Request.WriteResponse(zen::HttpResponse::BadRequest); // invalid URL + return Request.WriteResponse(zen::HttpResponseCode::BadRequest); // invalid URL } if (Ref.PayloadId == IoHash::Zero) @@ -128,11 +107,11 @@ HttpStructuredCacheService::HandleCacheBucketRequest(zen::HttpServerRequest& Req if (m_CacheStore.DropBucket(Bucket)) { - return Request.WriteResponse(zen::HttpResponse::OK); + return Request.WriteResponse(zen::HttpResponseCode::OK); } else { - return Request.WriteResponse(zen::HttpResponse::NotFound); + return Request.WriteResponse(zen::HttpResponseCode::NotFound); } break; } @@ -209,7 +188,7 @@ HttpStructuredCacheService::HandleCacheRecordRequest(zen::HttpServerRequest& Req { m_Log.debug("MISS - '{}/{}'", Ref.BucketSegment, Ref.HashKey); - return Request.WriteResponse(zen::HttpResponse::NotFound); + return Request.WriteResponse(zen::HttpResponseCode::NotFound); } if (Verb == kHead) @@ -224,7 +203,7 @@ HttpStructuredCacheService::HandleCacheRecordRequest(zen::HttpServerRequest& Req Value.Value.GetContentType(), InUpstreamCache ? "upstream" : "local"); - return Request.WriteResponse(zen::HttpResponse::OK, MapToHttpContentType(Value.Value.GetContentType()), Value.Value); + return Request.WriteResponse(zen::HttpResponseCode::OK, Value.Value.GetContentType(), Value.Value); } break; @@ -234,7 +213,7 @@ HttpStructuredCacheService::HandleCacheRecordRequest(zen::HttpServerRequest& Req if (!Body || Body.Size() == 0) { - return Request.WriteResponse(zen::HttpResponse::BadRequest); + return Request.WriteResponse(zen::HttpResponseCode::BadRequest); } const HttpContentType ContentType = Request.RequestContentType(); @@ -253,7 +232,7 @@ HttpStructuredCacheService::HandleCacheRecordRequest(zen::HttpServerRequest& Req break; default: - return Request.WriteResponse(zen::HttpResponse::BadRequest); + return Request.WriteResponse(zen::HttpResponseCode::BadRequest); } if (!IsCompactBinary) @@ -272,7 +251,7 @@ HttpStructuredCacheService::HandleCacheRecordRequest(zen::HttpServerRequest& Req {.Type = ZenContentType::kBinary, .CacheKey = {Ref.BucketSegment, Ref.HashKey}}); } - return Request.WriteResponse(zen::HttpResponse::Created); + return Request.WriteResponse(zen::HttpResponseCode::Created); } // Validate payload before accessing it @@ -284,7 +263,9 @@ HttpStructuredCacheService::HandleCacheRecordRequest(zen::HttpServerRequest& Req m_Log.warn("Payload for key '{}/{}' ({} bytes) failed validation", Ref.BucketSegment, Ref.HashKey, Body.Size()); // TODO: add details in response, kText || kCbObject? - return Request.WriteResponse(HttpResponse::BadRequest, HttpContentType::kText, "Compact binary validation failed"sv); + return Request.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Compact binary validation failed"sv); } // Extract referenced payload hashes @@ -336,7 +317,7 @@ HttpStructuredCacheService::HandleCacheRecordRequest(zen::HttpServerRequest& Req .PayloadIds = std::move(References)}); } - return Request.WriteResponse(zen::HttpResponse::Created); + return Request.WriteResponse(zen::HttpResponseCode::Created); } else { @@ -351,7 +332,7 @@ HttpStructuredCacheService::HandleCacheRecordRequest(zen::HttpServerRequest& Req Response.EndArray(); // Return Created | BadRequest? - return Request.WriteResponse(zen::HttpResponse::Created, Response.Save()); + return Request.WriteResponse(zen::HttpResponseCode::Created, Response.Save()); } } break; @@ -407,7 +388,7 @@ HttpStructuredCacheService::HandleCachePayloadRequest(zen::HttpServerRequest& Re if (!Payload) { m_Log.debug("MISS - '{}/{}/{}'", Ref.BucketSegment, Ref.HashKey, Ref.PayloadId); - return Request.WriteResponse(zen::HttpResponse::NotFound); + return Request.WriteResponse(zen::HttpResponseCode::NotFound); } m_Log.debug("HIT - '{}/{}/{}' ({} bytes, {}) ({})", @@ -423,7 +404,7 @@ HttpStructuredCacheService::HandleCachePayloadRequest(zen::HttpServerRequest& Re Request.SetSuppressResponseBody(); } - return Request.WriteResponse(zen::HttpResponse::OK, zen::HttpContentType::kBinary, Payload); + return Request.WriteResponse(zen::HttpResponseCode::OK, zen::HttpContentType::kBinary, Payload); } break; @@ -433,7 +414,9 @@ HttpStructuredCacheService::HandleCachePayloadRequest(zen::HttpServerRequest& Re { if (Body.Size() == 0) { - return Request.WriteResponse(zen::HttpResponse::BadRequest, HttpContentType::kText, "Empty payload not permitted"); + return Request.WriteResponse(zen::HttpResponseCode::BadRequest, + HttpContentType::kText, + "Empty payload not permitted"); } zen::IoHash ChunkHash = zen::IoHash::HashBuffer(Body); @@ -443,7 +426,7 @@ HttpStructuredCacheService::HandleCachePayloadRequest(zen::HttpServerRequest& Re if (!Compressed) { // All attachment payloads need to be in compressed buffer format - return Request.WriteResponse(zen::HttpResponse::BadRequest, + return Request.WriteResponse(zen::HttpResponseCode::BadRequest, HttpContentType::kText, "Attachments must be compressed"); } @@ -452,7 +435,7 @@ HttpStructuredCacheService::HandleCachePayloadRequest(zen::HttpServerRequest& Re if (IoHash::FromBLAKE3(Compressed.GetRawHash()) != Ref.PayloadId) { // the URL specified content id and content hashes don't match! - return Request.WriteResponse(HttpResponse::BadRequest); + return Request.WriteResponse(HttpResponseCode::BadRequest); } zen::CasStore::InsertResult Result = m_CasStore.InsertChunk(Body, ChunkHash); @@ -469,11 +452,11 @@ HttpStructuredCacheService::HandleCachePayloadRequest(zen::HttpServerRequest& Re if (Result.New) { - return Request.WriteResponse(zen::HttpResponse::Created); + return Request.WriteResponse(zen::HttpResponseCode::Created); } else { - return Request.WriteResponse(zen::HttpResponse::OK); + return Request.WriteResponse(zen::HttpResponseCode::OK); } } } diff --git a/zenserver/cache/structuredcache.h b/zenserver/cache/structuredcache.h index b90301d84..d4bb94c52 100644 --- a/zenserver/cache/structuredcache.h +++ b/zenserver/cache/structuredcache.h @@ -2,7 +2,7 @@ #pragma once -#include <zencore/httpserver.h> +#include <zenhttp/httpserver.h> #include <spdlog/spdlog.h> #include <memory> diff --git a/zenserver/casstore.cpp b/zenserver/casstore.cpp index 1d147024a..6f1e4873b 100644 --- a/zenserver/casstore.cpp +++ b/zenserver/casstore.cpp @@ -23,7 +23,7 @@ HttpCasService::HttpCasService(CasStore& Store) : m_CasStore(Store) if ((EntryCount * sizeof(IoHash)) != Payload.Size()) { - return ServerRequest.WriteResponse(HttpResponse::BadRequest); + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest); } const IoHash* Hashes = reinterpret_cast<const IoHash*>(Payload.Data()); @@ -55,7 +55,7 @@ HttpCasService::HttpCasService(CasStore& Store) : m_CasStore(Store) Values[0] = IoBufferBuilder::MakeCloneFromMemory(HeaderStream.Data(), HeaderStream.Size()); - ServerRequest.WriteResponse(HttpResponse::OK, HttpContentType::kBinary, Values); + ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Values); }, HttpVerb::kPost); @@ -74,10 +74,10 @@ HttpCasService::HttpCasService(CasStore& Store) : m_CasStore(Store) { if (IoBuffer Value = m_CasStore.FindChunk(Hash)) { - return ServerRequest.WriteResponse(HttpResponse::OK, HttpContentType::kBinary, Value); + return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value); } - return ServerRequest.WriteResponse(HttpResponse::NotFound); + return ServerRequest.WriteResponse(HttpResponseCode::NotFound); } break; @@ -89,12 +89,12 @@ HttpCasService::HttpCasService(CasStore& Store) : m_CasStore(Store) // URI hash must match content hash if (PayloadHash != Hash) { - return ServerRequest.WriteResponse(HttpResponse::BadRequest); + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest); } m_CasStore.InsertChunk(Payload, PayloadHash); - return ServerRequest.WriteResponse(HttpResponse::OK); + return ServerRequest.WriteResponse(HttpResponseCode::OK); } break; } @@ -129,11 +129,11 @@ HttpCasService::HandleRequest(zen::HttpServerRequest& Request) if (InsertResult.New) { - return Request.WriteResponse(HttpResponse::Created); + return Request.WriteResponse(HttpResponseCode::Created); } else { - return Request.WriteResponse(HttpResponse::OK); + return Request.WriteResponse(HttpResponseCode::OK); } } break; diff --git a/zenserver/casstore.h b/zenserver/casstore.h index 7166f796e..4ca6908b5 100644 --- a/zenserver/casstore.h +++ b/zenserver/casstore.h @@ -2,7 +2,7 @@ #pragma once -#include <zencore/httpserver.h> +#include <zenhttp/httpserver.h> #include <zenstore/cas.h> namespace zen { diff --git a/zenserver/compute/apply.cpp b/zenserver/compute/apply.cpp index 939ac3362..94dedf087 100644 --- a/zenserver/compute/apply.cpp +++ b/zenserver/compute/apply.cpp @@ -351,12 +351,12 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end()) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } else { const WorkerDesc& Desc = It->second; - return HttpReq.WriteResponse(HttpResponse::OK, Desc.Descriptor); + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor); } } break; @@ -378,6 +378,10 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, ChunkSet.AddChunk(Hash); }); + // Note that we store executables uncompressed to make it + // more straightforward and efficient to materialize them, hence + // the CAS lookup here instead of CID for the input payloads + m_CasStore.FilterChunks(ChunkSet); if (ChunkSet.IsEmpty()) @@ -388,7 +392,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, spdlog::debug("worker {}: all attachments already available", WorkerId); - return HttpReq.WriteResponse(HttpResponse::NoContent); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); } else { @@ -406,7 +410,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, spdlog::debug("worker {}: need {} attachments", WorkerId, ChunkSet.GetChunkSet().size()); - return HttpReq.WriteResponse(HttpResponse::NotFound, ResponseWriter.Save()); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); } } break; @@ -426,21 +430,25 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, for (const CbAttachment& Attachment : Attachments) { - ZEN_ASSERT(Attachment.IsBinary()); + ZEN_ASSERT(Attachment.IsCompressedBinary()); - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + SharedBuffer Decompressed = DataView.Decompress(); + const uint64_t DecompressedSize = DataView.GetRawSize(); - TotalAttachmentBytes += DataView.GetCompressedSize(); + TotalAttachmentBytes += DecompressedSize; ++AttachmentCount; - IoBuffer Payload = DataView.GetCompressed().Flatten().AsIoBuffer(); + // Note that we store executables uncompressed to make it + // more straightforward and efficient to materialize them - CasStore::InsertResult InsertResult = m_CasStore.InsertChunk(Payload, DataHash); + const CasStore::InsertResult InsertResult = + m_CasStore.InsertChunk(Decompressed.AsIoBuffer(), IoHash::FromBLAKE3(DataView.GetRawHash())); if (InsertResult.New) { - TotalNewBytes += Payload.Size(); + TotalNewBytes += DecompressedSize; ++NewAttachmentCount; } } @@ -456,7 +464,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{.Descriptor = Obj}); - return HttpReq.WriteResponse(HttpResponse::NoContent); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); } break; } @@ -495,7 +503,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end()) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } else { @@ -526,7 +534,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, RequestObject.IterateAttachments([&](CbFieldView Field) { const IoHash FileHash = Field.AsHash(); - if (!m_CasStore.FindChunk(FileHash)) + if (!m_CidStore.ContainsChunk(FileHash)) { NeedList.push_back(FileHash); } @@ -538,7 +546,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, CbPackage Output = ExecAction(Worker, RequestObject); - return HttpReq.WriteResponse(HttpResponse::OK, Output); + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); } CbObjectWriter Cbo; @@ -552,7 +560,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, Cbo.EndArray(); CbObject Response = Cbo.Save(); - return HttpReq.WriteResponse(HttpResponse::NotFound, Response); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); } break; @@ -570,19 +578,21 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, for (const CbAttachment& Attachment : Attachments) { - ZEN_ASSERT(Attachment.IsBinary()); + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); - const IoHash DataHash = Attachment.GetHash(); - SharedBuffer DataView = Attachment.AsBinary(); + const uint64_t CompressedSize = DataView.GetCompressedSize(); - TotalAttachmentBytes += DataView.GetSize(); + TotalAttachmentBytes += CompressedSize; ++AttachmentCount; - CasStore::InsertResult InsertResult = m_CasStore.InsertChunk(DataView.AsIoBuffer(), DataHash); + const CasStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView); if (InsertResult.New) { - TotalNewBytes += DataView.GetSize(); + TotalNewBytes += CompressedSize; ++NewAttachmentCount; } } @@ -595,7 +605,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, CbPackage Output = ExecAction(Worker, ActionObj); - return HttpReq.WriteResponse(HttpResponse::OK, Output); + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); } break; } @@ -659,7 +669,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, if (!AllOk) { // TODO: Could report all the missing pieces in the response here - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } std::string Executable8{RequestObject["cmd"].AsString()}; @@ -681,7 +691,7 @@ HttpFunctionService::HttpFunctionService(CasStore& Store, CidStore& InCidStore, Response << "exitcode" << Job.ExitCode(); - return HttpReq.WriteResponse(HttpResponse::OK, Response.Save()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); } break; } @@ -780,18 +790,16 @@ HttpFunctionService::ExecAction(const WorkerDesc& Worker, CbObject Action) // Manifest inputs in sandbox Action.IterateAttachments([&](CbFieldView Field) { - const IoHash Hash = Field.AsHash(); - std::filesystem::path FilePath{SandboxPath / "Inputs" / Hash.ToHexString()}; - IoBuffer DataBuffer = m_CasStore.FindChunk(Hash); + const IoHash Cid = Field.AsHash(); + std::filesystem::path FilePath{SandboxPath / "Inputs" / Cid.ToHexString()}; + IoBuffer DataBuffer = m_CidStore.FindChunkByCid(Cid); if (!DataBuffer) { throw std::exception("Chunk missing" /* ADD CONTEXT */); } - CompressedBuffer Buffer = CompressedBuffer::Compress(SharedBuffer(std::move(DataBuffer))); - - zen::WriteFile(FilePath, Buffer.GetCompressed().Flatten().AsIoBuffer()); + zen::WriteFile(FilePath, DataBuffer); }); // Set up environment variables @@ -884,7 +892,7 @@ HttpFunctionService::ExecAction(const WorkerDesc& Worker, CbObject Action) ZEN_ASSERT(OutputData.Data.size() == 1); - CbAttachment Attachment(SharedBuffer(ChunkData.Data[0]), Hash); + CbAttachment Attachment(CompressedBuffer::FromCompressed(SharedBuffer(ChunkData.Data[0]))); OutputPackage.AddAttachment(Attachment); }); diff --git a/zenserver/compute/apply.h b/zenserver/compute/apply.h index 20a58507b..695dc2e6e 100644 --- a/zenserver/compute/apply.h +++ b/zenserver/compute/apply.h @@ -3,9 +3,9 @@ #pragma once #include <zencore/compactbinary.h> -#include <zencore/httpserver.h> #include <zencore/iohash.h> #include <zencore/logging.h> +#include <zenhttp/httpserver.h> #include <filesystem> #include <unordered_map> diff --git a/zenserver/diag/diagsvcs.h b/zenserver/diag/diagsvcs.h index 84f8d22ee..51ee98f67 100644 --- a/zenserver/diag/diagsvcs.h +++ b/zenserver/diag/diagsvcs.h @@ -2,8 +2,8 @@ #pragma once -#include <zencore/httpserver.h> #include <zencore/iobuffer.h> +#include <zenhttp/httpserver.h> ////////////////////////////////////////////////////////////////////////// @@ -25,17 +25,17 @@ public: if (Uri == "hello"sv) { - Request.WriteResponse(zen::HttpResponse::OK, zen::HttpContentType::kText, u8"hello world!"sv); + Request.WriteResponse(zen::HttpResponseCode::OK, zen::HttpContentType::kText, u8"hello world!"sv); // OutputLogMessageInternal(&LogPoint, 0, 0); } else if (Uri == "1K"sv) { - Request.WriteResponse(zen::HttpResponse::OK, zen::HttpContentType::kBinary, m_1k); + Request.WriteResponse(zen::HttpResponseCode::OK, zen::HttpContentType::kBinary, m_1k); } else if (Uri == "1M"sv) { - Request.WriteResponse(zen::HttpResponse::OK, zen::HttpContentType::kBinary, m_1m); + Request.WriteResponse(zen::HttpResponseCode::OK, zen::HttpContentType::kBinary, m_1m); } else if (Uri == "1M_1k"sv) { @@ -47,7 +47,7 @@ public: Buffers.push_back(m_1k); } - Request.WriteResponse(zen::HttpResponse::OK, zen::HttpContentType::kBinary, Buffers); + Request.WriteResponse(zen::HttpResponseCode::OK, zen::HttpContentType::kBinary, Buffers); } else if (Uri == "1G"sv) { @@ -59,7 +59,7 @@ public: Buffers.push_back(m_1m); } - Request.WriteResponse(zen::HttpResponse::OK, zen::HttpContentType::kBinary, Buffers); + Request.WriteResponse(zen::HttpResponseCode::OK, zen::HttpContentType::kBinary, Buffers); } else if (Uri == "1G_1k"sv) { @@ -71,7 +71,7 @@ public: Buffers.push_back(m_1k); } - Request.WriteResponse(zen::HttpResponse::OK, zen::HttpContentType::kBinary, Buffers); + Request.WriteResponse(zen::HttpResponseCode::OK, zen::HttpContentType::kBinary, Buffers); } } @@ -95,7 +95,7 @@ public: switch (Request.RequestVerb()) { case zen::HttpVerb::kGet: - return Request.WriteResponse(zen::HttpResponse::OK, zen::HttpContentType::kText, u8"OK!"sv); + return Request.WriteResponse(zen::HttpResponseCode::OK, zen::HttpContentType::kText, u8"OK!"sv); } } diff --git a/zenserver/diag/logging.cpp b/zenserver/diag/logging.cpp index 796a15d01..5782ce582 100644 --- a/zenserver/diag/logging.cpp +++ b/zenserver/diag/logging.cpp @@ -195,13 +195,13 @@ InitializeLogging(const ZenServerOptions& GlobalOptions) std::filesystem::path LogPath = GlobalOptions.DataDir / "logs/zenserver.log"; - bool IsAsync = true; + bool IsAsync = true; spdlog::level::level_enum LogLevel = spdlog::level::info; if (GlobalOptions.IsDebug) { LogLevel = spdlog::level::debug; - IsAsync = false; + IsAsync = false; } if (IsAsync) diff --git a/zenserver/projectstore.cpp b/zenserver/projectstore.cpp index 006796b28..2bbc1dce3 100644 --- a/zenserver/projectstore.cpp +++ b/zenserver/projectstore.cpp @@ -542,8 +542,30 @@ ProjectStore::Oplog::AppendNewOplogEntry(CbPackage OpPackage) for (const auto& Attach : Attachments) { - IoBuffer AttachmentData = Attach.AsBinary().AsIoBuffer(); - CasStore::InsertResult Result = m_CasStore.InsertChunk(AttachmentData, Attach.GetHash()); + IoBuffer AttachmentData; + + if (Attach.IsBinary()) + { + AttachmentData = Attach.AsBinary().AsIoBuffer(); + } + else if (Attach.IsCompressedBinary()) + { + ZEN_NOT_IMPLEMENTED("Compressed binary attachments are currently not supported for oplogs"); + + AttachmentData = Attach.AsCompressedBinary().GetCompressed().Flatten().AsIoBuffer(); + } + else if (Attach.IsObject()) + { + AttachmentData = Attach.AsObject().GetBuffer().AsIoBuffer(); + } + else + { + ZEN_NOT_IMPLEMENTED("Unknown attachment type"); + } + + ZEN_ASSERT(AttachmentData); + + CasStore::InsertResult Result = m_CasStore.InsertChunk(AttachmentData, Attach.GetHash()); const uint64_t AttachmentSize = AttachmentData.Size(); @@ -909,7 +931,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (FoundLog == nullptr) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } // Parse Request @@ -940,7 +962,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (Payload.Size() <= sizeof(RequestHeader)) { - HttpReq.WriteResponse(HttpResponse::BadRequest); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); } RequestHeader RequestHdr; @@ -948,7 +970,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (RequestHdr.Magic != RequestHeader::kMagic) { - HttpReq.WriteResponse(HttpResponse::BadRequest); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); } std::vector<RequestChunkEntry> RequestedChunks; @@ -1004,9 +1026,9 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) ResponsePtr += sizeof(ResponseHdr); for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex) { - //const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex]; - const IoBuffer& FoundChunk(OutBlobs[ChunkIndex + 1]); - ResponseChunkEntry ResponseChunk; + // const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex]; + const IoBuffer& FoundChunk(OutBlobs[ChunkIndex + 1]); + ResponseChunkEntry ResponseChunk; ResponseChunk.CorrelationId = ChunkIndex; if (FoundChunk) { @@ -1020,7 +1042,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) ResponsePtr += sizeof(ResponseChunk); } - return HttpReq.WriteResponse(HttpResponse::OK, HttpContentType::kBinary, OutBlobs); + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, OutBlobs); }, HttpVerb::kPost); @@ -1038,7 +1060,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (FoundLog == nullptr) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } ProjectStore::Oplog& Log = *FoundLog; @@ -1063,7 +1085,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) Response.EndArray(); - return HttpReq.WriteResponse(HttpResponse::OK, Response.Save()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); }, HttpVerb::kGet); @@ -1080,7 +1102,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (FoundLog == nullptr) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } ProjectStore::Oplog& Log = *FoundLog; @@ -1093,10 +1115,10 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) { CbObjectWriter Response; Response << "size" << Value.Size(); - return HttpReq.WriteResponse(HttpResponse::OK, Response.Save()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); } - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); }, HttpVerb::kGet); @@ -1124,7 +1146,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) } else { - return HttpReq.WriteResponse(HttpResponse::BadRequest); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } } @@ -1137,7 +1159,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) } else { - return HttpReq.WriteResponse(HttpResponse::BadRequest); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } } @@ -1147,7 +1169,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (FoundLog == nullptr) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } ProjectStore::Oplog& Log = *FoundLog; @@ -1162,7 +1184,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) case HttpVerb::kGet: if (!Value) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } if (Verb == HttpVerb::kHead) @@ -1185,10 +1207,10 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) // Send only a subset of data IoBuffer InnerValue(Value, Offset, Size); - return HttpReq.WriteResponse(HttpResponse::OK, HttpContentType::kBinary, InnerValue); + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, InnerValue); } - return HttpReq.WriteResponse(HttpResponse::OK, HttpContentType::kBinary, Value); + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value); } }, HttpVerb::kGet | HttpVerb::kHead); @@ -1217,7 +1239,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) } else { - return HttpReq.WriteResponse(HttpResponse::BadRequest); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } } @@ -1230,7 +1252,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) } else { - return HttpReq.WriteResponse(HttpResponse::BadRequest); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } } @@ -1240,7 +1262,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) IoBuffer Value = m_CasStore.FindChunk(Hash); if (!Value) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } if (IsOffset) @@ -1258,10 +1280,10 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) // Send only a subset of data IoBuffer InnerValue(Value, Offset, Size); - return HttpReq.WriteResponse(HttpResponse::OK, HttpContentType::kBinary, InnerValue); + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, InnerValue); } - return HttpReq.WriteResponse(HttpResponse::OK, HttpContentType::kBinary, Value); + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value); }, HttpVerb::kGet); @@ -1277,7 +1299,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (FoundLog == nullptr) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } // This operation takes a list of referenced hashes and decides which @@ -1312,7 +1334,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) Cbo.EndArray(); CbObject Response = Cbo.Save(); - return HttpReq.WriteResponse(HttpResponse::OK, Response); + return HttpReq.WriteResponse(HttpResponseCode::OK, Response); }, HttpVerb::kPost); @@ -1340,7 +1362,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (FoundLog == nullptr) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } ProjectStore::Oplog& Log = *FoundLog; @@ -1387,23 +1409,26 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) }; CbPackage Package; - if (!Package.TryLoad(Payload, &UniqueBuffer::Alloc, &Resolver)) + + if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver)) { - return HttpReq.WriteResponse(HttpResponse::BadRequest, HttpContentType::kText, "Invalid package"); + m_Log.error("Received malformed package!"); + + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package"); } if (!IsValid) { // TODO: emit diagnostics identifying missing chunks - return HttpReq.WriteResponse(HttpResponse::NotFound, HttpContentType::kText, "Missing chunk reference"); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Missing chunk reference"); } CbObject Core = Package.GetObject(); if (!Core["key"sv]) { - return HttpReq.WriteResponse(HttpResponse::BadRequest, HttpContentType::kText, "No oplog entry key specified"); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "No oplog entry key specified"); } // Write core to oplog @@ -1412,12 +1437,12 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (OpLsn == ProjectStore::Oplog::kInvalidOp) { - return HttpReq.WriteResponse(HttpResponse::BadRequest); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } m_Log.info("new op #{:4} - {}/{} ({:>6}) {}", OpLsn, ProjectId, OplogId, NiceBytes(Payload.Size()), Core["key"sv].AsString()); - HttpReq.WriteResponse(HttpResponse::Created); + HttpReq.WriteResponse(HttpResponseCode::Created); }, HttpVerb::kPost); @@ -1428,7 +1453,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) // TODO: look up op and respond with the payload! - HttpReq.WriteResponse(HttpResponse::Accepted, HttpContentType::kText, u8"yeee"sv); + HttpReq.WriteResponse(HttpResponseCode::Accepted, HttpContentType::kText, u8"yeee"sv); }, HttpVerb::kGet); @@ -1444,7 +1469,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (!ProjectIt) { - return Req.ServerRequest().WriteResponse(HttpResponse::NotFound, + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "project {} not found"_format(ProjectId)); } @@ -1459,7 +1484,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (!OplogIt) { - return Req.ServerRequest().WriteResponse(HttpResponse::NotFound, + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "oplog {} not found in project {}"_format(OplogId, ProjectId)); } @@ -1469,7 +1494,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) CbObjectWriter Cb; Cb << "id"sv << Log.OplogId() << "project"sv << Prj.Identifier << "tempdir"sv << Log.TempDir(); - Req.ServerRequest().WriteResponse(HttpResponse::OK, Cb.Save()); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cb.Save()); } break; @@ -1482,18 +1507,18 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (!Prj.NewOplog(OplogId)) { // TODO: indicate why the operation failed! - return Req.ServerRequest().WriteResponse(HttpResponse::InternalServerError); + return Req.ServerRequest().WriteResponse(HttpResponseCode::InternalServerError); } m_Log.info("established oplog {} / {}", ProjectId, OplogId); - return Req.ServerRequest().WriteResponse(HttpResponse::Created); + return Req.ServerRequest().WriteResponse(HttpResponseCode::Created); } // I guess this should ultimately be used to execute RPCs but for now, it // does absolutely nothing - return Req.ServerRequest().WriteResponse(HttpResponse::BadRequest); + return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); } break; @@ -1503,7 +1528,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) ProjectIt->DeleteOplog(OplogId); - return Req.ServerRequest().WriteResponse(HttpResponse::OK); + return Req.ServerRequest().WriteResponse(HttpResponseCode::OK); } break; } @@ -1522,7 +1547,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (FoundLog == nullptr) { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } CbObjectWriter Response; @@ -1542,7 +1567,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) } else { - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } } else @@ -1555,7 +1580,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) } } - return HttpReq.WriteResponse(HttpResponse::OK, Response.Save()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); }, HttpVerb::kGet); @@ -1585,7 +1610,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) EngineRoot, ProjectRoot); - Req.ServerRequest().WriteResponse(HttpResponse::Created); + Req.ServerRequest().WriteResponse(HttpResponseCode::Created); } break; @@ -1595,7 +1620,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (!ProjectIt) { - return Req.ServerRequest().WriteResponse(HttpResponse::NotFound, + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "project {} not found"_format(ProjectId)); } @@ -1609,7 +1634,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) Prj.IterateOplogs([&](const ProjectStore::Oplog& I) { Response << "id"sv << I.OplogId(); }); Response.EndArray(); // oplogs - Req.ServerRequest().WriteResponse(HttpResponse::OK, Response.Save()); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save()); } break; @@ -1619,7 +1644,7 @@ HttpProjectService::HttpProjectService(CasStore& Store, ProjectStore* Projects) if (!ProjectIt) { - return Req.ServerRequest().WriteResponse(HttpResponse::NotFound, + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "project {} not found"_format(ProjectId)); } diff --git a/zenserver/projectstore.h b/zenserver/projectstore.h index 33bacfdaa..8fe189ab9 100644 --- a/zenserver/projectstore.h +++ b/zenserver/projectstore.h @@ -2,9 +2,9 @@ #pragma once -#include <zencore/httpserver.h> #include <zencore/uid.h> #include <zencore/xxhash.h> +#include <zenhttp/httpserver.h> #include <zenstore/cas.h> #include <zenstore/caslog.h> diff --git a/zenserver/sos/sos.h b/zenserver/sos/sos.h index 283735cbd..da9064262 100644 --- a/zenserver/sos/sos.h +++ b/zenserver/sos/sos.h @@ -2,7 +2,7 @@ #pragma once -#include <zencore/httpserver.h> +#include <zenhttp/httpserver.h> #include <spdlog/spdlog.h> diff --git a/zenserver/testing/httptest.cpp b/zenserver/testing/httptest.cpp new file mode 100644 index 000000000..c4fd6003c --- /dev/null +++ b/zenserver/testing/httptest.cpp @@ -0,0 +1,106 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httptest.h" + +#include <zencore/compactbinarypackage.h> + +namespace zen { + +HttpTestingService::HttpTestingService() +{ + m_Router.RegisterRoute( + "hello", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [this](HttpRouterRequest& Req) { + IoBuffer Body = Req.ServerRequest().ReadPayload(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Body); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "package", + [this](HttpRouterRequest& Req) { + CbPackage Pkg = Req.ServerRequest().ReadPayloadPackage(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Pkg); + }, + HttpVerb::kPost); +} + +HttpTestingService::~HttpTestingService() +{ +} + +const char* +HttpTestingService::BaseUri() const +{ + return "/testing/"; +} + +void +HttpTestingService::HandleRequest(HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +Ref<IHttpPackageHandler> +HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) +{ + RwLock::ExclusiveLockScope _(m_RwLock); + + const uint32_t RequestId = HttpServiceRequest.RequestId(); + + if (auto It = m_HandlerMap.find(RequestId); It != m_HandlerMap.end()) + { + Ref<HttpTestingService::PackageHandler> Handler = std::move(It->second); + + m_HandlerMap.erase(It); + + return Handler.Get(); + } + + auto InsertResult = m_HandlerMap.insert({RequestId, nullptr}); + + _.ReleaseNow(); + + return (InsertResult.first->second = new PackageHandler(*this, RequestId)).Get(); +} + +////////////////////////////////////////////////////////////////////////// + +HttpTestingService::PackageHandler::PackageHandler(HttpTestingService& Svc, uint32_t RequestId) : m_Svc(Svc), m_RequestId(RequestId) +{ +} + +HttpTestingService::PackageHandler::~PackageHandler() +{ +} + +void +HttpTestingService::PackageHandler::FilterOffer(std::vector<IoHash>& OfferCids) +{ + ZEN_UNUSED(OfferCids); + // No-op + return; +} +void +HttpTestingService::PackageHandler::OnRequestBegin() +{ +} + +void +HttpTestingService::PackageHandler::OnRequestComplete() +{ +} + +IoBuffer +HttpTestingService::PackageHandler::CreateTarget(const IoHash& Cid, uint64_t StorageSize) +{ + ZEN_UNUSED(Cid); + return IoBuffer{StorageSize}; +} + +} // namespace zen diff --git a/zenserver/testing/httptest.h b/zenserver/testing/httptest.h new file mode 100644 index 000000000..5809d4e2e --- /dev/null +++ b/zenserver/testing/httptest.h @@ -0,0 +1,47 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +#include <spdlog/spdlog.h> + +namespace zen { + +/** + * Test service to facilitate testing the HTTP framework and client interactions + */ +class HttpTestingService : public HttpService +{ +public: + HttpTestingService(); + ~HttpTestingService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest) override; + + class PackageHandler : public IHttpPackageHandler + { + public: + PackageHandler(HttpTestingService& Svc, uint32_t RequestId); + ~PackageHandler(); + + virtual void FilterOffer(std::vector<IoHash>& OfferCids) override; + virtual void OnRequestBegin() override; + virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) override; + virtual void OnRequestComplete() override; + + private: + HttpTestingService& m_Svc; + uint32_t m_RequestId; + }; + +private: + HttpRequestRouter m_Router; + + RwLock m_RwLock; + std::unordered_map<uint32_t, Ref<PackageHandler>> m_HandlerMap; +}; + +} // namespace zen diff --git a/zenserver/testing/launch.cpp b/zenserver/testing/launch.cpp index d06fae3e2..b031193d5 100644 --- a/zenserver/testing/launch.cpp +++ b/zenserver/testing/launch.cpp @@ -409,7 +409,7 @@ HttpLaunchService::HttpLaunchService(CasStore& Store, const std::filesystem::pat Cbo.EndArray(); CbObject Response = Cbo.Save(); - return HttpReq.WriteResponse(HttpResponse::OK, Response); + return HttpReq.WriteResponse(HttpResponseCode::OK, Response); } break; } @@ -470,7 +470,7 @@ HttpLaunchService::HttpLaunchService(CasStore& Store, const std::filesystem::pat if (!AllOk) { // TODO: Could report all the missing pieces in the response here - return HttpReq.WriteResponse(HttpResponse::NotFound); + return HttpReq.WriteResponse(HttpResponseCode::NotFound); } std::string Executable8{RequestObject["cmd"].AsString()}; @@ -492,7 +492,7 @@ HttpLaunchService::HttpLaunchService(CasStore& Store, const std::filesystem::pat Response << "exitcode" << Job.ExitCode(); - return HttpReq.WriteResponse(HttpResponse::OK, Response.Save()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); } break; } diff --git a/zenserver/testing/launch.h b/zenserver/testing/launch.h index 00f322624..a6eb137d2 100644 --- a/zenserver/testing/launch.h +++ b/zenserver/testing/launch.h @@ -2,7 +2,7 @@ #pragma once -#include <zencore/httpserver.h> +#include <zenhttp/httpserver.h> #include <spdlog/spdlog.h> #include <filesystem> diff --git a/zenserver/upstream/jupiter.cpp b/zenserver/upstream/jupiter.cpp index 2dd51fe6c..ba6300c65 100644 --- a/zenserver/upstream/jupiter.cpp +++ b/zenserver/upstream/jupiter.cpp @@ -81,7 +81,7 @@ CloudCacheSession::GetDerivedData(std::string_view BucketId, std::string_view Ke cpr::Response Response = Session.Get(); m_Log.debug("GET {}", Response); - + const bool Success = Response.status_code == 200; const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); diff --git a/zenserver/upstream/jupiter.h b/zenserver/upstream/jupiter.h index efe9d07ba..9f36704fa 100644 --- a/zenserver/upstream/jupiter.h +++ b/zenserver/upstream/jupiter.h @@ -2,9 +2,9 @@ #pragma once -#include <zencore/httpserver.h> #include <zencore/refcount.h> #include <zencore/thread.h> +#include <zenhttp/httpserver.h> #include <spdlog/spdlog.h> diff --git a/zenserver/vfs.cpp b/zenserver/vfs.cpp index 16b23513f..18d8f1842 100644 --- a/zenserver/vfs.cpp +++ b/zenserver/vfs.cpp @@ -2,7 +2,7 @@ #include "vfs.h" -#if WITH_VFS +#if ZEN_WITH_VFS # include <zencore/except.h> # include <zencore/filesystem.h> # include <zencore/snapshot_manifest.h> diff --git a/zenserver/vfs.h b/zenserver/vfs.h index f8fea6e12..0d2ca6062 100644 --- a/zenserver/vfs.h +++ b/zenserver/vfs.h @@ -2,7 +2,11 @@ #pragma once -#if WITH_VFS +#ifndef ZEN_WITH_VFS +# define ZEN_WITH_VFS 0 +#endif + +#if ZEN_WITH_VFS # include <memory> namespace zen { diff --git a/zenserver/xmake.lua b/zenserver/xmake.lua index cbe021f90..bb70846fa 100644 --- a/zenserver/xmake.lua +++ b/zenserver/xmake.lua @@ -1,7 +1,7 @@ target("zenserver") set_kind("binary") add_files("**.cpp") - add_deps("zencore", "zenstore", "zenutil") + add_deps("zencore", "zenhttp", "zenstore", "zenutil") add_includedirs(".") set_symbols("debug") @@ -23,7 +23,8 @@ target("zenserver") "vcpkg::sol2", "vcpkg::lua", "vcpkg::asio", - "vcpkg::json11" + "vcpkg::json11", + "vcpkg::uwebsockets", "vcpkg::usockets", "vcpkg::libuv" ) add_packages( diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp index c857d4c71..c9f74daa4 100644 --- a/zenserver/zenserver.cpp +++ b/zenserver/zenserver.cpp @@ -2,7 +2,6 @@ #include <zencore/filesystem.h> #include <zencore/fmtutils.h> -#include <zencore/httpserver.h> #include <zencore/iobuffer.h> #include <zencore/refcount.h> #include <zencore/scopeguard.h> @@ -10,6 +9,7 @@ #include <zencore/thread.h> #include <zencore/timer.h> #include <zencore/windows.h> +#include <zenhttp/httpserver.h> #include <zenserverprocess.h> #include <zenstore/cas.h> #include <zenstore/cidstore.h> @@ -72,6 +72,7 @@ #include "diag/diagsvcs.h" #include "experimental/usnjournal.h" #include "projectstore.h" +#include "testing/httptest.h" #include "testing/launch.h" #include "upstream/jupiter.h" #include "upstream/upstreamcache.h" @@ -205,7 +206,7 @@ public: else { UpstreamCache.reset(); - spdlog::info("upstream cache NOT active"); + spdlog::info("NOT using upstream cache"); } } @@ -221,32 +222,40 @@ public: { StartMesh(BasePort); } + else + { + spdlog::info("NOT starting mesh"); + } + + m_Http = zen::CreateHttpServer(); + m_Http->Initialize(BasePort); + m_Http->RegisterService(m_HealthService); + + m_Http->RegisterService(m_TestService); // NOTE: this is intentionally not limited to test mode as it's useful for diagnostics + m_Http->RegisterService(m_TestingService); - m_Http.Initialize(BasePort); - m_Http.AddEndpoint(m_HealthService); - m_Http.AddEndpoint(m_TestService); - m_Http.AddEndpoint(m_AdminService); + m_Http->RegisterService(m_AdminService); if (m_HttpProjectService) { - m_Http.AddEndpoint(*m_HttpProjectService); + m_Http->RegisterService(*m_HttpProjectService); } - m_Http.AddEndpoint(m_CasService); + m_Http->RegisterService(m_CasService); if (m_StructuredCacheService) { - m_Http.AddEndpoint(*m_StructuredCacheService); + m_Http->RegisterService(*m_StructuredCacheService); } if (m_HttpLaunchService) { - m_Http.AddEndpoint(*m_HttpLaunchService); + m_Http->RegisterService(*m_HttpLaunchService); } if (m_HttpFunctionService) { - m_Http.AddEndpoint(*m_HttpFunctionService); + m_Http->RegisterService(*m_HttpFunctionService); } } @@ -284,7 +293,7 @@ public: __debugbreak(); } - m_Http.Run(m_TestMode); + m_Http->Run(m_TestMode); spdlog::info(ZEN_APP_NAME " exiting"); @@ -296,7 +305,7 @@ public: void RequestExit(int ExitCode) { RequestApplicationExit(ExitCode); - m_Http.RequestExit(); + m_Http->RequestExit(); } void Cleanup() { spdlog::info(ZEN_APP_NAME " cleaning up"); } @@ -358,13 +367,14 @@ private: zen::ProcessHandle m_Process; zen::NamedMutex m_ServerMutex; - zen::HttpServer m_Http; + zen::Ref<zen::HttpServer> m_Http; std::unique_ptr<zen::CasStore> m_CasStore{zen::CreateCasStore()}; std::unique_ptr<zen::CidStore> m_CidStore; std::unique_ptr<ZenCacheStore> m_CacheStore; zen::CasGc m_Gc{*m_CasStore}; zen::CasScrubber m_Scrubber{*m_CasStore}; HttpTestService m_TestService; + zen::HttpTestingService m_TestingService; zen::HttpCasService m_CasService{*m_CasStore}; zen::RefPtr<zen::ProjectStore> m_ProjectStore; zen::Ref<zen::LocalProjectService> m_LocalProjectService; diff --git a/zenserver/zenserver.vcxproj b/zenserver/zenserver.vcxproj index e4b40e13e..aa9d538a5 100644 --- a/zenserver/zenserver.vcxproj +++ b/zenserver/zenserver.vcxproj @@ -110,6 +110,7 @@ <ClInclude Include="config.h" /> <ClInclude Include="diag\logging.h" /> <ClInclude Include="sos\sos.h" /> + <ClInclude Include="testing\httptest.h" /> <ClInclude Include="upstream\jupiter.h" /> <ClInclude Include="projectstore.h" /> <ClInclude Include="cache\cacheagent.h" /> @@ -132,6 +133,7 @@ <ClCompile Include="projectstore.cpp" /> <ClCompile Include="cache\cacheagent.cpp" /> <ClCompile Include="sos\sos.cpp" /> + <ClCompile Include="testing\httptest.cpp" /> <ClCompile Include="upstream\jupiter.cpp" /> <ClCompile Include="testing\launch.cpp" /> <ClCompile Include="cache\cachestore.cpp" /> @@ -146,6 +148,9 @@ <ProjectReference Include="..\zencore\zencore.vcxproj"> <Project>{d75bf9ab-c61e-4fff-ad59-1563430f05e2}</Project> </ProjectReference> + <ProjectReference Include="..\zenhttp\zenhttp.vcxproj"> + <Project>{8eeb3be5-7001-46bf-aafd-edb7558ac012}</Project> + </ProjectReference> <ProjectReference Include="..\zenstore\zenstore.vcxproj"> <Project>{26cbbaeb-14c1-4efc-877d-80f48215651c}</Project> </ProjectReference> @@ -153,6 +158,9 @@ <Project>{77f8315d-b21d-4db0-9a6f-2d3359f88a70}</Project> </ProjectReference> </ItemGroup> + <ItemGroup> + <None Include="xmake.lua" /> + </ItemGroup> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <ImportGroup Label="ExtensionTargets"> </ImportGroup> diff --git a/zenserver/zenserver.vcxproj.filters b/zenserver/zenserver.vcxproj.filters index ca16caf77..a86a6d96d 100644 --- a/zenserver/zenserver.vcxproj.filters +++ b/zenserver/zenserver.vcxproj.filters @@ -38,6 +38,7 @@ <ClInclude Include="upstream\upstreamcache.h"> <Filter>upstream</Filter> </ClInclude> + <ClInclude Include="testing\httptest.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="zenserver.cpp" /> @@ -71,6 +72,7 @@ <ClCompile Include="upstream\upstreamcache.cpp"> <Filter>upstream</Filter> </ClCompile> + <ClCompile Include="testing\httptest.cpp" /> </ItemGroup> <ItemGroup> <Filter Include="cache"> @@ -89,4 +91,7 @@ <UniqueIdentifier>{303c28c2-3607-4ef4-89bd-e3618fe37e74}</UniqueIdentifier> </Filter> </ItemGroup> + <ItemGroup> + <None Include="xmake.lua" /> + </ItemGroup> </Project>
\ No newline at end of file diff --git a/zenstore/cidstore.cpp b/zenstore/cidstore.cpp index 4e5188f1c..391520599 100644 --- a/zenstore/cidstore.cpp +++ b/zenstore/cidstore.cpp @@ -2,6 +2,7 @@ #include "zenstore/cidstore.h" +#include <zencore/compress.h> #include <zencore/filesystem.h> #include <zenstore/CAS.h> #include <zenstore/caslog.h> @@ -27,10 +28,25 @@ struct CidStore::CidState RwLock m_Lock; tsl::robin_map<IoHash, IoHash> m_CidMap; + CasStore::InsertResult AddChunk(CompressedBuffer& ChunkData) + { + IoBuffer Payload = ChunkData.GetCompressed().Flatten().AsIoBuffer(); + IoHash CompressedHash = IoHash::HashBuffer(Payload.Data(), Payload.Size()); + + CasStore::InsertResult Result = m_CasStore.InsertChunk(Payload, CompressedHash); + AddCompressedCid(IoHash::FromBLAKE3(ChunkData.GetRawHash()), CompressedHash); + + return Result; + } + void AddCompressedCid(const IoHash& DecompressedId, const IoHash& Compressed) { RwLock::ExclusiveLockScope _(m_Lock); m_CidMap.insert_or_assign(DecompressedId, Compressed); + // TODO: it's pretty wasteful to log even idempotent updates + // however we can't simply use the boolean returned by insert_or_assign + // since there's not a 1:1 mapping between compressed and uncompressed + // so if we want a last-write-wins policy then we have to log each update m_LogFile.Append({.Uncompressed = DecompressedId, .Compressed = Compressed}); } @@ -87,6 +103,12 @@ CidStore::~CidStore() { } +CasStore::InsertResult +CidStore::AddChunk(CompressedBuffer& ChunkData) +{ + return m_Impl->AddChunk(ChunkData); +} + void CidStore::AddCompressedCid(const IoHash& DecompressedId, const IoHash& Compressed) { diff --git a/zenstore/include/zenstore/cidstore.h b/zenstore/include/zenstore/cidstore.h index 2c2b395a5..62d642ad1 100644 --- a/zenstore/include/zenstore/cidstore.h +++ b/zenstore/include/zenstore/cidstore.h @@ -4,6 +4,7 @@ #include <tsl/robin_map.h> #include <zencore/iohash.h> +#include <zenstore/CAS.h> namespace std::filesystem { class path; @@ -12,6 +13,7 @@ class path; namespace zen { class CasStore; +class CompressedBuffer; class IoBuffer; /** Content Store @@ -29,10 +31,13 @@ public: CidStore(CasStore& InCasStore, const std::filesystem::path& RootDir); ~CidStore(); - void AddCompressedCid(const IoHash& DecompressedId, const IoHash& Compressed); - IoBuffer FindChunkByCid(const IoHash& DecompressedId); - bool ContainsChunk(const IoHash& DecompressedId); - void Flush(); + CasStore::InsertResult AddChunk(CompressedBuffer& ChunkData); + void AddCompressedCid(const IoHash& DecompressedId, const IoHash& Compressed); + IoBuffer FindChunkByCid(const IoHash& DecompressedId); + bool ContainsChunk(const IoHash& DecompressedId); + void Flush(); + + // TODO: add batch filter support private: struct CidState; diff --git a/zenutil/include/zenserverprocess.h b/zenutil/include/zenserverprocess.h index d0093537e..f7d911a87 100644 --- a/zenutil/include/zenserverprocess.h +++ b/zenutil/include/zenserverprocess.h @@ -55,6 +55,8 @@ struct ZenServerInstance void AttachToRunningServer(int BasePort = 0); + std::string GetBaseUri() const; + private: ZenServerEnvironment& m_Env; zen::ProcessHandle m_Process; @@ -63,6 +65,7 @@ private: bool m_Terminate = false; std::filesystem::path m_TestDir; bool m_MeshEnabled = false; + int m_BasePort = 0; void CreateShutdownEvent(int BasePort); }; diff --git a/zenutil/zenserverprocess.cpp b/zenutil/zenserverprocess.cpp index 4e45ddfae..093f18f6a 100644 --- a/zenutil/zenserverprocess.cpp +++ b/zenutil/zenserverprocess.cpp @@ -403,6 +403,7 @@ ZenServerInstance::SpawnServer(int BasePort) if (BasePort) { CommandLine << " --port " << BasePort; + m_BasePort = BasePort; } if (!m_TestDir.empty()) @@ -418,7 +419,7 @@ ZenServerInstance::SpawnServer(int BasePort) std::filesystem::path CurrentDirectory = std::filesystem::current_path(); - spdlog::debug("Spawning server"); + spdlog::debug("Spawning server '{}'", LogId); PROCESS_INFORMATION ProcessInfo{}; STARTUPINFO StartupInfo{.cb = sizeof(STARTUPINFO)}; @@ -492,7 +493,7 @@ ZenServerInstance::SpawnServer(int BasePort) } } - spdlog::debug("Server spawned OK"); + spdlog::debug("Server '{}' spawned OK", LogId); if (IsTest) { @@ -558,3 +559,13 @@ ZenServerInstance::WaitUntilReady(int Timeout) { return m_ReadyEvent.Wait(Timeout); } + +std::string +ZenServerInstance::GetBaseUri() const +{ + ZEN_ASSERT(m_BasePort); + + using namespace fmt::literals; + + return "http://localhost:{}"_format(m_BasePort); +} |