diff options
| -rw-r--r-- | zencore/compactbinary.cpp | 2 | ||||
| -rw-r--r-- | zencore/compactbinarypackage.cpp | 418 | ||||
| -rw-r--r-- | zencore/compactbinaryvalidation.cpp | 143 | ||||
| -rw-r--r-- | zencore/httpserver.cpp | 459 | ||||
| -rw-r--r-- | zencore/include/zencore/compactbinary.h | 22 | ||||
| -rw-r--r-- | zencore/include/zencore/compactbinarypackage.h | 47 | ||||
| -rw-r--r-- | zencore/include/zencore/compactbinaryvalidation.h | 5 | ||||
| -rw-r--r-- | zencore/include/zencore/httpserver.h | 55 | ||||
| -rw-r--r-- | zenserver-test/zenserver-test.cpp | 93 | ||||
| -rw-r--r-- | zenserver/testing/httptest.cpp | 48 | ||||
| -rw-r--r-- | zenserver/testing/httptest.h | 34 | ||||
| -rw-r--r-- | zenserver/zenserver.cpp | 16 | ||||
| -rw-r--r-- | zenserver/zenserver.vcxproj | 2 | ||||
| -rw-r--r-- | zenserver/zenserver.vcxproj.filters | 2 | ||||
| -rw-r--r-- | zenutil/include/zenserverprocess.h | 3 | ||||
| -rw-r--r-- | zenutil/zenserverprocess.cpp | 15 |
16 files changed, 932 insertions, 432 deletions
diff --git a/zencore/compactbinary.cpp b/zencore/compactbinary.cpp index 5fe7f272d..b508d8fe8 100644 --- a/zencore/compactbinary.cpp +++ b/zencore/compactbinary.cpp @@ -12,7 +12,7 @@ namespace zen { -const int DaysToMonth[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; +const int DaysToMonth[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; bool IsLeapYear(int Year) diff --git a/zencore/compactbinarypackage.cpp b/zencore/compactbinarypackage.cpp index 7880164f9..a345f2b1b 100644 --- a/zencore/compactbinarypackage.cpp +++ b/zencore/compactbinarypackage.cpp @@ -16,22 +16,47 @@ CbAttachment::CbAttachment(const CompressedBuffer& InValue) : CbAttachment(InVal { } -CbAttachment::CbAttachment(CompressedBuffer&& InValue) +CbAttachment::CbAttachment(const SharedBuffer& InValue) : CbAttachment(CompositeBuffer(InValue)) { - Value.emplace<CompressedBuffer>(std::move(InValue).MakeOwned()); } -CbAttachment::CbAttachment(const SharedBuffer& InValue) +CbAttachment::CbAttachment(const SharedBuffer& InValue, [[maybe_unused]] const IoHash& InHash) : CbAttachment(InValue.IsNull() ? CompressedBuffer() : CompressedBuffer::Compress(InValue, OodleCompressor::NotSet, OodleCompressionLevel::None)) { + // This could be more efficient, and should at the very least try to validate the hash } -CbAttachment::CbAttachment(const SharedBuffer& InValue, [[maybe_unused]] const IoHash& InHash) -: CbAttachment(InValue.IsNull() ? CompressedBuffer() - : CompressedBuffer::Compress(InValue, OodleCompressor::NotSet, OodleCompressionLevel::None)) +CbAttachment::CbAttachment(const CompositeBuffer& InValue) : Value{std::in_place_type<BinaryValue>, InValue} { - // This could be more efficient, and should at the very least try to validate the hash + if (std::get<BinaryValue>(Value).Buffer.IsNull()) + { + Value.emplace<nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompositeBuffer&& InValue) : Value{std::in_place_type<BinaryValue>, InValue} +{ + if (std::get<BinaryValue>(Value).Buffer.IsNull()) + { + Value.emplace<nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompositeBuffer&& InValue, const IoHash& Hash) : Value{std::in_place_type<BinaryValue>, InValue, Hash} +{ + if (std::get<BinaryValue>(Value).Buffer.IsNull()) + { + Value.emplace<nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompressedBuffer&& InValue) : Value(std::in_place_type<CompressedBuffer>, InValue) +{ + if (std::get<CompressedBuffer>(Value).IsNull()) + { + Value.emplace<nullptr_t>(); + } } CbAttachment::CbAttachment(const CbObject& InValue, const IoHash* const InHash) @@ -70,114 +95,139 @@ CbAttachment::TryLoad(IoBuffer& InBuffer, BufferAllocator Allocator) bool CbAttachment::TryLoad(CbFieldIterator& Fields) { - const CbObjectView ObjectView = Fields.AsObjectView(); - if (Fields.HasError()) + if (const CbObjectView ObjectView = Fields.AsObjectView(); !Fields.HasError()) + { + // Is a null object or object not prefixed with a precomputed hash value + Value.emplace<CbObjectValue>(CbObject(ObjectView, Fields.GetOuterBuffer()), ObjectView.GetHash()); + ++Fields; + } + else if (const IoHash ObjectAttachmentHash = Fields.AsObjectAttachment(); !Fields.HasError()) + { + // Is an object + ++Fields; + const CbObjectView InnerObjectView = Fields.AsObjectView(); + if (Fields.HasError()) + { + return false; + } + Value.emplace<CbObjectValue>(CbObject(InnerObjectView, Fields.GetOuterBuffer()), ObjectAttachmentHash); + ++Fields; + } + else if (const IoHash BinaryAttachmentHash = Fields.AsBinaryAttachment(); !Fields.HasError()) + { + // Is an uncompressed binary blob + ++Fields; + MemoryView BinaryView = Fields.AsBinaryView(); + if (Fields.HasError()) + { + return false; + } + Value.emplace<BinaryValue>(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer()), BinaryAttachmentHash); + ++Fields; + } + else if (MemoryView BinaryView = Fields.AsBinaryView(); !Fields.HasError()) { - // Is a buffer - const MemoryView BinaryView = Fields.AsBinaryView(); if (BinaryView.GetSize() > 0) { + // Is a compressed binary blob Value.emplace<CompressedBuffer>( CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer())).MakeOwned()); - ++Fields; } else { + // Is an uncompressed empty binary blob + Value.emplace<BinaryValue>(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer()), IoHash::HashBuffer(nullptr, 0)); ++Fields; - Value.emplace<CompressedBuffer>(); } } else { - // It's an object - ++Fields; - IoHash Hash; - if (ObjectView) - { - Hash = Fields.AsObjectAttachment(); - ++Fields; - } - else - { - Hash = IoHash::HashBuffer(MemoryView{}); - } - Value.emplace<CbObjectValue>(CbObject(ObjectView, Fields->GetOuterBuffer()), Hash); + return false; } return true; } -bool -CbAttachment::TryLoad(BinaryReader& Reader, BufferAllocator Allocator) +static bool +TryLoad_ArchiveFieldIntoAttachment(CbAttachment& TargetAttachment, CbField&& Field, BinaryReader& Reader, BufferAllocator Allocator) { - CbField Field = LoadCompactBinary(Reader, Allocator); - const CbObjectView ObjectView = Field.AsObjectView(); - - if (Field.HasError()) + if (const CbObjectView ObjectView = Field.AsObjectView(); !Field.HasError()) { - // It's a buffer - const MemoryView BinaryView = Field.AsBinaryView(); - if (BinaryView.GetSize() > 0) + // Is a null object or object not prefixed with a precomputed hash value + TargetAttachment = CbAttachment(CbObject(ObjectView, std::move(Field)), ObjectView.GetHash()); + } + else if (const IoHash ObjectAttachmentHash = Field.AsObjectAttachment(); !Field.HasError()) + { + // Is an object + Field = LoadCompactBinary(Reader, Allocator); + if (!Field.IsObject()) { - Value.emplace<CompressedBuffer>( - CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView, Field.GetOuterBuffer())).MakeOwned()); + return false; } - else + TargetAttachment = CbAttachment(std::move(Field).AsObject(), ObjectAttachmentHash); + } + else if (const IoHash BinaryAttachmentHash = Field.AsBinaryAttachment(); !Field.HasError()) + { + // Is an uncompressed binary blob + Field = LoadCompactBinary(Reader, Allocator); + SharedBuffer Buffer = Field.AsBinary(); + if (Field.HasError()) { - Value.emplace<CompressedBuffer>(); + return false; } + TargetAttachment = CbAttachment(CompositeBuffer(Buffer), BinaryAttachmentHash); } - else + else if (SharedBuffer Buffer = Field.AsBinary(); !Field.HasError()) { - // It's an object - IoHash Hash; - if (ObjectView) + if (Buffer.GetSize() > 0) { - std::vector<uint8_t> HashBuffer; - CbField HashField = LoadCompactBinary(Reader, [&HashBuffer](uint64_t Size) -> UniqueBuffer { - HashBuffer.resize(Size); - return UniqueBuffer::MakeMutableView(HashBuffer.data(), Size); - }); - Hash = HashField.AsAttachment(); - if (HashField.HasError() || ObjectView.GetHash() != Hash) - { - // Error - return false; - } + // Is a compressed binary blob + TargetAttachment = CbAttachment(CompressedBuffer::FromCompressed(std::move(Buffer))); } else { - Hash = IoHash::HashBuffer(MemoryView()); + // Is an uncompressed empty binary blob + TargetAttachment = CbAttachment(CompositeBuffer(Buffer), IoHash::HashBuffer(nullptr, 0)); } - Value.emplace<CbObjectValue>(CbObject(ObjectView, Field.GetOuterBuffer()), Hash); + } + else + { + return false; } return true; } +bool +CbAttachment::TryLoad(BinaryReader& Reader, BufferAllocator Allocator) +{ + CbField Field = LoadCompactBinary(Reader, Allocator); + return TryLoad_ArchiveFieldIntoAttachment(*this, std::move(Field), Reader, Allocator); +} + void CbAttachment::Save(CbWriter& Writer) const { - if (const CbObjectValue* ObjectValue = std::get_if<CbObjectValue>(&Value)) + if (const CbObjectValue* ObjValue = std::get_if<CbObjectValue>(&Value)) { - Writer.AddObject(ObjectValue->Object); - if (ObjectValue->Object) + if (ObjValue->Object) { - Writer.AddObjectAttachment(ObjectValue->Hash); + Writer.AddObjectAttachment(ObjValue->Hash); } + Writer.AddObject(ObjValue->Object); } - else + else if (const BinaryValue* BinValue = std::get_if<BinaryValue>(&Value)) { - const CompressedBuffer& BufferValue = std::get<CompressedBuffer>(Value); - if (BufferValue.GetRawSize()) - { - Writer.AddBinary(BufferValue.GetCompressed()); - } - else // Null + if (BinValue->Buffer.GetSize() > 0) { - Writer.AddBinary(MemoryView()); + Writer.AddBinaryAttachment(BinValue->Hash); } + Writer.AddBinary(BinValue->Buffer); + } + else if (const CompressedBuffer* BufferValue = std::get_if<CompressedBuffer>(&Value)) + { + Writer.AddBinary(BufferValue->GetCompressed()); } } @@ -192,14 +242,19 @@ CbAttachment::Save(BinaryWriter& Writer) const bool CbAttachment::IsNull() const { - if (const CompressedBuffer* Buffer = std::get_if<CompressedBuffer>(&Value)) - { - return Buffer->IsNull(); - } - else - { - return false; - } + return std::holds_alternative<nullptr_t>(Value); +} + +bool +CbAttachment::IsBinary() const +{ + return std::holds_alternative<BinaryValue>(Value); +} + +bool +CbAttachment::IsCompressedBinary() const +{ + return std::holds_alternative<CompressedBuffer>(Value); } bool @@ -213,25 +268,42 @@ CbAttachment::GetHash() const { if (const CompressedBuffer* Buffer = std::get_if<CompressedBuffer>(&Value)) { - return Buffer->IsNull() ? IoHash::HashBuffer(MemoryView()) : IoHash::FromBLAKE3(Buffer->GetRawHash()); + return IoHash::FromBLAKE3(Buffer->GetRawHash()); } - else + + if (const BinaryValue* BinValue = std::get_if<BinaryValue>(&Value)) + { + return BinValue->Hash; + } + + if (const CbObjectValue* ObjectValue = std::get_if<CbObjectValue>(&Value)) { - return std::get<CbObjectValue>(Value).Hash; + return ObjectValue->Hash; } + + return IoHash::Zero; } -SharedBuffer -CbAttachment::AsBinary() const +CompositeBuffer +CbAttachment::AsCompositeBinary() const { - if (const CompressedBuffer* Buffer = std::get_if<CompressedBuffer>(&Value)) + if (const BinaryValue* BinValue = std::get_if<BinaryValue>(&Value)) { - return Buffer->Decompress(); + return BinValue->Buffer; } - else + + return CompositeBuffer::Null; +} + +SharedBuffer +CbAttachment::AsBinary() const +{ + if (const BinaryValue* BinValue = std::get_if<BinaryValue>(&Value)) { - return std::get<CbObjectValue>(Value).Object.GetBuffer(); + return BinValue->Buffer.Flatten(); } + + return {}; } CompressedBuffer @@ -241,12 +313,8 @@ CbAttachment::AsCompressedBinary() const { return *Buffer; } - else - { - return CompressedBuffer::Compress(std::get<CbObjectValue>(Value).Object.GetBuffer(), - OodleCompressor::NotSet, - OodleCompressionLevel::None); - } + + return CompressedBuffer::Null; } /** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */ @@ -257,10 +325,8 @@ CbAttachment::AsObject() const { return ObjectValue->Object; } - else - { - return {}; - } + + return {}; } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -301,10 +367,7 @@ CbPackage::AddAttachment(const CbAttachment& Attachment, AttachmentResolver* Res if (It != Attachments.end() && *It == Attachment) { CbAttachment& Existing = *It; - if (Attachment.IsObject() && !Existing.IsObject()) - { - Existing = CbAttachment(CbObject(Existing.AsBinary()), Existing.GetHash()); - } + Existing = Attachment; } else { @@ -358,7 +421,7 @@ CbPackage::GatherAttachments(const CbObject& Value, AttachmentResolver Resolver) } else { - AddAttachment(CbAttachment(std::move(Buffer), Hash)); + AddAttachment(CbAttachment(std::move(Buffer))); } } }); @@ -377,6 +440,7 @@ bool CbPackage::TryLoad(CbFieldIterator& Fields) { *this = CbPackage(); + while (Fields) { if (Fields.IsNull()) @@ -384,43 +448,76 @@ CbPackage::TryLoad(CbFieldIterator& Fields) ++Fields; break; } - else if (Fields.IsBinary()) - { - CbAttachment Attachment; - Attachment.TryLoad(Fields); - AddAttachment(Attachment); - } - else + else if (IoHash Hash = Fields.AsHash(); !Fields.HasError() && !Fields.IsAttachment()) { - Object = Fields.AsObject(); - if (Fields->HasError()) + ++Fields; + CbObjectView ObjectView = Fields.AsObjectView(); + if (Fields.HasError() || Hash != ObjectView.GetHash()) { return false; } + Object = CbObject(ObjectView, Fields.GetOuterBuffer()); Object.MakeOwned(); + ObjectHash = Hash; ++Fields; - if (Object.CreateIterator()) - { - ObjectHash = Fields.AsObjectAttachment(); - if (Fields.HasError()) - { - return false; - } - ++Fields; - } - else + } + else + { + CbAttachment Attachment; + if (!Attachment.TryLoad(Fields)) { - Object.Reset(); + return false; } + AddAttachment(Attachment); } } - return true; } bool CbPackage::TryLoad(BinaryReader& Reader, BufferAllocator Allocator, AttachmentResolver* Mapper) { + // TODO: this needs to re-grow the ability to accept a reference to an attachment which is + // not embedded + + ZEN_UNUSED(Mapper); + +#if 1 + *this = CbPackage(); + for (;;) + { + CbField Field = LoadCompactBinary(Reader, Allocator); + if (!Field) + { + return false; + } + + if (Field.IsNull()) + { + return true; + } + else if (IoHash Hash = Field.AsHash(); !Field.HasError() && !Field.IsAttachment()) + { + Field = LoadCompactBinary(Reader, Allocator); + CbObjectView ObjectView = Field.AsObjectView(); + if (Field.HasError() || Hash != ObjectView.GetHash()) + { + return false; + } + Object = CbObject(ObjectView, Field.GetOuterBuffer()); + ObjectHash = Hash; + } + else + { + CbAttachment Attachment; + if (!TryLoad_ArchiveFieldIntoAttachment(Attachment, std::move(Field), Reader, Allocator)) + { + return false; + } + AddAttachment(Attachment); + } + } +#else uint8_t StackBuffer[64]; const auto StackAllocator = [&Allocator, &StackBuffer](uint64_t Size) -> UniqueBuffer { if (Size <= sizeof(StackBuffer)) @@ -494,6 +591,7 @@ CbPackage::TryLoad(BinaryReader& Reader, BufferAllocator Allocator, AttachmentRe } } } +#endif } void @@ -501,8 +599,8 @@ CbPackage::Save(CbWriter& Writer) const { if (Object) { + Writer.AddHash(ObjectHash); Writer.AddObject(Object); - Writer.AddObjectAttachment(ObjectHash); } for (const CbAttachment& Attachment : Attachments) { @@ -567,8 +665,7 @@ TEST_CASE("usonpackage") CHECK_FALSE(bool(Attachment.AsObject())); CHECK_FALSE(Attachment.IsBinary()); CHECK_FALSE(Attachment.IsObject()); - CHECK(Attachment.GetHash() == IoHash::HashBuffer({})); - TestSaveLoadValidate("Null", Attachment); + CHECK(Attachment.GetHash() == IoHash::Zero); } SUBCASE("Binary Attachment") @@ -596,12 +693,12 @@ TEST_CASE("usonpackage") CHECK_FALSE(Attachment.IsNull()); CHECK(bool(Attachment)); - CHECK(Attachment.AsBinary() == Object.GetBuffer()); + CHECK(Attachment.AsBinary() == SharedBuffer()); CHECK(Attachment.AsObject().Equals(Object)); - CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsBinary()); CHECK(Attachment.IsObject()); CHECK(Attachment.GetHash() == Object.GetHash()); - TestSaveLoadValidate("CompactBinary", Attachment); + TestSaveLoadValidate("Object", Attachment); } SUBCASE("Binary View") @@ -633,7 +730,7 @@ TEST_CASE("usonpackage") CHECK(Attachment.AsBinary() != ObjectView.GetBuffer()); CHECK(Attachment.AsObject().Equals(Object)); - CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsBinary()); CHECK(Attachment.IsObject()); CHECK(Attachment.GetHash() == IoHash(Object.GetHash())); } @@ -677,15 +774,14 @@ TEST_CASE("usonpackage") CbFieldIterator FieldsView = CbFieldIterator::MakeRangeView(CbFieldViewIterator(Fields)); Attachment.TryLoad(FieldsView); + MemoryView View; CHECK_FALSE(Attachment.IsNull()); CHECK(bool(Attachment)); - - CHECK(Attachment.AsBinary().GetView().EqualBytes(Value.GetView())); - CHECK_FALSE(FieldsView.GetBuffer().GetView().Contains(Attachment.AsObject().GetBuffer().GetView())); - CHECK(Attachment.IsBinary()); + CHECK(Attachment.AsBinary().GetView().EqualBytes(MemoryView())); + CHECK_FALSE((!Attachment.AsObject().TryGetSerializedView(View) || FieldsView.GetOuterBuffer().GetView().Contains(View))); + CHECK_FALSE(Attachment.IsBinary()); CHECK(Attachment.IsObject()); - CHECK(Attachment.GetHash() == Value.GetHash()); } @@ -696,7 +792,7 @@ TEST_CASE("usonpackage") CHECK(Attachment.IsNull()); CHECK_FALSE(Attachment.IsBinary()); CHECK_FALSE(Attachment.IsObject()); - CHECK(Attachment.GetHash() == IoHash::HashBuffer(SharedBuffer{})); + CHECK(Attachment.GetHash() == IoHash::Zero); } SUBCASE("Binary Empty") @@ -714,7 +810,7 @@ TEST_CASE("usonpackage") const CbAttachment Attachment(CbObject{}); CHECK_FALSE(Attachment.IsNull()); - CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsBinary()); CHECK(Attachment.IsObject()); CHECK(Attachment.GetHash() == CbObject().GetHash()); } @@ -833,17 +929,16 @@ TEST_CASE("usonpackage.serialization") CHECK((Object1Attachment && Object1Attachment->AsObject().Equals(Object1))); CHECK((Object2Attachment && Object2Attachment->AsBinary() == Object2.GetBuffer())); - Package.AddAttachment(CbAttachment(SharedBuffer::Clone(Object1.GetView()))); + SharedBuffer Object1ClonedBuffer = SharedBuffer::Clone(Object1.GetOuterBuffer()); + Package.AddAttachment(CbAttachment(Object1ClonedBuffer)); Package.AddAttachment(CbAttachment(CbObject::Clone(Object2))); CHECK(Package.GetAttachments().size() == 2); CHECK(Package.FindAttachment(Object1.GetHash()) == Object1Attachment); CHECK(Package.FindAttachment(Object2.GetHash()) == Object2Attachment); - CHECK((Object1Attachment && Object1Attachment->AsObject().Equals(Object1))); - CHECK((Object1Attachment && Object1Attachment->AsBinary() == Object1.GetBuffer())); + CHECK((Object1Attachment && Object1Attachment->AsBinary() == Object1ClonedBuffer)); CHECK((Object2Attachment && Object2Attachment->AsObject().Equals(Object2))); - CHECK((Object2Attachment && Object2Attachment->AsBinary() == Object2.GetBuffer())); CHECK(std::is_sorted(begin(Package.GetAttachments()), end(Package.GetAttachments()))); } @@ -884,8 +979,8 @@ TEST_CASE("usonpackage.serialization") const IoHash Level1Hash = Level1.GetHash(); const auto Resolver = [&Level2, &Level2Hash, &Level3, &Level3Hash, &Level4, &Level4Hash](const IoHash& Hash) -> SharedBuffer { - return Hash == Level2Hash ? Level2.GetBuffer() - : Hash == Level3Hash ? Level3.GetBuffer() + return Hash == Level2Hash ? Level2.GetOuterBuffer() + : Hash == Level3Hash ? Level3.GetOuterBuffer() : Hash == Level4Hash ? Level4 : SharedBuffer(); }; @@ -907,8 +1002,9 @@ TEST_CASE("usonpackage.serialization") const CbAttachment* const Level4Attachment = Package.FindAttachment(Level4Hash); CHECK((Level2Attachment && Level2Attachment->AsObject().Equals(Level2))); CHECK((Level3Attachment && Level3Attachment->AsObject().Equals(Level3))); - CHECK((Level4Attachment && Level4Attachment->AsBinary() != Level4 && - Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView()))); + REQUIRE(Level4Attachment); + CHECK(Level4Attachment->AsBinary() != Level4); + CHECK(Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView())); CHECK(std::is_sorted(begin(Package.GetAttachments()), end(Package.GetAttachments()))); @@ -932,15 +1028,15 @@ TEST_CASE("usonpackage.serialization") // Out of Order { - CbWriter Writer; - Writer.AddBinary(Level2.GetBuffer()); - Writer.AddObjectAttachment(Level2Hash); - Writer.AddBinary(Level4); - Writer.AddBinaryAttachment(Level4Hash); + CbWriter Writer; + CbAttachment Attachment2(Level2, Level2Hash); + Attachment2.Save(Writer); + CbAttachment Attachment4(Level4); + Attachment4.Save(Writer); + Writer.AddHash(Level1Hash); Writer.AddObject(Level1); - Writer.AddObjectAttachment(Level1Hash); - Writer.AddBinary(Level3.GetBuffer()); - Writer.AddObjectAttachment(Level3Hash); + CbAttachment Attachment3(Level3, Level3Hash); + Attachment3.Save(Writer); Writer.AddNull(); CbFieldIterator Fields = Writer.Save(); @@ -961,11 +1057,9 @@ TEST_CASE("usonpackage.serialization") const MemoryView FieldsOuterBufferView = Fields.GetOuterBuffer().GetView(); CHECK(Level2Attachment->AsObject().Equals(Level2)); - CHECK(FieldsOuterBufferView.Contains(Level2Attachment->AsBinary().GetView())); CHECK(Level2Attachment->GetHash() == Level2Hash); CHECK(Level3Attachment->AsObject().Equals(Level3)); - CHECK(FieldsOuterBufferView.Contains(Level3Attachment->AsBinary().GetView())); CHECK(Level3Attachment->GetHash() == Level3Hash); CHECK(Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView())); @@ -983,21 +1077,23 @@ TEST_CASE("usonpackage.serialization") Writer.Reset(); FromArchive.Save(Writer); CbFieldIterator Saved = Writer.Save(); - CHECK(Saved.AsObject().Equals(Level1)); + + CHECK(Saved.AsHash() == Level1Hash); ++Saved; - CHECK(Saved.AsObjectAttachment() == Level1Hash); + CHECK(Saved.AsObject().Equals(Level1)); ++Saved; - CHECK(Saved.AsBinaryView().EqualBytes(Level2.GetView())); + CHECK_EQ(Saved.AsObjectAttachment(), Level2Hash); ++Saved; - CHECK(Saved.AsObjectAttachment() == Level2Hash); + CHECK(Saved.AsObject().Equals(Level2)); ++Saved; - CHECK(Saved.AsBinaryView().EqualBytes(Level3.GetView())); + CHECK_EQ(Saved.AsObjectAttachment(), Level3Hash); ++Saved; - CHECK(Saved.AsObjectAttachment() == Level3Hash); + CHECK(Saved.AsObject().Equals(Level3)); ++Saved; - CHECK(Saved.AsBinaryView().EqualBytes(Level4.GetView())); + CHECK_EQ(Saved.AsBinaryAttachment(), Level4Hash); ++Saved; - CHECK(Saved.AsBinaryAttachment() == Level4Hash); + SharedBuffer SavedLevel4Buffer = SharedBuffer::MakeView(Saved.AsBinaryView()); + CHECK(SavedLevel4Buffer.GetView().EqualBytes(Level4.GetView())); ++Saved; CHECK(Saved.IsNull()); ++Saved; diff --git a/zencore/compactbinaryvalidation.cpp b/zencore/compactbinaryvalidation.cpp index 52f625313..316da76a6 100644 --- a/zencore/compactbinaryvalidation.cpp +++ b/zencore/compactbinaryvalidation.cpp @@ -416,92 +416,125 @@ ValidateCbPackageField(MemoryView& View, CbValidateMode Mode, CbValidateError& E static IoHash ValidateCbPackageAttachment(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error) { - const CbObjectView ObjectView = Value.AsObjectView(); - if (Value.HasError()) + if (const CbObjectView ObjectView = Value.AsObjectView(); !Value.HasError()) { - const MemoryView BinaryView = Value.AsBinaryView(); - if (Value.HasError() && EnumHasAnyFlags(Mode, CbValidateMode::Package)) + return CbObject().GetHash(); + } + + if (const IoHash ObjectAttachmentHash = Value.AsObjectAttachment(); !Value.HasError()) + { + if (CbFieldView ObjectField = ValidateCbPackageField(View, Mode, Error)) { - if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + const CbObjectView InnerObjectView = ObjectField.AsObjectView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && ObjectField.HasError()) { AddError(Error, CbValidateError::InvalidPackageFormat); } + else if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (ObjectAttachmentHash != InnerObjectView.GetHash())) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + return ObjectAttachmentHash; } - else if (BinaryView.GetSize()) + } + else if (const IoHash BinaryAttachmentHash = Value.AsBinaryAttachment(); !Value.HasError()) + { + if (CbFieldView BinaryField = ValidateCbPackageField(View, Mode, Error)) { - if (EnumHasAnyFlags(Mode, CbValidateMode::Package | CbValidateMode::PackageHash)) + const MemoryView BinaryView = BinaryField.AsBinaryView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && BinaryField.HasError()) { - CompressedBuffer Buffer = CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView)); - if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && Buffer.IsNull()) + AddError(Error, CbValidateError::InvalidPackageFormat); + } + else + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && BinaryView.IsEmpty()) { - AddError(Error, CbValidateError::InvalidPackageFormat); + AddError(Error, CbValidateError::NullPackageAttachment); } - if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && - (IoHash::FromBLAKE3(Buffer.GetRawHash()) != IoHash::HashBuffer(Buffer.DecompressToComposite()))) + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (BinaryAttachmentHash != IoHash::HashBuffer(BinaryView))) { AddError(Error, CbValidateError::InvalidPackageHash); } - return IoHash::FromBLAKE3(Buffer.GetRawHash()); } + return BinaryAttachmentHash; } } - else + else if (const MemoryView BinaryView = Value.AsBinaryView(); !Value.HasError()) { - if (ObjectView) + if (BinaryView.GetSize() > 0) { - if (CbFieldView HashField = ValidateCbPackageField(View, Mode, Error)) + CompressedBuffer Buffer = CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView)); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && Buffer.IsNull()) { - const IoHash Hash = HashField.AsAttachment(); - if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && HashField.HasError()) - { - AddError(Error, CbValidateError::InvalidPackageFormat); - } - if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (Hash != ObjectView.GetHash())) - { - AddError(Error, CbValidateError::InvalidPackageHash); - } - return Hash; + AddError(Error, CbValidateError::NullPackageAttachment); } + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && + (IoHash::FromBLAKE3(Buffer.GetRawHash()) != IoHash::HashBuffer(Buffer.DecompressToComposite()))) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + return IoHash::FromBLAKE3(Buffer.GetRawHash()); } else { - return CbObject().GetHash(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::NullPackageAttachment); + } + return IoHash::HashBuffer(MemoryView()); } } - return {}; -} - -static IoHash -ValidateCbPackageObject(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error) -{ - CbObjectView Object = Value.AsObjectView(); - if (Value.HasError()) + else { if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) { AddError(Error, CbValidateError::InvalidPackageFormat); } } - else if (CbFieldView HashField = ValidateCbPackageField(View, Mode, Error)) + + return IoHash(); +} + +static IoHash +ValidateCbPackageObject(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + if (IoHash RootObjectHash = Value.AsHash(); !Value.HasError() && !Value.IsAttachment()) { - const IoHash Hash = HashField.AsAttachment(); + CbFieldView RootObjectField = ValidateCbPackageField(View, Mode, Error); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) { - if (!Object) - { - AddError(Error, CbValidateError::NullPackageObject); - } - if (HashField.HasError()) + if (RootObjectField.HasError()) { AddError(Error, CbValidateError::InvalidPackageFormat); } - else if (Hash != Value.GetHash()) + } + + const CbObjectView RootObjectView = RootObjectField.AsObjectView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + if (!RootObjectView) { - AddError(Error, CbValidateError::InvalidPackageHash); + AddError(Error, CbValidateError::NullPackageObject); } } - return Hash; + + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (RootObjectHash != RootObjectView.GetHash())) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + + return RootObjectHash; + } + else + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } } + return IoHash(); } @@ -562,24 +595,20 @@ ValidateCompactBinaryPackage(MemoryView View, CbValidateMode Mode) uint32_t ObjectCount = 0; while (CbFieldView Value = ValidateCbPackageField(View, Mode, Error)) { - if (Value.IsBinary()) + if (Value.IsHash() && !Value.IsAttachment()) { - const IoHash Hash = ValidateCbPackageAttachment(Value, View, Mode, Error); - if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + ValidateCbPackageObject(Value, View, Mode, Error); + if (++ObjectCount > 1 && EnumHasAnyFlags(Mode, CbValidateMode::Package)) { - Attachments.push_back(Hash); - if (Value.AsBinaryView().IsEmpty()) - { - AddError(Error, CbValidateError::NullPackageAttachment); - } + AddError(Error, CbValidateError::MultiplePackageObjects); } } - else if (Value.IsObject()) + else if (Value.IsBinary() || Value.IsAttachment() || Value.IsObject()) { - ValidateCbPackageObject(Value, View, Mode, Error); - if (++ObjectCount > 1 && EnumHasAnyFlags(Mode, CbValidateMode::Package)) + const IoHash Hash = ValidateCbPackageAttachment(Value, View, Mode, Error); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) { - AddError(Error, CbValidateError::MultiplePackageObjects); + Attachments.push_back(Hash); } } else if (Value.IsNull()) diff --git a/zencore/httpserver.cpp b/zencore/httpserver.cpp index d7e6a875f..a9096b99b 100644 --- a/zencore/httpserver.cpp +++ b/zencore/httpserver.cpp @@ -481,21 +481,55 @@ HttpServerRequest::ReadPayloadPackage() #if ZEN_PLATFORM_WINDOWS class HttpSysServer; -class HttpTransaction; +class HttpSysTransaction; +class HttpMessageResponseRequest; class HttpSysRequestHandler { public: - HttpSysRequestHandler(HttpTransaction& InRequest) : m_Request(InRequest) {} + HttpSysRequestHandler(HttpSysTransaction& InRequest) : m_Request(InRequest) {} virtual ~HttpSysRequestHandler() = default; virtual void IssueRequest() = 0; virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; - HttpTransaction& Transaction() { return m_Request; } + HttpSysTransaction& Transaction() { return m_Request; } private: - HttpTransaction& m_Request; // Outermost HTTP transaction object + HttpSysTransaction& m_Request; // Outermost HTTP transaction object +}; + +struct InitialRequestHandler : public HttpSysRequestHandler +{ + inline PHTTP_REQUEST HttpRequest() { return (PHTTP_REQUEST)m_RequestBuffer; } + inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + + InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} + ~InitialRequestHandler() {} + + virtual void IssueRequest() override; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + + PHTTP_REQUEST m_HttpRequestPtr = (HTTP_REQUEST*)(m_RequestBuffer); + UCHAR m_RequestBuffer[16384 + sizeof(HTTP_REQUEST)]; +}; + +class HttpSysServerRequest : public HttpServerRequest +{ +public: + HttpSysServerRequest() = default; + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service); + ~HttpSysServerRequest() = default; + + virtual void ReadPayload(std::function<void(HttpServerRequest&, IoBuffer)>&& CompletionHandler) override; + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponse HttpResponseCode) override; + virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + + bool m_IsInitialized = false; + HttpSysTransaction& m_HttpTx; + HttpMessageResponseRequest* m_Response = nullptr; // TODO: make this more general }; /** HTTP transaction @@ -503,12 +537,12 @@ private: There will be an instance of this per pending and in-flight HTTP transaction */ -class HttpTransaction +class HttpSysTransaction { public: - HttpTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler) {} + HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_HttpHandler(&m_InitialHttpHandler) {} - virtual ~HttpTransaction() {} + virtual ~HttpSysTransaction() {} enum class Status { @@ -533,16 +567,15 @@ public: // than one thread at any given moment. This means we need to be careful about what // happens in here - HttpTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpTransaction, m_HttpOverlapped); + HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); - if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpTransaction::Status::kDone) + if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) { delete Transaction; } } - void IssueInitialRequest(); - + void IssueInitialRequest(); PTP_IO Iocp(); HANDLE RequestQueueHandle(); inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } @@ -557,31 +590,41 @@ protected: RwLock m_Lock; private: - struct InitialRequestHandler : public HttpSysRequestHandler - { - inline PHTTP_REQUEST HttpRequest() { return (PHTTP_REQUEST)m_RequestBuffer; } - inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + InitialRequestHandler m_InitialHttpHandler{*this}; +}; - InitialRequestHandler(HttpTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} - ~InitialRequestHandler() {} +////////////////////////////////////////////////////////////////////////// - virtual void IssueRequest() override; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; +class HttpPayloadReadRequest : public HttpSysRequestHandler +{ +public: + HttpPayloadReadRequest(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) {} - PHTTP_REQUEST m_HttpRequestPtr = (HTTP_REQUEST*)(m_RequestBuffer); - UCHAR m_RequestBuffer[16384 + sizeof(HTTP_REQUEST)]; - } m_InitialHttpHandler{*this}; + virtual void IssueRequest() override; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; }; +void +HttpPayloadReadRequest::IssueRequest() +{ +} + +HttpSysRequestHandler* +HttpPayloadReadRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(IoResult, NumberOfBytesTransferred); + return nullptr; +} + ////////////////////////////////////////////////////////////////////////// class HttpMessageResponseRequest : public HttpSysRequestHandler { public: - HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode); - HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, const char* Message); - HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, const void* Payload, size_t PayloadSize); - HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const char* Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const void* Payload, size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs); ~HttpMessageResponseRequest(); virtual void IssueRequest() override; @@ -603,14 +646,15 @@ private: std::vector<IoBuffer> m_DataBuffers; }; -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode) : HttpSysRequestHandler(InRequest) +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) +: HttpSysRequestHandler(InRequest) { std::array<IoBuffer, 0> buffers; Initialize(ResponseCode, buffers); } -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, const char* Message) +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, const char* Message) : HttpSysRequestHandler(InRequest) { IoBuffer MessageBuffer(IoBuffer::Wrap, Message, strlen(Message)); @@ -619,10 +663,10 @@ HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InReques Initialize(ResponseCode, buffers); } -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InRequest, - uint16_t ResponseCode, - const void* Payload, - size_t PayloadSize) +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + const void* Payload, + size_t PayloadSize) : HttpSysRequestHandler(InRequest) { IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); @@ -631,7 +675,7 @@ HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InReques Initialize(ResponseCode, buffers); } -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs) +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::span<IoBuffer> Blobs) : HttpSysRequestHandler(InRequest) { Initialize(ResponseCode, Blobs); @@ -708,7 +752,9 @@ HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfB ZEN_UNUSED(IoResult); if (m_RemainingChunkCount == 0) + { return nullptr; // All done + } return this; } @@ -716,7 +762,7 @@ HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfB void HttpMessageResponseRequest::IssueRequest() { - HttpTransaction& Tx = Transaction(); + HttpSysTransaction& Tx = Transaction(); HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); PTP_IO const Iocp = Tx.Iocp(); @@ -828,7 +874,7 @@ HttpMessageResponseRequest::IssueRequest() class HttpSysServer { - friend class HttpTransaction; + friend class HttpSysTransaction; public: HttpSysServer(WinIoThreadPool& InThreadPool); @@ -945,7 +991,7 @@ HttpSysServer::Initialize(const wchar_t* UrlPath) // Create I/O completion port - m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpTransaction::IoCompletionCallback, this); + m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, this); // Check result! } @@ -1015,7 +1061,7 @@ HttpSysServer::IssueNewRequestMaybe() return; } - std::unique_ptr<HttpTransaction> Request = std::make_unique<HttpTransaction>(*this); + std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this); Request->IssueInitialRequest(); @@ -1083,227 +1129,235 @@ HttpSysServer::RemoveEndpoint(const char* UrlPath, HttpService& Service) ////////////////////////////////////////////////////////////////////////// -class HttpSysServerRequest : public HttpServerRequest +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service) : m_IsInitialized(true), m_HttpTx(Tx) { -public: - HttpSysServerRequest(HttpTransaction& Tx, HttpService& Service) : m_HttpTx(Tx) - { - PHTTP_REQUEST HttpRequestPtr = Tx.HttpRequest(); + PHTTP_REQUEST HttpRequestPtr = Tx.HttpRequest(); - const int PrefixLength = Service.UriPrefixLength(); - const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(char16_t); - if (AbsPathLength >= PrefixLength) - { - // We convert the URI immediately because most of the code involved prefers to deal - // with utf8. This has some performance impact which I'd prefer to avoid but for now - // we just have to live with it + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This has some performance impact which I'd prefer to avoid but for now + // we just have to live with it - WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, - m_Uri); - } - else - { - m_Uri.Reset(); - } + WideToUtf8({(char16_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, + m_Uri); + } + else + { + m_Uri.Reset(); + } - if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) - { - --QueryStringLength; + if (auto QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; - WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryString); - } - else - { - m_QueryString.Reset(); - } + WideToUtf8({(char16_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(char16_t)}, m_QueryString); + } + else + { + m_QueryString.Reset(); + } - switch (HttpRequestPtr->Verb) - { - case HttpVerbOPTIONS: - m_Verb = HttpVerb::kOptions; - break; - - case HttpVerbGET: - m_Verb = HttpVerb::kGet; - break; - - case HttpVerbHEAD: - m_Verb = HttpVerb::kHead; - break; - - case HttpVerbPOST: - m_Verb = HttpVerb::kPost; - break; - - case HttpVerbPUT: - m_Verb = HttpVerb::kPut; - break; - - case HttpVerbDELETE: - m_Verb = HttpVerb::kDelete; - break; - - case HttpVerbCOPY: - m_Verb = HttpVerb::kCopy; - break; - - default: - // TODO: invalid request? - m_Verb = (HttpVerb)0; - break; - } + switch (HttpRequestPtr->Verb) + { + case HttpVerbOPTIONS: + m_Verb = HttpVerb::kOptions; + break; - const HTTP_KNOWN_HEADER& clh = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentLength]; - std::string_view cl(clh.pRawValue, clh.RawValueLength); - std::from_chars(cl.data(), cl.data() + cl.size(), m_ContentLength); + case HttpVerbGET: + m_Verb = HttpVerb::kGet; + break; - const HTTP_KNOWN_HEADER& CtHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentType]; - m_ContentType = MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); - } + case HttpVerbHEAD: + m_Verb = HttpVerb::kHead; + break; - ~HttpSysServerRequest() {} + case HttpVerbPOST: + m_Verb = HttpVerb::kPost; + break; - virtual IoBuffer ReadPayload() override - { - // This is presently synchronous for simplicity, but we - // need to implement an asynchronous version also + case HttpVerbPUT: + m_Verb = HttpVerb::kPut; + break; - HTTP_REQUEST* const HttpReq = m_HttpTx.HttpRequest(); + case HttpVerbDELETE: + m_Verb = HttpVerb::kDelete; + break; - IoBuffer PayloadBuffer(m_ContentLength); + case HttpVerbCOPY: + m_Verb = HttpVerb::kCopy; + break; - HttpContentType ContentType = RequestContentType(); - PayloadBuffer.SetContentType(ContentType); + default: + // TODO: invalid request? + m_Verb = (HttpVerb)0; + break; + } - uint64_t BytesToRead = m_ContentLength; + const HTTP_KNOWN_HEADER& clh = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + std::from_chars(cl.data(), cl.data() + cl.size(), m_ContentLength); - uint8_t* ReadPointer = reinterpret_cast<uint8_t*>(PayloadBuffer.MutableData()); + const HTTP_KNOWN_HEADER& CtHdr = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderContentType]; + m_ContentType = MapContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +} - // First deal with any payload which has already been copied - // into our request buffer +void +HttpSysServerRequest::ReadPayload(std::function<void(HttpServerRequest&, IoBuffer)>&& CompletionHandler) +{ + ZEN_UNUSED(CompletionHandler); +} - const int EntityChunkCount = HttpReq->EntityChunkCount; +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + // This is presently synchronous for simplicity, but we + // need to implement an asynchronous version also - for (int i = 0; i < EntityChunkCount; ++i) - { - HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + HTTP_REQUEST* const HttpReq = m_HttpTx.HttpRequest(); - ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + IoBuffer PayloadBuffer(m_ContentLength); - const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + HttpContentType ContentType = RequestContentType(); + PayloadBuffer.SetContentType(ContentType); - ZEN_ASSERT(BufferLength <= BytesToRead); + uint64_t BytesToRead = m_ContentLength; - memcpy(ReadPointer, EntityChunk.FromMemory.pBuffer, BufferLength); + uint8_t* ReadPointer = reinterpret_cast<uint8_t*>(PayloadBuffer.MutableData()); - ReadPointer += BufferLength; - BytesToRead -= BufferLength; - } + // First deal with any payload which has already been copied + // into our request buffer - // Call http.sys API to receive the remaining data + const int EntityChunkCount = HttpReq->EntityChunkCount; - static const uint64_t kMaxBytesPerApiCall = 1 * 1024 * 1024; + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; - while (BytesToRead) - { - ULONG BytesRead = 0; + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); - const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; - ULONG ApiResult = HttpReceiveRequestEntityBody(m_HttpTx.RequestQueueHandle(), - HttpReq->RequestId, - 0, /* Flags */ - ReadPointer, - gsl::narrow<ULONG>(BytesToReadThisCall), - &BytesRead, - NULL /* Overlapped */ - ); + ZEN_ASSERT(BufferLength <= BytesToRead); - if (ApiResult != NO_ERROR && ApiResult != ERROR_HANDLE_EOF) - { - throw HttpServerException("payload read failed", ApiResult); - } + memcpy(ReadPointer, EntityChunk.FromMemory.pBuffer, BufferLength); - BytesToRead -= BytesRead; - ReadPointer += BytesRead; - } + ReadPointer += BufferLength; + BytesToRead -= BufferLength; + } + if (BytesToRead == 0) + { PayloadBuffer.MakeImmutable(); return PayloadBuffer; } - virtual void WriteResponse(HttpResponse HttpResponseCode) override + // Call http.sys API to receive the remaining data SYNCHRONOUSLY + + static const uint64_t kMaxBytesPerApiCall = 1 * 1024 * 1024; + + while (BytesToRead) { - ZEN_ASSERT(m_IsHandled == false); + ULONG BytesRead = 0; + + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode); + ULONG ApiResult = HttpReceiveRequestEntityBody(m_HttpTx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + ReadPointer, + gsl::narrow<ULONG>(BytesToReadThisCall), + &BytesRead, + NULL /* Overlapped */ + ); - if (m_SuppressBody) + if (ApiResult != NO_ERROR && ApiResult != ERROR_HANDLE_EOF) { - m_Response->SuppressResponseBody(); + throw HttpServerException("payload read failed", ApiResult); } - m_IsHandled = true; + BytesToRead -= BytesRead; + ReadPointer += BytesRead; } - virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override - { - ZEN_ASSERT(m_IsHandled == false); - ZEN_UNUSED(ContentType); + PayloadBuffer.MakeImmutable(); - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, Blobs); + return PayloadBuffer; +} - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode) +{ + ZEN_ASSERT(m_IsHandled == false); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode); - m_IsHandled = true; + if (m_SuppressBody) + { + m_Response->SuppressResponseBody(); } - virtual void WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override + m_IsHandled = true; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(m_IsHandled == false); + ZEN_UNUSED(ContentType); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, Blobs); + + if (m_SuppressBody) { - ZEN_ASSERT(m_IsHandled == false); - ZEN_UNUSED(ContentType); + m_Response->SuppressResponseBody(); + } - m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ResponseString.data(), ResponseString.size()); + m_IsHandled = true; +} - if (m_SuppressBody) - { - m_Response->SuppressResponseBody(); - } +void +HttpSysServerRequest::WriteResponse(HttpResponse HttpResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(m_IsHandled == false); + ZEN_UNUSED(ContentType); + + m_Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)HttpResponseCode, ResponseString.data(), ResponseString.size()); - m_IsHandled = true; + if (m_SuppressBody) + { + m_Response->SuppressResponseBody(); } - HttpTransaction& m_HttpTx; - HttpMessageResponseRequest* m_Response = nullptr; -}; + m_IsHandled = true; +} ////////////////////////////////////////////////////////////////////////// PTP_IO -HttpTransaction::Iocp() +HttpSysTransaction::Iocp() { return m_HttpServer.m_ThreadPool.Iocp(); } HANDLE -HttpTransaction::RequestQueueHandle() +HttpSysTransaction::RequestQueueHandle() { return m_HttpServer.m_RequestQueueHandle; } void -HttpTransaction::IssueInitialRequest() +HttpSysTransaction::IssueInitialRequest() { m_InitialHttpHandler.IssueRequest(); } -HttpTransaction::Status -HttpTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +HttpSysTransaction::Status +HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { // We use this to ensure sequential execution of completion handlers // for any given transaction. @@ -1347,6 +1401,7 @@ HttpTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransfe } } + // Ensure new requests are enqueued m_HttpServer.IssueNewRequestMaybe(); if (RequestPending) @@ -1360,13 +1415,13 @@ HttpTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransfe ////////////////////////////////////////////////////////////////////////// void -HttpTransaction::InitialRequestHandler::IssueRequest() +InitialRequestHandler::IssueRequest() { PTP_IO Iocp = Transaction().Iocp(); StartThreadpoolIo(Iocp); - HttpTransaction& Tx = Transaction(); + HttpSysTransaction& Tx = Transaction(); HTTP_REQUEST* HttpReq = Tx.HttpRequest(); @@ -1389,14 +1444,14 @@ HttpTransaction::InitialRequestHandler::IssueRequest() // CleanupHttpIoRequest(pIoRequest); - fprintf(stderr, "HttpReceiveHttpRequest failed, error 0x%lx\n", Result); + spdlog::error("HttpReceiveHttpRequest failed, error {:x}", Result); return; } } HttpSysRequestHandler* -HttpTransaction::InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { ZEN_UNUSED(IoResult); ZEN_UNUSED(NumberOfBytesTransferred); @@ -1536,10 +1591,28 @@ HttpRequestRouter::AddPattern(const char* Id, const char* Regex) void HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) { - // Expand patterns + ExtendableStringBuilder<128> ExpandedRegex; + ProcessRegexSubstitutions(Regex, ExpandedRegex); + m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex); +} + +void +HttpRequestRouter::RegisterRoute(const char* Regex, PackageEndpointHandler& Handler) +{ ExtendableStringBuilder<128> ExpandedRegex; + ProcessRegexSubstitutions(Regex, ExpandedRegex); + + m_Handlers.emplace_back( + ExpandedRegex.c_str(), + HttpVerb::kPost, + [&Handler](HttpRouterRequest& Request) { Handler.HandleRequest(Request); }, + Regex); +} +void +HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) +{ size_t RegexLen = strlen(Regex); for (size_t i = 0; i < RegexLen;) @@ -1558,13 +1631,13 @@ HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFu if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) { - ExpandedRegex.Append(it->second.c_str()); + OutExpandedRegex.Append(it->second.c_str()); } else { // Default to anything goes (or should this just be an error?) - ExpandedRegex.Append("(.+?)"); + OutExpandedRegex.Append("(.+?)"); } // skip ahead @@ -1579,11 +1652,9 @@ HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFu if (!matched) { - ExpandedRegex.Append(Regex[i++]); + OutExpandedRegex.Append(Regex[i++]); } } - - m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex); } bool diff --git a/zencore/include/zencore/compactbinary.h b/zencore/include/zencore/compactbinary.h index e20679317..09619be8b 100644 --- a/zencore/include/zencore/compactbinary.h +++ b/zencore/include/zencore/compactbinary.h @@ -1096,15 +1096,15 @@ public: /** Access the field as an object. Defaults to an empty object on error. */ inline CbObject AsObject() &; - - /** Access the field as an object. Defaults to an empty object on error. */ inline CbObject AsObject() &&; /** Access the field as an array. Defaults to an empty array on error. */ inline CbArray AsArray() &; - - /** Access the field as an array. Defaults to an empty array on error. */ inline CbArray AsArray() &&; + + /** Access the field as binary. Returns the provided default on error. */ + inline SharedBuffer AsBinary(const SharedBuffer& Default = SharedBuffer()) &; + inline SharedBuffer AsBinary(const SharedBuffer& Default = SharedBuffer()) &&; }; /** @@ -1266,6 +1266,20 @@ CbField::AsArray() && return IsArray() ? CbArray(AsArrayView(), std::move(*this)) : CbArray(); } +inline SharedBuffer +CbField::AsBinary(const SharedBuffer& Default) & +{ + const MemoryView View = AsBinaryView(); + return !HasError() ? SharedBuffer::MakeView(View, GetOuterBuffer()) : Default; +} + +inline SharedBuffer +CbField::AsBinary(const SharedBuffer& Default) && +{ + const MemoryView View = AsBinaryView(); + return !HasError() ? SharedBuffer::MakeView(View, std::move(*this).GetOuterBuffer()) : Default; +} + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// /** diff --git a/zencore/include/zencore/compactbinarypackage.h b/zencore/include/zencore/compactbinarypackage.h index d60155d1a..57624a3ab 100644 --- a/zencore/include/zencore/compactbinarypackage.h +++ b/zencore/include/zencore/compactbinarypackage.h @@ -38,18 +38,27 @@ public: CbAttachment() = default; /** Construct a compact binary attachment. Value is cloned if not owned. */ - inline explicit CbAttachment(const CbObject& Value) : CbAttachment(Value, nullptr) {} + inline explicit CbAttachment(const CbObject& InValue) : CbAttachment(InValue, nullptr) {} /** Construct a compact binary attachment. Value is cloned if not owned. Hash must match Value. */ - inline explicit CbAttachment(const CbObject& Value, const IoHash& Hash) : CbAttachment(Value, &Hash) {} + inline explicit CbAttachment(const CbObject& InValue, const IoHash& Hash) : CbAttachment(InValue, &Hash) {} - /** Construct a binary attachment. Value is cloned if not owned. */ - ZENCORE_API explicit CbAttachment(const SharedBuffer& Value); + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue); - /** Construct a binary attachment. Value is cloned if not owned. Hash must match Value. */ - ZENCORE_API explicit CbAttachment(const SharedBuffer& Value, const IoHash& Hash); + /** Construct a raw binary attachment. Value is cloned if not owned. Hash must match Value. */ + ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue, const IoHash& Hash); - /** Construct a binary attachment. Value is cloned if not owned. */ + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(const CompositeBuffer& InValue); + + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue); + + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue, const IoHash& Hash); + + /** Construct a compressed binary attachment. Value is cloned if not owned. */ ZENCORE_API explicit CbAttachment(const CompressedBuffer& InValue); ZENCORE_API explicit CbAttachment(CompressedBuffer&& InValue); @@ -66,13 +75,19 @@ public: ZENCORE_API [[nodiscard]] SharedBuffer AsBinary() const; /** Access the attachment as compressed binary. Defaults to a null buffer if the attachment is null. */ + ZENCORE_API [[nodiscard]] CompositeBuffer AsCompositeBinary() const; + + /** Access the attachment as compressed binary. Defaults to a null buffer if the attachment is null. */ ZENCORE_API [[nodiscard]] CompressedBuffer AsCompressedBinary() const; /** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */ ZENCORE_API [[nodiscard]] CbObject AsObject() const; - /** Returns true if the attachment is either binary or an object */ - [[nodiscard]] inline bool IsBinary() const { return !IsNull(); } + /** Returns true if the attachment is binary */ + ZENCORE_API [[nodiscard]] bool IsBinary() const; + + /** Returns true if the attachment is compressed binary */ + ZENCORE_API [[nodiscard]] bool IsCompressedBinary() const; /** Returns whether the attachment is an object. */ ZENCORE_API [[nodiscard]] bool IsObject() const; @@ -122,7 +137,19 @@ private: CbObjectValue(CbObject&& InObject, const IoHash& InHash) : Object(std::move(InObject)), Hash(InHash) {} }; - std::variant<CompressedBuffer, CbObjectValue> Value; + struct BinaryValue + { + CompositeBuffer Buffer; + IoHash Hash; + + BinaryValue(const CompositeBuffer& InBuffer) : Buffer(InBuffer.MakeOwned()), Hash(IoHash::HashBuffer(InBuffer)) {} + BinaryValue(const CompositeBuffer& InBuffer, const IoHash& InHash) : Buffer(InBuffer.MakeOwned()), Hash(InHash) {} + BinaryValue(CompositeBuffer&& InBuffer) : Buffer(std::move(InBuffer)), Hash(IoHash::HashBuffer(Buffer)) {} + BinaryValue(CompositeBuffer&& InBuffer, const IoHash& InHash) : Buffer(std::move(InBuffer)), Hash(InHash) {} + BinaryValue(SharedBuffer&& InBuffer, const IoHash& InHash) : Buffer(std::move(InBuffer)), Hash(InHash) {} + }; + + std::variant<nullptr_t, CbObjectValue, BinaryValue, CompressedBuffer> Value; }; /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/zencore/include/zencore/compactbinaryvalidation.h b/zencore/include/zencore/compactbinaryvalidation.h index 9799c594a..b1fab9572 100644 --- a/zencore/include/zencore/compactbinaryvalidation.h +++ b/zencore/include/zencore/compactbinaryvalidation.h @@ -58,10 +58,13 @@ enum class CbValidateMode : uint32_t Padding = 1 << 3, /** - * Validate that a package or attachment has the expected fields and matches its saved hashes. + * Validate that a package or attachment has the expected fields. */ Package = 1 << 4, + /** + * Validate that a package or attachment matches its saved hashes. + */ PackageHash = 1 << 5, /** Perform all validation described above. */ diff --git a/zencore/include/zencore/httpserver.h b/zencore/include/zencore/httpserver.h index d4d9e21e0..a0be54665 100644 --- a/zencore/include/zencore/httpserver.h +++ b/zencore/include/zencore/httpserver.h @@ -228,6 +228,8 @@ public: */ virtual IoBuffer ReadPayload() = 0; + virtual void ReadPayload(std::function<void(HttpServerRequest&, IoBuffer)>&& CompletionHandler) = 0; + ZENCORE_API CbObject ReadPayloadObject(); ZENCORE_API CbPackage ReadPayloadPackage(); @@ -346,11 +348,31 @@ HttpRouterRequest::GetCapture(uint32_t Index) const ////////////////////////////////////////////////////////////////////////// +class PackageRequestContext +{ +public: + PackageRequestContext(); + ~PackageRequestContext(); + +private: +}; + +class PackageEndpointHandler +{ +public: + virtual void HandleRequest(HttpRouterRequest& Request) = 0; + +private: +}; + /** HTTP request router helper * * This helper class allows a service implementer to register one or more * endpoints using pattern matching (currently using regex matching) * + * This is intended to be initialized once only, there is no thread + * safety so you can absolutely not add or remove endpoints once the handler + * goes live */ class HttpRequestRouter @@ -358,8 +380,37 @@ class HttpRequestRouter public: typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; + /** + * @brief Add pattern which can be referenced by name, commonly used for URL components + * @param Id String used to identify patterns for replacement + * @param Regex String which will replace the Id string in any registered URL paths + */ void AddPattern(const char* Id, const char* Regex); + + /** + * @brief Register a an endpoint handler for the given route + * @param Regex Regular expression used to match the handler to a request. This may + * contain pattern aliases registered via AddPattern + * @param HandlerFunc Handler function to call for any matching request + * @param SupportedVerbs Supported HTTP verbs for this handler + */ void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs); + /** + * @brief Register CbPackage endpoint handler + * @param Regex Regular expression used to match the handler to a request. This may + * contain pattern aliases registered via AddPattern + * @param Handler Package handler instance + */ + void RegisterRoute(const char* Regex, PackageEndpointHandler& Handler); + + void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); + + /** + * @brief HTTP request handling function - this should be called to route the + * request to a registered handler + * @param Request Request to route to a handler + * @return Function returns true if the request was routed successfully + */ bool HandleRequest(zen::HttpServerRequest& Request); private: @@ -379,6 +430,10 @@ private: HttpVerb Verbs; HandlerFunc_t Handler; const char* Pattern; + + private: + HandlerEntry& operator=(const HandlerEntry&) = delete; + HandlerEntry(const HandlerEntry&) = delete; }; std::list<HandlerEntry> m_Handlers; diff --git a/zenserver-test/zenserver-test.cpp b/zenserver-test/zenserver-test.cpp index 455ab2495..e71b7f730 100644 --- a/zenserver-test/zenserver-test.cpp +++ b/zenserver-test/zenserver-test.cpp @@ -1380,4 +1380,97 @@ TEST_CASE("mesh.basic") } } +class ZenServerTestHelper +{ +public: + ZenServerTestHelper(std::string_view HelperId, int ServerCount) + : m_HelperId{HelperId} + , m_ServerCount{ServerCount} + { + } + ~ZenServerTestHelper() {} + + void SpawnServers() + { + SpawnServers([](ZenServerInstance&) {}); + } + + void SpawnServers(auto&& Callback) + { + spdlog::info("{}: spawning {} server instances", m_HelperId, m_ServerCount); + + m_Instances.resize(m_ServerCount); + + for (int i = 0; i < m_ServerCount; ++i) + { + auto& Instance = m_Instances[i]; + + Instance = std::make_unique<ZenServerInstance>(TestEnv); + Instance->SetTestDir(TestEnv.CreateNewTestDir()); + + Callback(*Instance); + + Instance->SpawnServer(13337 + i); + } + + for (int i = 0; i < m_ServerCount; ++i) + { + auto& Instance = m_Instances[i]; + + Instance->WaitUntilReady(); + } + } + + ZenServerInstance& GetInstance(int Index) { return *m_Instances[Index]; } + +private: + std::string m_HelperId; + int m_ServerCount = 0; + std::vector<std::unique_ptr<ZenServerInstance> > m_Instances; +}; + +TEST_CASE("http.basics") +{ + using namespace std::literals; + + ZenServerTestHelper Servers{"http.basics"sv, 1}; + Servers.SpawnServers(); + + ZenServerInstance& Instance = Servers.GetInstance(0); + const std::string BaseUri = Instance.GetBaseUri(); + + { + cpr::Response r = cpr::Get(cpr::Url{"{}/testing/hello"_format(BaseUri)}); + CHECK_EQ(r.status_code, 200); + } + + { + cpr::Response r = cpr::Post(cpr::Url{"{}/testing/hello"_format(BaseUri)}); + CHECK_EQ(r.status_code, 404); + } + + { + cpr::Response r = cpr::Post(cpr::Url{"{}/testing/echo"_format(BaseUri)}, cpr::Body{"yoyoyoyo"}); + CHECK_EQ(r.status_code, 200); + CHECK_EQ(r.text, "yoyoyoyo"); + } +} + +TEST_CASE("http.package") +{ + using namespace std::literals; + + ZenServerTestHelper Servers{"http.package"sv, 1}; + Servers.SpawnServers(); + + ZenServerInstance& Instance = Servers.GetInstance(0); + const std::string BaseUri = Instance.GetBaseUri(); + + { + cpr::Response r = cpr::Post(cpr::Url{"{}/testing/package"_format(BaseUri)}, cpr::Body{"yoyoyoyo"}); + CHECK_EQ(r.status_code, 200); + CHECK_EQ(r.text, "yoyoyoyo"); + } +} + #endif diff --git a/zenserver/testing/httptest.cpp b/zenserver/testing/httptest.cpp new file mode 100644 index 000000000..0639c2b53 --- /dev/null +++ b/zenserver/testing/httptest.cpp @@ -0,0 +1,48 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httptest.h" + +namespace zen { + +HttpTestingService::HttpTestingService() +{ + m_Router.RegisterRoute( + "hello", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponse::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [this](HttpRouterRequest& Req) { + IoBuffer Body = Req.ServerRequest().ReadPayload(); + Req.ServerRequest().WriteResponse(HttpResponse::OK, HttpContentType::kBinary, Body); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute("package", m_PackageHandler); +} + +HttpTestingService::~HttpTestingService() +{ +} + +const char* +HttpTestingService::BaseUri() const +{ + return "/testing/"; +} + +void +HttpTestingService::HandleRequest(HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +void +HttpTestingService::PackageHandler::HandleRequest(HttpRouterRequest& Req) +{ + IoBuffer Body = Req.ServerRequest().ReadPayload(); + Req.ServerRequest().WriteResponse(HttpResponse::OK, HttpContentType::kBinary, Body); +} + +} // namespace zen diff --git a/zenserver/testing/httptest.h b/zenserver/testing/httptest.h new file mode 100644 index 000000000..236d17ce7 --- /dev/null +++ b/zenserver/testing/httptest.h @@ -0,0 +1,34 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/httpserver.h> + +#include <spdlog/spdlog.h> + +namespace zen { + +/** + * Test service to facilitate testing the HTTP framework and client interactions + */ +class HttpTestingService : public HttpService +{ +public: + HttpTestingService(); + ~HttpTestingService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + +private: + HttpRequestRouter m_Router; + + struct PackageHandler : public PackageEndpointHandler + { + virtual void HandleRequest(HttpRouterRequest& Request) override; + }; + + PackageHandler m_PackageHandler; +}; + +} // namespace zen diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp index c857d4c71..57e691ea1 100644 --- a/zenserver/zenserver.cpp +++ b/zenserver/zenserver.cpp @@ -72,6 +72,7 @@ #include "diag/diagsvcs.h" #include "experimental/usnjournal.h" #include "projectstore.h" +#include "testing/httptest.h" #include "testing/launch.h" #include "upstream/jupiter.h" #include "upstream/upstreamcache.h" @@ -205,7 +206,7 @@ public: else { UpstreamCache.reset(); - spdlog::info("upstream cache NOT active"); + spdlog::info("NOT using upstream cache"); } } @@ -221,10 +222,20 @@ public: { StartMesh(BasePort); } + else + { + spdlog::info("NOT starting mesh"); + } m_Http.Initialize(BasePort); m_Http.AddEndpoint(m_HealthService); - m_Http.AddEndpoint(m_TestService); + + if (m_TestMode) + { + m_Http.AddEndpoint(m_TestService); + m_Http.AddEndpoint(m_TestingService); + } + m_Http.AddEndpoint(m_AdminService); if (m_HttpProjectService) @@ -365,6 +376,7 @@ private: zen::CasGc m_Gc{*m_CasStore}; zen::CasScrubber m_Scrubber{*m_CasStore}; HttpTestService m_TestService; + zen::HttpTestingService m_TestingService; zen::HttpCasService m_CasService{*m_CasStore}; zen::RefPtr<zen::ProjectStore> m_ProjectStore; zen::Ref<zen::LocalProjectService> m_LocalProjectService; diff --git a/zenserver/zenserver.vcxproj b/zenserver/zenserver.vcxproj index e4b40e13e..3c907e2fb 100644 --- a/zenserver/zenserver.vcxproj +++ b/zenserver/zenserver.vcxproj @@ -110,6 +110,7 @@ <ClInclude Include="config.h" /> <ClInclude Include="diag\logging.h" /> <ClInclude Include="sos\sos.h" /> + <ClInclude Include="testing\httptest.h" /> <ClInclude Include="upstream\jupiter.h" /> <ClInclude Include="projectstore.h" /> <ClInclude Include="cache\cacheagent.h" /> @@ -132,6 +133,7 @@ <ClCompile Include="projectstore.cpp" /> <ClCompile Include="cache\cacheagent.cpp" /> <ClCompile Include="sos\sos.cpp" /> + <ClCompile Include="testing\httptest.cpp" /> <ClCompile Include="upstream\jupiter.cpp" /> <ClCompile Include="testing\launch.cpp" /> <ClCompile Include="cache\cachestore.cpp" /> diff --git a/zenserver/zenserver.vcxproj.filters b/zenserver/zenserver.vcxproj.filters index ca16caf77..3a17cbb07 100644 --- a/zenserver/zenserver.vcxproj.filters +++ b/zenserver/zenserver.vcxproj.filters @@ -38,6 +38,7 @@ <ClInclude Include="upstream\upstreamcache.h"> <Filter>upstream</Filter> </ClInclude> + <ClInclude Include="testing\httptest.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="zenserver.cpp" /> @@ -71,6 +72,7 @@ <ClCompile Include="upstream\upstreamcache.cpp"> <Filter>upstream</Filter> </ClCompile> + <ClCompile Include="testing\httptest.cpp" /> </ItemGroup> <ItemGroup> <Filter Include="cache"> diff --git a/zenutil/include/zenserverprocess.h b/zenutil/include/zenserverprocess.h index d0093537e..f7d911a87 100644 --- a/zenutil/include/zenserverprocess.h +++ b/zenutil/include/zenserverprocess.h @@ -55,6 +55,8 @@ struct ZenServerInstance void AttachToRunningServer(int BasePort = 0); + std::string GetBaseUri() const; + private: ZenServerEnvironment& m_Env; zen::ProcessHandle m_Process; @@ -63,6 +65,7 @@ private: bool m_Terminate = false; std::filesystem::path m_TestDir; bool m_MeshEnabled = false; + int m_BasePort = 0; void CreateShutdownEvent(int BasePort); }; diff --git a/zenutil/zenserverprocess.cpp b/zenutil/zenserverprocess.cpp index 4e45ddfae..d0dd0106b 100644 --- a/zenutil/zenserverprocess.cpp +++ b/zenutil/zenserverprocess.cpp @@ -403,6 +403,7 @@ ZenServerInstance::SpawnServer(int BasePort) if (BasePort) { CommandLine << " --port " << BasePort; + m_BasePort = BasePort; } if (!m_TestDir.empty()) @@ -418,7 +419,7 @@ ZenServerInstance::SpawnServer(int BasePort) std::filesystem::path CurrentDirectory = std::filesystem::current_path(); - spdlog::debug("Spawning server"); + spdlog::debug("Spawning server '{}'", LogId); PROCESS_INFORMATION ProcessInfo{}; STARTUPINFO StartupInfo{.cb = sizeof(STARTUPINFO)}; @@ -492,7 +493,7 @@ ZenServerInstance::SpawnServer(int BasePort) } } - spdlog::debug("Server spawned OK"); + spdlog::debug("Server '{}' spawned OK", LogId); if (IsTest) { @@ -558,3 +559,13 @@ ZenServerInstance::WaitUntilReady(int Timeout) { return m_ReadyEvent.Wait(Timeout); } + +std::string +ZenServerInstance::GetBaseUri() const +{ + ZEN_ASSERT(m_BasePort); + + using namespace fmt::literals; + + return "http://localhost:{}"_format(m_BasePort); +} |