diff options
Diffstat (limited to 'src/zencore')
113 files changed, 15049 insertions, 1986 deletions
diff --git a/src/zencore/basicfile.cpp b/src/zencore/basicfile.cpp new file mode 100644 index 000000000..6989da67e --- /dev/null +++ b/src/zencore/basicfile.cpp @@ -0,0 +1,1067 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/basicfile.h> + +#include <zencore/compactbinary.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/memory/memory.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#else +# include <fcntl.h> +# include <sys/file.h> +# include <sys/stat.h> +# include <unistd.h> +#endif + +#include <fmt/format.h> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +BasicFile::~BasicFile() +{ + Close(); +} +BasicFile::BasicFile(const std::filesystem::path& FileName, Mode Mode) +{ + Open(FileName, Mode); +} + +BasicFile::BasicFile(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec) +{ + Open(FileName, Mode, Ec); +} + +BasicFile::BasicFile(const std::filesystem::path& FileName, Mode Mode, std::function<bool(std::error_code& Ec)>&& RetryCallback) +{ + Open(FileName, Mode, std::move(RetryCallback)); +} + +void +BasicFile::Open(const std::filesystem::path& FileName, Mode Mode) +{ + std::error_code Ec; + Open(FileName, Mode, Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to open file '{}', mode: {:x}", FileName, uint32_t(Mode))); + } +} + +void +BasicFile::Open(const std::filesystem::path& FileName, Mode InMode, std::error_code& Ec) +{ + Ec.clear(); + + Mode Mode = InMode & Mode::kModeMask; + +#if ZEN_PLATFORM_WINDOWS + DWORD dwCreationDisposition = 0; + DWORD dwDesiredAccess = 0; + switch (Mode) + { + case Mode::kRead: + dwCreationDisposition |= OPEN_EXISTING; + dwDesiredAccess |= GENERIC_READ; + break; + case Mode::kWrite: + dwCreationDisposition |= OPEN_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE); + break; + case Mode::kDelete: + dwCreationDisposition |= OPEN_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE); + break; + case Mode::kTruncate: + dwCreationDisposition |= CREATE_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE); + break; + case Mode::kTruncateDelete: + dwCreationDisposition |= CREATE_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE); + break; + } + + const DWORD dwShareMode = FILE_SHARE_READ | (EnumHasAllFlags(InMode, Mode::kPreventWrite) ? 0 : FILE_SHARE_WRITE) | + (EnumHasAllFlags(InMode, Mode::kPreventDelete) ? 0 : FILE_SHARE_DELETE); + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; + const HANDLE hTemplateFile = nullptr; + const HANDLE FileHandle = CreateFile(FileName.c_str(), + dwDesiredAccess, + dwShareMode, + /* lpSecurityAttributes */ nullptr, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + Ec = MakeErrorCodeFromLastError(); + + return; + } +#else + int OpenFlags = O_CLOEXEC; + switch (Mode) + { + case Mode::kRead: + OpenFlags |= O_RDONLY; + break; + case Mode::kWrite: + case Mode::kDelete: + OpenFlags |= (O_RDWR | O_CREAT); + break; + case Mode::kTruncate: + case Mode::kTruncateDelete: + OpenFlags |= (O_RDWR | O_CREAT | O_TRUNC); + break; + } + + int Fd = open(FileName.c_str(), OpenFlags, 0666); + if (Fd < 0) + { + Ec = MakeErrorCodeFromLastError(); + return; + } + if (Mode != Mode::kRead) + { + fchmod(Fd, 0666); + } + + void* FileHandle = (void*)(uintptr_t(Fd)); +#endif + + m_FileHandle = FileHandle; +} + +void +BasicFile::Open(const std::filesystem::path& FileName, Mode Mode, std::function<bool(std::error_code& Ec)>&& RetryCallback) +{ + std::error_code Ec; + Open(FileName, Mode, Ec); + while (Ec && RetryCallback(Ec)) + { + Ec.clear(); + Open(FileName, Mode, Ec); + } + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to open file '{}', mode: {:x}", FileName, uint32_t(Mode))); + } +} + +void +BasicFile::Close() +{ + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + ::CloseHandle(m_FileHandle); +#else + int Fd = int(uintptr_t(m_FileHandle)); + close(Fd); +#endif + m_FileHandle = nullptr; + } +} + +IoBuffer +BasicFile::ReadRange(uint64_t FileOffset, uint64_t ByteCount) +{ + return IoBufferBuilder::MakeFromFileHandle(m_FileHandle, FileOffset, ByteCount); +} + +void +BasicFile::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset) +{ + const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; + std::error_code Ec; + ReadFile(m_FileHandle, Data, BytesToRead, FileOffset, MaxChunkSize, Ec); + if (Ec) + { + std::error_code DummyEc; + throw std::system_error(Ec, + fmt::format("BasicFile::Read: ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})", + FileOffset, + BytesToRead, + PathFromHandle(m_FileHandle, DummyEc), + FileSizeFromHandle(m_FileHandle))); + } +} + +IoBuffer +BasicFile::ReadAll() +{ + if (const uint64_t Size = FileSize()) + { + IoBuffer Buffer(Size); + Read(Buffer.MutableData(), Size, 0); + return Buffer; + } + else + { + return {}; + } +} + +void +BasicFile::StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun) +{ + StreamByteRange(0, FileSize(), std::move(ChunkFun)); +} + +void +BasicFile::StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun) +{ + const uint64_t ChunkSize = 128 * 1024; + IoBuffer ReadBuffer{ChunkSize}; + void* BufferPtr = ReadBuffer.MutableData(); + + uint64_t RemainBytes = Size; + uint64_t CurrentOffset = FileOffset; + + while (RemainBytes) + { + const uint64_t ThisChunkBytes = zen::Min(ChunkSize, RemainBytes); + + Read(BufferPtr, ThisChunkBytes, CurrentOffset); + + ChunkFun(BufferPtr, ThisChunkBytes); + + CurrentOffset += ThisChunkBytes; + RemainBytes -= ThisChunkBytes; + } +} + +uint64_t +BasicFile::Write(const CompositeBuffer& Data, uint64_t FileOffset) +{ + std::error_code Ec; + uint64_t WrittenBytes = Write(Data, FileOffset, Ec); + + if (Ec) + { + std::error_code Dummy; + throw std::system_error(Ec, fmt::format("Failed to write to file '{}'", zen::PathFromHandle(m_FileHandle, Dummy))); + } + return WrittenBytes; +} + +uint64_t +BasicFile::Write(const CompositeBuffer& Data, uint64_t FileOffset, std::error_code& Ec) +{ + uint64_t WrittenBytes = 0; + for (const SharedBuffer& Buffer : Data.GetSegments()) + { + MemoryView BlockView = Buffer.GetView(); + Write(BlockView, FileOffset + WrittenBytes, Ec); + + if (Ec) + { + return WrittenBytes; + } + + WrittenBytes += BlockView.GetSize(); + } + + return WrittenBytes; +} + +void +BasicFile::Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec) +{ + Write(Data.GetData(), Data.GetSize(), FileOffset, Ec); +} + +void +BasicFile::Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec) +{ + const uint64_t MaxChunkSize = 2u * 1024 * 1024; + + WriteFile(m_FileHandle, Data, Size, FileOffset, MaxChunkSize, Ec); +} + +void +BasicFile::Write(MemoryView Data, uint64_t FileOffset) +{ + Write(Data.GetData(), Data.GetSize(), FileOffset); +} + +void +BasicFile::Write(const void* Data, uint64_t Size, uint64_t Offset) +{ + std::error_code Ec; + Write(Data, Size, Offset, Ec); + + if (Ec) + { + std::error_code Dummy; + throw std::system_error(Ec, fmt::format("Failed to write to file '{}'", zen::PathFromHandle(m_FileHandle, Dummy))); + } +} + +void +BasicFile::WriteAll(IoBuffer Data, std::error_code& Ec) +{ + Write(Data.Data(), Data.Size(), 0, Ec); +} + +void +BasicFile::Flush() +{ + if (m_FileHandle == nullptr) + { + return; + } +#if ZEN_PLATFORM_WINDOWS + FlushFileBuffers(m_FileHandle); +#else + int Fd = int(uintptr_t(m_FileHandle)); + fsync(Fd); +#endif +} + +uint64_t +BasicFile::FileSize() const +{ + std::error_code Ec; + uint64_t FileSize = FileSizeFromHandle(m_FileHandle, Ec); + if (Ec) + { + std::error_code Dummy; + ThrowSystemError(Ec.value(), fmt::format("Failed to get file size from file '{}'", PathFromHandle(m_FileHandle, Dummy))); + } + return FileSize; +} + +uint64_t +BasicFile::FileSize(std::error_code& Ec) const +{ + return FileSizeFromHandle(m_FileHandle, Ec); +} + +void +BasicFile::SetFileSize(uint64_t FileSize) +{ +#if ZEN_PLATFORM_WINDOWS + LARGE_INTEGER liFileSize; + liFileSize.QuadPart = FileSize; + BOOL OK = ::SetFilePointerEx(m_FileHandle, liFileSize, 0, FILE_BEGIN); + if (OK == FALSE) + { + int Error = zen::GetLastError(); + if (Error) + { + std::error_code Dummy; + ThrowSystemError(Error, + fmt::format("Failed to set file pointer to {} for file {}", FileSize, PathFromHandle(m_FileHandle, Dummy))); + } + } + OK = ::SetEndOfFile(m_FileHandle); + if (OK == FALSE) + { + int Error = zen::GetLastError(); + if (Error) + { + std::error_code Dummy; + ThrowSystemError(Error, + fmt::format("Failed to set end of file to {} for file {}", FileSize, PathFromHandle(m_FileHandle, Dummy))); + } + } +#elif ZEN_PLATFORM_MAC + int Fd = int(intptr_t(m_FileHandle)); + if (ftruncate(Fd, (off_t)FileSize) < 0) + { + int Error = zen::GetLastError(); + if (Error) + { + std::error_code Dummy; + ThrowSystemError(Error, + fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle, Dummy))); + } + } +#else + int Fd = int(intptr_t(m_FileHandle)); + if (ftruncate64(Fd, (off64_t)FileSize) < 0) + { + int Error = zen::GetLastError(); + if (Error) + { + std::error_code Dummy; + ThrowSystemError(Error, + fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle, Dummy))); + } + } + if (FileSize > 0) + { + int Error = posix_fallocate64(Fd, 0, (off64_t)FileSize); + if (Error) + { + std::error_code Dummy; + ThrowSystemError(Error, + fmt::format("Failed to allocate space of {} for file {}", FileSize, PathFromHandle(m_FileHandle, Dummy))); + } + } +#endif +} + +void +BasicFile::Attach(void* Handle) +{ + ZEN_ASSERT(Handle != nullptr); + ZEN_ASSERT(m_FileHandle == nullptr); + m_FileHandle = Handle; +} + +void* +BasicFile::Detach() +{ + void* FileHandle = m_FileHandle; + m_FileHandle = 0; + return FileHandle; +} + +////////////////////////////////////////////////////////////////////////// + +TemporaryFile::~TemporaryFile() +{ + Close(); +} + +void +TemporaryFile::Close() +{ + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + // Mark file for deletion when final handle is closed + + FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; + + SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); +#else + std::error_code Ec; + std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec); + if (!Ec) + { + unlink(FilePath.c_str()); + } +#endif + + BasicFile::Close(); + } +} + +void +TemporaryFile::CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec) +{ + StringBuilder<64> TempName; + Oid::NewOid().ToString(TempName); + + m_TempPath = TempDirName / TempName.c_str(); + + Open(m_TempPath, BasicFile::Mode::kTruncateDelete, Ec); +} + +void +TemporaryFile::MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec) +{ + // We intentionally call the base class Close() since otherwise we'll end up + // deleting the temporary file + BasicFile::Close(); + + RenameFile(m_TempPath, FinalFileName, Ec); + + if (Ec) + { + // Try to re-open the temp file so we clean up after us when TemporaryFile is destructed + std::error_code DummyEc; + Open(m_TempPath, BasicFile::Mode::kWrite, DummyEc); + } +} + +////////////////////////////////////////////////////////////////////////// + +void +TemporaryFile::SafeWriteFile(const std::filesystem::path& Path, MemoryView Data) +{ + std::error_code Ec; + SafeWriteFile(Path, Data, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to safely write file '{}'", Path)); + } +} + +void +TemporaryFile::SafeWriteFile(const std::filesystem::path& Path, MemoryView Data, std::error_code& OutEc) +{ + TemporaryFile TempFile; + if (TempFile.CreateTemporary(Path.parent_path(), OutEc); !OutEc) + { + if (TempFile.Write(Data, 0, OutEc); !OutEc) + { + TempFile.MoveTemporaryIntoPlace(Path, OutEc); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +LockFile::LockFile() +{ +} + +LockFile::~LockFile() +{ +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Fd = int(intptr_t(m_FileHandle)); + flock(Fd, LOCK_UN | LOCK_NB); +#endif +} + +void +LockFile::Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + Ec.clear(); + + const DWORD dwCreationDisposition = CREATE_ALWAYS; + DWORD dwDesiredAccess = GENERIC_READ | GENERIC_WRITE | DELETE; + const DWORD dwShareMode = FILE_SHARE_READ; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_DELETE_ON_CLOSE; + HANDLE hTemplateFile = nullptr; + + HANDLE FileHandle = CreateFile(FileName.c_str(), + dwDesiredAccess, + dwShareMode, + /* lpSecurityAttributes */ nullptr, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + Ec = zen::MakeErrorCodeFromLastError(); + + return; + } +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Fd = open(FileName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666); + if (Fd < 0) + { + Ec = zen::MakeErrorCodeFromLastError(); + return; + } + fchmod(Fd, 0666); + + int LockRet = flock(Fd, LOCK_EX | LOCK_NB); + if (LockRet < 0) + { + Ec = zen::MakeErrorCodeFromLastError(); + close(Fd); + return; + } + + void* FileHandle = (void*)uintptr_t(Fd); +#endif + + m_FileHandle = FileHandle; + + BasicFile::Write(Payload.GetBuffer(), 0, Ec); +} + +void +LockFile::Update(CbObject Payload, std::error_code& Ec) +{ + BasicFile::Write(Payload.GetBuffer(), 0, Ec); +} + +////////////////////////////////////////////////////////////////////////// + +BasicFileBuffer::BasicFileBuffer(BasicFile& Base, uint64_t BufferSize) +: m_Base(Base) +, m_Buffer(nullptr) +, m_BufferSize(BufferSize) +, m_Size(Base.FileSize()) +, m_BufferStart(0) +, m_BufferEnd(0) +{ + m_Buffer = (uint8_t*)Memory::Alloc(m_BufferSize); +} + +BasicFileBuffer::~BasicFileBuffer() +{ + Memory::Free(m_Buffer); +} + +void +BasicFileBuffer::Read(void* Data, uint64_t Size, uint64_t FileOffset) +{ + if (m_Buffer == nullptr || (Size > m_BufferSize) || (FileOffset + Size > m_Size)) + { + m_Base.Read(Data, Size, FileOffset); + return; + } + uint8_t* WritePtr = ((uint8_t*)Data); + uint64_t Begin = FileOffset; + uint64_t End = FileOffset + Size; + if (FileOffset <= m_BufferStart) + { + if (End > m_BufferStart) + { + uint64_t Count = Min(m_BufferEnd, End) - m_BufferStart; + memcpy(WritePtr + End - Count - FileOffset, m_Buffer, Count); + End -= Count; + if (Begin == End) + { + return; + } + } + } + else if (FileOffset < m_BufferEnd) + { + uint64_t Count = Min(m_BufferEnd, End) - FileOffset; + memcpy(WritePtr + Begin - FileOffset, m_Buffer + Begin - m_BufferStart, Count); + Begin += Count; + if (Begin == End) + { + return; + } + } + m_BufferStart = Begin; + m_BufferEnd = Min(Begin + m_BufferSize, m_Size); + m_Base.Read(m_Buffer, m_BufferEnd - m_BufferStart, m_BufferStart); + uint64_t Count = Min(m_BufferEnd, End) - m_BufferStart; + memcpy(WritePtr + Begin - FileOffset, m_Buffer, Count); + ZEN_ASSERT(Begin + Count == End); +} + +MemoryView +BasicFileBuffer::MakeView(uint64_t Size, uint64_t FileOffset) +{ + if (FileOffset < m_BufferStart || (FileOffset + Size) > m_BufferEnd) + { + if (m_Buffer == nullptr || (Size > m_BufferSize) || (FileOffset + Size > m_Size)) + { + return {}; + } + m_BufferStart = FileOffset; + m_BufferEnd = Min(m_BufferStart + m_BufferSize, m_Size); + m_Base.Read(m_Buffer, m_BufferEnd - m_BufferStart, m_BufferStart); + } + return MemoryView(m_Buffer + (FileOffset - m_BufferStart), Size); +} + +////////////////////////////////////////////////////////////////////////// + +BasicFileWriter::BasicFileWriter(BasicFile& Base, uint64_t BufferSize) +: m_Base(Base) +, m_Buffer(nullptr) +, m_BufferSize(BufferSize) +, m_BufferStart(0) +, m_BufferEnd(0) +{ + m_Buffer = (uint8_t*)Memory::Alloc(m_BufferSize); +} + +BasicFileWriter::~BasicFileWriter() +{ + Flush(); + Memory::Free(m_Buffer); +} + +void +BasicFileWriter::AddPadding(uint64_t Padding) +{ + while (Padding) + { + const uint64_t BufferOffset = m_BufferEnd - m_BufferStart; + const uint64_t RemainingBufferCapacity = m_BufferSize - BufferOffset; + const uint64_t BlockPadBytes = Min(RemainingBufferCapacity, Padding); + + memset(m_Buffer + BufferOffset, 0, BlockPadBytes); + m_BufferEnd += BlockPadBytes; + Padding -= BlockPadBytes; + + if ((BufferOffset + BlockPadBytes) == m_BufferSize) + { + Flush(); + } + } +} + +uint64_t +BasicFileWriter::AlignTo(uint64_t Alignment) +{ + uint64_t AlignedPos = RoundUp(m_BufferEnd, Alignment); + uint64_t Padding = AlignedPos - m_BufferEnd; + AddPadding(Padding); + return AlignedPos; +} + +void +BasicFileWriter::Write(const void* Data, uint64_t Size, uint64_t FileOffset) +{ + if (m_Buffer == nullptr || (Size >= m_BufferSize)) + { + if (FileOffset == m_BufferEnd) + { + Flush(); + m_BufferStart = m_BufferEnd = FileOffset + Size; + } + + m_Base.Write(Data, Size, FileOffset); + return; + } + + // Note that this only supports buffering of sequential writes! + + if (FileOffset != m_BufferEnd) + { + Flush(); + m_BufferStart = m_BufferEnd = FileOffset; + } + + const uint8_t* DataPtr = (const uint8_t*)Data; + while (Size) + { + const uint64_t RemainingBufferCapacity = m_BufferStart + m_BufferSize - m_BufferEnd; + const uint64_t BlockWriteBytes = Min(RemainingBufferCapacity, Size); + const uint64_t BufferWriteOffset = FileOffset - m_BufferStart; + + ZEN_ASSERT_SLOW(BufferWriteOffset < m_BufferSize); + ZEN_ASSERT_SLOW((BufferWriteOffset + BlockWriteBytes) <= m_BufferSize); + + memcpy(m_Buffer + BufferWriteOffset, DataPtr, BlockWriteBytes); + + Size -= BlockWriteBytes; + m_BufferEnd += BlockWriteBytes; + FileOffset += BlockWriteBytes; + DataPtr += BlockWriteBytes; + + if ((m_BufferEnd - m_BufferStart) == m_BufferSize) + { + Flush(); + } + } +} + +void +BasicFileWriter::Write(const CompositeBuffer& Data, uint64_t FileOffset) +{ + for (const SharedBuffer& Segment : Data.GetSegments()) + { + const uint64_t SegmentSize = Segment.GetSize(); + Write(Segment.GetData(), SegmentSize, FileOffset); + FileOffset += SegmentSize; + } +} + +void +BasicFileWriter::Flush() +{ + const uint64_t BufferedBytes = m_BufferEnd - m_BufferStart; + + if (BufferedBytes == 0) + return; + + const uint64_t WriteOffset = m_BufferStart; + m_BufferStart = m_BufferEnd; + + m_Base.Write(m_Buffer, BufferedBytes, WriteOffset); +} + +IoBuffer +WriteToTempFile(CompositeBuffer&& Buffer, const std::filesystem::path& Path) +{ + TemporaryFile Temp; + std::error_code Ec; + Temp.CreateTemporary(Path.parent_path(), Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to create temp file for blob at '{}'", Path)); + } + + uint64_t BufferSize = Buffer.GetSize(); + { + uint64_t Offset = 0; + static const uint64_t BufferingSize = 256u * 1024u; + BasicFileWriter BufferedOutput(Temp, Min(BufferingSize, BufferSize)); + for (const SharedBuffer& Segment : Buffer.GetSegments()) + { + size_t SegmentSize = Segment.GetSize(); + + IoBufferFileReference FileRef; + if (SegmentSize >= (BufferingSize + BufferingSize / 2) && Segment.GetFileReference(FileRef)) + { + ScanFile(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + BufferingSize, + [&BufferedOutput, &Offset](const void* Data, size_t Size) { + BufferedOutput.Write(Data, Size, Offset); + Offset += Size; + }); + } + else + { + BufferedOutput.Write(Segment.GetData(), SegmentSize, Offset); + Offset += SegmentSize; + } + } + } + + Temp.MoveTemporaryIntoPlace(Path, Ec); + if (Ec) + { + Ec.clear(); + BasicFile OpenTemp(Path, BasicFile::Mode::kDelete, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to move temp file to '{}'", Path)); + } + if (OpenTemp.FileSize() != BufferSize) + { + throw std::runtime_error(fmt::format("Failed to move temp file to '{}' - mismatching file size already exists", Path)); + } + IoBuffer TmpBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BufferSize, true); + + IoHash ExistingHash = IoHash::HashBuffer(TmpBuffer); + const IoHash ExpectedHash = IoHash::HashBuffer(Buffer); + if (ExistingHash != ExpectedHash) + { + throw std::runtime_error(fmt::format("Failed to move temp file to '{}' - mismatching file hash already exists", Path)); + } + Buffer = CompositeBuffer{}; + TmpBuffer.SetDeleteOnClose(true); + return TmpBuffer; + } + Buffer = CompositeBuffer{}; + BasicFile OpenTemp(Path, BasicFile::Mode::kDelete); + IoBuffer TmpBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BufferSize, true); + TmpBuffer.SetDeleteOnClose(true); + return TmpBuffer; +} + +////////////////////////////////////////////////////////////////////////// + +/* + ___________ __ + \__ ___/___ _______/ |_ ______ + | |_/ __ \ / ___/\ __\/ ___/ + | |\ ___/ \___ \ | | \___ \ + |____| \___ >____ > |__| /____ > + \/ \/ \/ +*/ + +#if ZEN_WITH_TESTS + +TEST_CASE("BasicFile") +{ + ScopedCurrentDirectoryChange _; + + BasicFile File1; + CHECK_THROWS(File1.Open("zonk", BasicFile::Mode::kRead)); + CHECK_NOTHROW(File1.Open("zonk", BasicFile::Mode::kTruncate)); + CHECK_NOTHROW(File1.Write("abcd", 4, 0)); + CHECK(File1.FileSize() == 4); + { + IoBuffer Data = File1.ReadAll(); + CHECK(Data.Size() == 4); + CHECK_EQ(memcmp(Data.Data(), "abcd", 4), 0); + } + CHECK_NOTHROW(File1.Write("efgh", 4, 2)); + CHECK(File1.FileSize() == 6); + { + IoBuffer Data = File1.ReadAll(); + CHECK(Data.Size() == 6); + CHECK_EQ(memcmp(Data.Data(), "abefgh", 6), 0); + } +} + +TEST_CASE("TemporaryFile") +{ + ScopedCurrentDirectoryChange _; + + SUBCASE("DeleteOnClose") + { + std::filesystem::path Path; + { + TemporaryFile TmpFile; + std::error_code Ec; + TmpFile.CreateTemporary(std::filesystem::current_path(), Ec); + CHECK(!Ec); + Path = TmpFile.GetPath(); + CHECK(IsFile(Path)); + } + CHECK(IsFile(Path) == false); + } + + SUBCASE("MoveIntoPlace") + { + TemporaryFile TmpFile; + std::error_code Ec; + TmpFile.CreateTemporary(std::filesystem::current_path(), Ec); + CHECK(!Ec); + std::filesystem::path TempPath = TmpFile.GetPath(); + std::filesystem::path FinalPath = std::filesystem::current_path() / "final"; + CHECK(IsFile(TempPath)); + TmpFile.MoveTemporaryIntoPlace(FinalPath, Ec); + CHECK(!Ec); + CHECK(IsFile(TempPath) == false); + CHECK(IsFile(FinalPath)); + } +} + +TEST_CASE("BasicFileBuffer") +{ + ScopedCurrentDirectoryChange _; + { + BasicFile File1; + const std::string_view Data = "0123456789abcdef"; + File1.Open("buffered", BasicFile::Mode::kTruncate); + for (uint32_t I = 0; I < 16; ++I) + { + File1.Write(Data.data(), Data.size(), I * Data.size()); + } + } + SUBCASE("EvenBuffer") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + // Non-primed + { + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 1 * 16); + std::string_view Verify(Buffer, 16); + CHECK(Verify == "0123456789abcdef"); + } + // Primed + { + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 1 * 16); + std::string_view Verify(Buffer, 16); + CHECK(Verify == "0123456789abcdef"); + } + } + SUBCASE("UnevenBuffer") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + // Non-primed + { + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 7); + std::string_view Verify(Buffer, 16); + CHECK(Verify == "789abcdef0123456"); + } + // Primed + { + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 7); + std::string_view Verify(Buffer, 16); + CHECK(Verify == "789abcdef0123456"); + } + } + SUBCASE("BiggerThanBuffer") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + char Buffer[17] = {0}; + File1Buffer.Read(Buffer, 17, 0 * 16); + std::string_view Verify(Buffer, 17); + CHECK(Verify == "0123456789abcdef0"); + } + SUBCASE("InsideBuffer") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 0 * 16); + + File1Buffer.Read(Buffer, 8, 2); + std::string_view Verify(Buffer, 8); + CHECK(Verify == "23456789"); + } + SUBCASE("BeginningOfBuffer") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 8); + + File1Buffer.Read(Buffer, 8, 8); + std::string_view Verify(Buffer, 8); + CHECK(Verify == "89abcdef"); + } + SUBCASE("EndOfBuffer") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 0 * 16); + + File1Buffer.Read(Buffer, 8, 8); + std::string_view Verify(Buffer, 8); + CHECK(Verify == "89abcdef"); + } + SUBCASE("OverEnd") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 0 * 16); + + File1Buffer.Read(Buffer, 16, 8); + std::string_view Verify(Buffer, 16); + CHECK(Verify == "89abcdef01234567"); + } + SUBCASE("OverBegin") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 1 * 16); + + File1Buffer.Read(Buffer, 16, 8); + std::string_view Verify(Buffer, 16); + CHECK(Verify == "89abcdef01234567"); + } + SUBCASE("EndOfFile") + { + BasicFile File1; + File1.Open("buffered", BasicFile::Mode::kRead); + BasicFileBuffer File1Buffer(File1, 16); + char Buffer[16] = {0}; + File1Buffer.Read(Buffer, 16, 0 * 16); + + File1Buffer.Read(Buffer, 8, 256 - 8); + std::string_view Verify(Buffer, 8); + CHECK(Verify == "89abcdef"); + } +} + +void +basicfile_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zencore/blake3.cpp b/src/zencore/blake3.cpp index 89826ae5d..054f0d3a0 100644 --- a/src/zencore/blake3.cpp +++ b/src/zencore/blake3.cpp @@ -3,6 +3,7 @@ #include <zencore/blake3.h> #include <zencore/compositebuffer.h> +#include <zencore/filesystem.h> #include <zencore/string.h> #include <zencore/testing.h> #include <zencore/zencore.h> @@ -45,7 +46,51 @@ BLAKE3::HashBuffer(const CompositeBuffer& Buffer) for (const SharedBuffer& Segment : Buffer.GetSegments()) { - blake3_hasher_update(&Hasher, Segment.GetData(), Segment.GetSize()); + size_t SegmentSize = Segment.GetSize(); + static const uint64_t BufferingSize = 256u * 1024u; + + IoBufferFileReference FileRef; + if (SegmentSize >= (BufferingSize + BufferingSize / 2) && Segment.GetFileReference(FileRef)) + { + ScanFile(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + BufferingSize, + [&Hasher](const void* Data, size_t Size) { blake3_hasher_update(&Hasher, Data, Size); }); + } + else + { + blake3_hasher_update(&Hasher, Segment.GetData(), SegmentSize); + } + } + + blake3_hasher_finalize(&Hasher, Hash.Hash, sizeof Hash.Hash); + + return Hash; +} + +BLAKE3 +BLAKE3::HashBuffer(const IoBuffer& Buffer) +{ + BLAKE3 Hash; + + blake3_hasher Hasher; + blake3_hasher_init(&Hasher); + + size_t BufferSize = Buffer.GetSize(); + static const uint64_t BufferingSize = 256u * 1024u; + IoBufferFileReference FileRef; + if (BufferSize >= (BufferingSize + BufferingSize / 2) && Buffer.GetFileReference(FileRef)) + { + ScanFile(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + BufferingSize, + [&Hasher](const void* Data, size_t Size) { blake3_hasher_update(&Hasher, Data, Size); }); + } + else + { + blake3_hasher_update(&Hasher, Buffer.GetData(), BufferSize); } blake3_hasher_finalize(&Hasher, Hash.Hash, sizeof Hash.Hash); @@ -106,6 +151,28 @@ BLAKE3Stream::Append(const void* data, size_t byteCount) return *this; } +BLAKE3Stream& +BLAKE3Stream::Append(const IoBuffer& Buffer) +{ + blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState); + + size_t BufferSize = Buffer.GetSize(); + static const uint64_t BufferingSize = 256u * 1024u; + IoBufferFileReference FileRef; + if (BufferSize >= (BufferingSize + BufferingSize / 2) && Buffer.GetFileReference(FileRef)) + { + ScanFile(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, BufferingSize, [&b3h](const void* Data, size_t Size) { + blake3_hasher_update(b3h, Data, Size); + }); + } + else + { + blake3_hasher_update(b3h, Buffer.GetData(), BufferSize); + } + + return *this; +} + BLAKE3 BLAKE3Stream::GetHash() { diff --git a/src/zencore/callstack.cpp b/src/zencore/callstack.cpp new file mode 100644 index 000000000..8aa1111bf --- /dev/null +++ b/src/zencore/callstack.cpp @@ -0,0 +1,282 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/callstack.h> +#include <zencore/filesystem.h> +#include <zencore/thread.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# include <Dbghelp.h> +#endif + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +# include <execinfo.h> +#endif + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif + +#include <fmt/format.h> + +namespace zen { +#if ZEN_PLATFORM_WINDOWS + +class WinSymbolInit +{ +public: + WinSymbolInit() {} + ~WinSymbolInit() + { + m_CallstackLock.WithExclusiveLock([this]() { + if (m_Initialized) + { + SymCleanup(m_CurrentProcess); + } + }); + } + + bool GetSymbol(void* Frame, SYMBOL_INFO* OutSymbolInfo, DWORD64& OutDisplacement) + { + bool Result = false; + m_CallstackLock.WithExclusiveLock([&]() { + if (!m_Initialized) + { + m_CurrentProcess = GetCurrentProcess(); + std::filesystem::path ProgramBaseDir = GetRunningExecutablePath().parent_path(); + if (SymInitializeW(m_CurrentProcess, ProgramBaseDir.c_str(), TRUE) == TRUE) + { + m_Initialized = true; + } + } + if (m_Initialized) + { + if (SymFromAddr(m_CurrentProcess, (DWORD64)Frame, &OutDisplacement, OutSymbolInfo) == TRUE) + { + Result = true; + } + } + }); + return Result; + } + +private: + HANDLE m_CurrentProcess = NULL; + BOOL m_Initialized = FALSE; + RwLock m_CallstackLock; +}; + +static WinSymbolInit WinSymbols; + +#endif + +CallstackFrames* +CreateCallstack(uint32_t FrameCount, void** Frames) noexcept +{ + if (FrameCount == 0) + { + return nullptr; + } + CallstackFrames* Callstack = (CallstackFrames*)malloc(sizeof(CallstackFrames) + sizeof(void*) * FrameCount); + if (Callstack != nullptr) + { + Callstack->FrameCount = FrameCount; + if (FrameCount == 0) + { + Callstack->Frames = nullptr; + } + else + { + Callstack->Frames = (void**)&Callstack[1]; + memcpy(Callstack->Frames, Frames, sizeof(void*) * FrameCount); + } + } + return Callstack; +} + +CallstackFrames* +CloneCallstack(const CallstackFrames* Callstack) noexcept +{ + if (Callstack == nullptr) + { + return nullptr; + } + return CreateCallstack(Callstack->FrameCount, Callstack->Frames); +} + +void +FreeCallstack(CallstackFrames* Callstack) noexcept +{ + if (Callstack != nullptr) + { + free(Callstack); + } +} + +uint32_t +GetCallstack(int FramesToSkip, int FramesToCapture, void* OutAddresses[]) +{ +#if ZEN_PLATFORM_WINDOWS + return (uint32_t)CaptureStackBackTrace(FramesToSkip, FramesToCapture, OutAddresses, 0); +#endif +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + void* Frames[FramesToSkip + FramesToCapture]; + int FrameCount = backtrace(Frames, FramesToSkip + FramesToCapture); + if (FrameCount > FramesToSkip) + { + for (int Index = FramesToSkip; Index < FrameCount; Index++) + { + OutAddresses[Index - FramesToSkip] = Frames[Index]; + } + return (uint32_t)(FrameCount - FramesToSkip); + } + else + { + return 0; + } +#endif +} + +std::vector<std::string> +GetFrameSymbols(uint32_t FrameCount, void** Frames) +{ + std::vector<std::string> FrameSymbols; + if (FrameCount > 0) + { + FrameSymbols.resize(FrameCount); +#if ZEN_PLATFORM_WINDOWS + char SymbolBuffer[sizeof(SYMBOL_INFO) + 1024]; + SYMBOL_INFO* SymbolInfo = (SYMBOL_INFO*)SymbolBuffer; + SymbolInfo->SizeOfStruct = sizeof(SYMBOL_INFO); + SymbolInfo->MaxNameLen = 1023; + DWORD64 Displacement = 0; + for (uint32_t FrameIndex = 0; FrameIndex < FrameCount; FrameIndex++) + { + if (WinSymbols.GetSymbol(Frames[FrameIndex], SymbolInfo, Displacement)) + { + FrameSymbols[FrameIndex] = fmt::format("{}+{:#x} [{:#x}]", SymbolInfo->Name, Displacement, (uintptr_t)Frames[FrameIndex]); + } + } +#endif +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + char** messages = backtrace_symbols(Frames, (int)FrameCount); + if (messages) + { + for (uint32_t FrameIndex = 0; FrameIndex < FrameCount; FrameIndex++) + { + FrameSymbols[FrameIndex] = messages[FrameIndex]; + } + free(messages); + } +#endif + } + return FrameSymbols; +} + +void +FormatCallstack(const CallstackFrames* Callstack, StringBuilderBase& SB, std::string_view Prefix) +{ + bool First = true; + for (const std::string& Symbol : GetFrameSymbols(Callstack)) + { + try + { + if (!First) + { + SB.Append("\n"); + } + else + { + First = false; + } + if (!Prefix.empty()) + { + SB.Append(Prefix); + } + SB.Append(Symbol); + } + catch (const std::exception&) + { + break; + } + } +} + +std::string +CallstackToString(const CallstackFrames* Callstack, std::string_view Prefix) +{ + StringBuilder<2048> SB; + FormatCallstack(Callstack, SB, Prefix); + return SB.ToString(); +} + +void +CallstackToStringRaw(const CallstackFrames* Callstack, void* CallbackUserData, CallstackRawCallback Callback) +{ + if (Callstack && Callstack->FrameCount > 0) + { +#if ZEN_PLATFORM_WINDOWS + char SymbolBuffer[sizeof(SYMBOL_INFO) + 1024]; + SYMBOL_INFO* SymbolInfo = (SYMBOL_INFO*)SymbolBuffer; + SymbolInfo->SizeOfStruct = sizeof(SYMBOL_INFO); + SymbolInfo->MaxNameLen = 1023; + DWORD64 Displacement = 0; + fmt::basic_memory_buffer<char, 2048> Message; + for (uint32_t FrameIndex = 0; FrameIndex < Callstack->FrameCount; FrameIndex++) + { + if (WinSymbols.GetSymbol(Callstack->Frames[FrameIndex], SymbolInfo, Displacement)) + { + auto Appender = fmt::appender(Message); + fmt::format_to(Appender, "{}+{:#x} [{:#x}]", SymbolInfo->Name, Displacement, (uintptr_t)Callstack->Frames[FrameIndex]); + Message.push_back('\0'); + Callback(CallbackUserData, FrameIndex, Message.data()); + Message.resize(0); + } + } +#endif +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + char** messages = backtrace_symbols(Callstack->Frames, (int)Callstack->FrameCount); + if (messages) + { + for (uint32_t FrameIndex = 0; FrameIndex < Callstack->FrameCount; FrameIndex++) + { + Callback(CallbackUserData, FrameIndex, messages[FrameIndex]); + } + free(messages); + } +#endif + } +} + +CallstackFrames* +GetCallstackRaw(void* CaptureBuffer, int FramesToSkip, int FramesToCapture) +{ + CallstackFrames* Callstack = (CallstackFrames*)CaptureBuffer; + + Callstack->Frames = (void**)&Callstack[1]; + Callstack->FrameCount = GetCallstack(FramesToSkip, FramesToCapture, Callstack->Frames); + return Callstack; +} + +#if ZEN_WITH_TESTS + +TEST_CASE("Callstack.Basic") +{ + void* Addresses[4]; + uint32_t FrameCount = GetCallstack(1, 4, Addresses); + CHECK(FrameCount > 0); + std::vector<std::string> Symbols = GetFrameSymbols(FrameCount, Addresses); + for (const std::string& Symbol : Symbols) + { + CHECK(!Symbol.empty()); + } +} + +void +callstack_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zencore/commandline.cpp b/src/zencore/commandline.cpp new file mode 100644 index 000000000..c801bf151 --- /dev/null +++ b/src/zencore/commandline.cpp @@ -0,0 +1,72 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/commandline.h> +#include <zencore/string.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +ZEN_THIRD_PARTY_INCLUDES_START +# include <shellapi.h> // For command line parsing +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +#include <functional> + +namespace zen { + +void +IterateCommandlineArgs(std::function<void(const std::string_view& Arg)>& ProcessArg) +{ +#if ZEN_PLATFORM_WINDOWS + int ArgC = 0; + const LPWSTR CmdLine = ::GetCommandLineW(); + const LPWSTR* ArgV = ::CommandLineToArgvW(CmdLine, &ArgC); + + if (ArgC > 1) + { + for (int i = 1; i < ArgC; ++i) + { + StringBuilder<4096> ArgString8; + + WideToUtf8(ArgV[i], ArgString8); + + ProcessArg(ArgString8); + } + } + + ::LocalFree(HLOCAL(ArgV)); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + if (FILE* CmdLineFile = fopen("/proc/self/cmdline", "r")) + { + const char* ArgV[255] = {}; + int ArgC = 0; + + char* Arg = nullptr; + size_t Size = 0; + while (getdelim(&Arg, &Size, 0, CmdLineFile) != -1) + { + ArgV[ArgC++] = Arg; + Arg = nullptr; // getdelim will allocate buffer for next Arg + } + fclose(CmdLineFile); + + if (ArgC > 1) + { + for (int i = 1; i < ArgC; ++i) + { + ProcessArg(ArgV[i]); + } + } + + // cleanup after getdelim + while (ArgC > 0) + { + free((void*)ArgV[--ArgC]); + } + } +#else +# error Unknown platform +#endif +} + +} // namespace zen diff --git a/src/zencore/compactbinary.cpp b/src/zencore/compactbinary.cpp index 6677b5a61..b43cc18f1 100644 --- a/src/zencore/compactbinary.cpp +++ b/src/zencore/compactbinary.cpp @@ -15,6 +15,8 @@ #include <zencore/testing.h> #include <zencore/uid.h> +#include <EASTL/fixed_vector.h> + #include <fmt/format.h> #include <string_view> @@ -24,10 +26,6 @@ # include <time.h> #endif -ZEN_THIRD_PARTY_INCLUDES_START -#include <json11.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - namespace zen { const int DaysToMonth[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; @@ -411,7 +409,7 @@ CbFieldView::CbFieldView(const void* DataPointer, CbFieldType FieldType) } void -CbFieldView::IterateAttachments(std::function<void(CbFieldView)> Visitor) const +CbFieldView::IterateAttachments(const std::function<void(CbFieldView)>& Visitor) const { switch (CbFieldTypeOps::GetType(Type)) { @@ -1174,7 +1172,7 @@ template class TCbFieldIterator<CbField>; template<typename FieldType> void -TCbFieldIterator<FieldType>::IterateRangeAttachments(std::function<void(CbFieldView)> Visitor) const +TCbFieldIterator<FieldType>::IterateRangeAttachments(const std::function<void(CbFieldView)>& Visitor) const { if (CbFieldTypeOps::HasFieldType(FieldType::GetType())) { @@ -1380,9 +1378,9 @@ TryMeasureCompactBinary(MemoryView View, CbFieldType& OutType, uint64_t& OutSize CbField LoadCompactBinary(BinaryReader& Ar, BufferAllocator Allocator) { - std::vector<uint8_t> HeaderBytes; - CbFieldType FieldType; - uint64_t FieldSize = 1; + eastl::fixed_vector<uint8_t, 32> HeaderBytes; + CbFieldType FieldType; + uint64_t FieldSize = 1; for (const int64_t StartPos = Ar.CurrentOffset(); FieldSize > 0;) { @@ -1397,7 +1395,7 @@ LoadCompactBinary(BinaryReader& Ar, BufferAllocator Allocator) HeaderBytes.resize(ReadOffset + ReadSize); Ar.Read(HeaderBytes.data() + ReadOffset, ReadSize); - if (TryMeasureCompactBinary(MakeMemoryView(HeaderBytes), FieldType, FieldSize)) + if (TryMeasureCompactBinary(MakeMemoryView(HeaderBytes.data(), HeaderBytes.size()), FieldType, FieldSize)) { if (FieldSize <= uint64_t(Ar.Size() - StartPos)) { @@ -1425,25 +1423,43 @@ LoadCompactBinary(BinaryReader& Ar, BufferAllocator Allocator) CbObject LoadCompactBinaryObject(IoBuffer&& Payload) { + if (Payload.GetSize() == 0) + { + return CbObject(); + } return CbObject{SharedBuffer(std::move(Payload))}; } CbObject LoadCompactBinaryObject(const IoBuffer& Payload) { + if (Payload.GetSize() == 0) + { + return CbObject(); + } return CbObject{SharedBuffer(Payload)}; } CbObject LoadCompactBinaryObject(CompressedBuffer&& Payload) { - return CbObject{SharedBuffer(Payload.DecompressToComposite().Flatten())}; + CompositeBuffer Decompressed = std::move(Payload).DecompressToComposite(); + if (Decompressed.GetSize() == 0) + { + return CbObject(); + } + return CbObject{std::move(Decompressed).Flatten()}; } CbObject LoadCompactBinaryObject(const CompressedBuffer& Payload) { - return CbObject{SharedBuffer(Payload.DecompressToComposite().Flatten())}; + CompositeBuffer Decompressed = Payload.DecompressToComposite(); + if (Decompressed.GetSize() == 0) + { + return CbObject(); + } + return CbObject{std::move(Decompressed).Flatten()}; } ////////////////////////////////////////////////////////////////////////// @@ -1468,339 +1484,6 @@ SaveCompactBinary(BinaryWriter& Ar, const CbObjectView& Object) ////////////////////////////////////////////////////////////////////////// -class CbJsonWriter -{ -public: - explicit CbJsonWriter(StringBuilderBase& InBuilder) : Builder(InBuilder) { NewLineAndIndent << LINE_TERMINATOR_ANSI; } - - void BeginObject() - { - Builder << '{'; - NewLineAndIndent << '\t'; - NeedsNewLine = true; - } - - void EndObject() - { - NewLineAndIndent.RemoveSuffix(1); - if (NeedsComma) - { - WriteOptionalNewLine(); - } - Builder << '}'; - } - - void BeginArray() - { - Builder << '['; - NewLineAndIndent << '\t'; - NeedsNewLine = true; - } - - void EndArray() - { - NewLineAndIndent.RemoveSuffix(1); - if (NeedsComma) - { - WriteOptionalNewLine(); - } - Builder << ']'; - } - - void WriteField(CbFieldView Field) - { - using namespace std::literals; - - WriteOptionalComma(); - WriteOptionalNewLine(); - - if (std::u8string_view Name = Field.GetU8Name(); !Name.empty()) - { - AppendQuotedString(Name); - Builder << ": "sv; - } - - switch (CbValue Accessor = Field.GetValue(); Accessor.GetType()) - { - case CbFieldType::Null: - Builder << "null"sv; - break; - case CbFieldType::Object: - case CbFieldType::UniformObject: - { - BeginObject(); - for (CbFieldView It : Field) - { - WriteField(It); - } - EndObject(); - } - break; - case CbFieldType::Array: - case CbFieldType::UniformArray: - { - BeginArray(); - for (CbFieldView It : Field) - { - WriteField(It); - } - EndArray(); - } - break; - case CbFieldType::Binary: - AppendBase64String(Accessor.AsBinary()); - break; - case CbFieldType::String: - AppendQuotedString(Accessor.AsU8String()); - break; - case CbFieldType::IntegerPositive: - Builder << Accessor.AsIntegerPositive(); - break; - case CbFieldType::IntegerNegative: - Builder << Accessor.AsIntegerNegative(); - break; - case CbFieldType::Float32: - { - const float Value = Accessor.AsFloat32(); - if (std::isfinite(Value)) - { - Builder.Append(fmt::format("{:.9g}", Value)); - } - else - { - Builder << "null"sv; - } - } - break; - case CbFieldType::Float64: - { - const double Value = Accessor.AsFloat64(); - if (std::isfinite(Value)) - { - Builder.Append(fmt::format("{:.17g}", Value)); - } - else - { - Builder << "null"sv; - } - } - break; - case CbFieldType::BoolFalse: - Builder << "false"sv; - break; - case CbFieldType::BoolTrue: - Builder << "true"sv; - break; - case CbFieldType::ObjectAttachment: - case CbFieldType::BinaryAttachment: - { - Builder << '"'; - Accessor.AsAttachment().ToHexString(Builder); - Builder << '"'; - } - break; - case CbFieldType::Hash: - { - Builder << '"'; - Accessor.AsHash().ToHexString(Builder); - Builder << '"'; - } - break; - case CbFieldType::Uuid: - { - Builder << '"'; - Accessor.AsUuid().ToString(Builder); - Builder << '"'; - } - break; - case CbFieldType::DateTime: - Builder << '"' << DateTime(Accessor.AsDateTimeTicks()).ToIso8601() << '"'; - break; - case CbFieldType::TimeSpan: - { - const TimeSpan Span(Accessor.AsTimeSpanTicks()); - if (Span.GetDays() == 0) - { - Builder << '"' << Span.ToString("%h:%m:%s.%n") << '"'; - } - else - { - Builder << '"' << Span.ToString("%d.%h:%m:%s.%n") << '"'; - } - break; - } - case CbFieldType::ObjectId: - Builder << '"'; - Accessor.AsObjectId().ToString(Builder); - Builder << '"'; - break; - case CbFieldType::CustomById: - { - CbCustomById Custom = Accessor.AsCustomById(); - Builder << "{ \"Id\": "; - Builder << Custom.Id; - Builder << ", \"Data\": "; - AppendBase64String(Custom.Data); - Builder << " }"; - break; - } - case CbFieldType::CustomByName: - { - CbCustomByName Custom = Accessor.AsCustomByName(); - Builder << "{ \"Name\": "; - AppendQuotedString(Custom.Name); - Builder << ", \"Data\": "; - AppendBase64String(Custom.Data); - Builder << " }"; - break; - } - default: - ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType())); - break; - } - - NeedsComma = true; - NeedsNewLine = true; - } - -private: - void WriteOptionalComma() - { - if (NeedsComma) - { - NeedsComma = false; - Builder << ','; - } - } - - void WriteOptionalNewLine() - { - if (NeedsNewLine) - { - NeedsNewLine = false; - Builder << NewLineAndIndent; - } - } - - void AppendQuotedString(std::u8string_view Value) - { - using namespace std::literals; - - const AsciiSet EscapeSet( - "\\\"\b\f\n\r\t" - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f" - "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"); - - Builder << '\"'; - while (!Value.empty()) - { - std::u8string_view Verbatim = AsciiSet::FindPrefixWithout(Value, EscapeSet); - Builder << Verbatim; - - Value = Value.substr(Verbatim.size()); - - std::u8string_view Escape = AsciiSet::FindPrefixWith(Value, EscapeSet); - for (char Char : Escape) - { - switch (Char) - { - case '\\': - Builder << "\\\\"sv; - break; - case '\"': - Builder << "\\\""sv; - break; - case '\b': - Builder << "\\b"sv; - break; - case '\f': - Builder << "\\f"sv; - break; - case '\n': - Builder << "\\n"sv; - break; - case '\r': - Builder << "\\r"sv; - break; - case '\t': - Builder << "\\t"sv; - break; - default: - Builder << Char; - break; - } - } - Value = Value.substr(Escape.size()); - } - Builder << '\"'; - } - - void AppendBase64String(MemoryView Value) - { - Builder << '"'; - ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024); - const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); - const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); - Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); - } - -private: - StringBuilderBase& Builder; - ExtendableStringBuilder<32> NewLineAndIndent; - bool NeedsComma{false}; - bool NeedsNewLine{false}; -}; - -void -CompactBinaryToJson(const CbObjectView& Object, StringBuilderBase& Builder) -{ - CbJsonWriter Writer(Builder); - Writer.WriteField(Object.AsFieldView()); -} - -void -CompactBinaryToJson(const CbArrayView& Array, StringBuilderBase& Builder) -{ - CbJsonWriter Writer(Builder); - Writer.WriteField(Array.AsFieldView()); -} - -void -CompactBinaryToJson(MemoryView Data, StringBuilderBase& InBuilder) -{ - std::vector<CbFieldView> Fields = ReadCompactBinaryStream(Data); - CbJsonWriter Writer(InBuilder); - if (!Fields.empty()) - { - if (Fields.size() == 1) - { - Writer.WriteField(Fields[0]); - return; - } - bool UseTopLevelObject = Fields[0].HasName(); - if (UseTopLevelObject) - { - Writer.BeginObject(); - } - else - { - Writer.BeginArray(); - } - for (const CbFieldView& Field : Fields) - { - Writer.WriteField(Field); - } - if (UseTopLevelObject) - { - Writer.EndObject(); - } - else - { - Writer.EndArray(); - } - } -} - std::vector<CbFieldView> ReadCompactBinaryStream(MemoryView Data) { @@ -1823,225 +1506,6 @@ ReadCompactBinaryStream(MemoryView Data) ////////////////////////////////////////////////////////////////////////// -class CbJsonReader -{ -public: - static CbFieldIterator Read(std::string_view JsonText, std::string& Error) - { - using namespace json11; - - const Json Json = Json::parse(std::string(JsonText), Error); - - if (Error.empty()) - { - CbWriter Writer; - if (ReadField(Writer, Json, std::string_view(), Error)) - { - return Writer.Save(); - } - } - - return CbFieldIterator(); - } - -private: - static bool ReadField(CbWriter& Writer, const json11::Json& Json, const std::string_view FieldName, std::string& Error) - { - using namespace json11; - - switch (Json.type()) - { - case Json::Type::OBJECT: - { - if (FieldName.empty()) - { - Writer.BeginObject(); - } - else - { - Writer.BeginObject(FieldName); - } - - for (const auto& Kv : Json.object_items()) - { - const std::string& Name = Kv.first; - const json11::Json& Item = Kv.second; - - if (ReadField(Writer, Item, Name, Error) == false) - { - return false; - } - } - - Writer.EndObject(); - } - break; - case Json::Type::ARRAY: - { - if (FieldName.empty()) - { - Writer.BeginArray(); - } - else - { - Writer.BeginArray(FieldName); - } - - for (const json11::Json& Item : Json.array_items()) - { - if (ReadField(Writer, Item, std::string_view(), Error) == false) - { - return false; - } - } - - Writer.EndArray(); - } - break; - case Json::Type::NUL: - { - if (FieldName.empty()) - { - Writer.AddNull(); - } - else - { - Writer.AddNull(FieldName); - } - } - break; - case Json::Type::BOOL: - { - if (FieldName.empty()) - { - Writer.AddBool(Json.bool_value()); - } - else - { - Writer.AddBool(FieldName, Json.bool_value()); - } - } - break; - case Json::Type::NUMBER: - { - if (FieldName.empty()) - { - Writer.AddFloat(Json.number_value()); - } - else - { - Writer.AddFloat(FieldName, Json.number_value()); - } - } - break; - case Json::Type::STRING: - { - Oid Id; - if (TryParseObjectId(Json.string_value(), Id)) - { - if (FieldName.empty()) - { - Writer.AddObjectId(Id); - } - else - { - Writer.AddObjectId(FieldName, Id); - } - - return true; - } - - IoHash Hash; - if (TryParseIoHash(Json.string_value(), Hash)) - { - if (FieldName.empty()) - { - Writer.AddHash(Hash); - } - else - { - Writer.AddHash(FieldName, Hash); - } - - return true; - } - - if (FieldName.empty()) - { - Writer.AddString(Json.string_value()); - } - else - { - Writer.AddString(FieldName, Json.string_value()); - } - } - break; - default: - break; - } - - return true; - } - - static constexpr AsciiSet HexCharSet = AsciiSet("0123456789abcdefABCDEF"); - - static bool TryParseObjectId(std::string_view Str, Oid& Id) - { - using namespace std::literals; - - if (Str.size() == Oid::StringLength && AsciiSet::HasOnly(Str, HexCharSet)) - { - Id = Oid::FromHexString(Str); - return true; - } - - if (Str.starts_with("0x"sv)) - { - return TryParseObjectId(Str.substr(2), Id); - } - - return false; - } - - static bool TryParseIoHash(std::string_view Str, IoHash& Hash) - { - using namespace std::literals; - - if (Str.size() == IoHash::StringLength && AsciiSet::HasOnly(Str, HexCharSet)) - { - Hash = IoHash::FromHexString(Str); - return true; - } - - if (Str.starts_with("0x"sv)) - { - return TryParseIoHash(Str.substr(2), Hash); - } - - return false; - } -}; - -CbFieldIterator -LoadCompactBinaryFromJson(std::string_view Json, std::string& Error) -{ - if (Json.empty() == false) - { - return CbJsonReader::Read(Json, Error); - } - - return CbFieldIterator(); -} - -CbFieldIterator -LoadCompactBinaryFromJson(std::string_view Json) -{ - std::string Error; - return LoadCompactBinaryFromJson(Json, Error); -} - -////////////////////////////////////////////////////////////////////////// - #if ZEN_WITH_TESTS void uson_forcelink() @@ -2211,130 +1675,6 @@ TEST_CASE("uson.null") } } -TEST_CASE("uson.json") -{ - using namespace std::literals; - - SUBCASE("string") - { - CbObjectWriter Writer; - Writer << "KeyOne" - << "ValueOne"; - Writer << "KeyTwo" - << "ValueTwo"; - CbObject Obj = Writer.Save(); - - StringBuilder<128> Sb; - const char* JsonText = Obj.ToJson(Sb).Data(); - - std::string JsonError; - json11::Json Json = json11::Json::parse(JsonText, JsonError); - - const std::string ValueOne = Json["KeyOne"].string_value(); - const std::string ValueTwo = Json["KeyTwo"].string_value(); - - CHECK(JsonError.empty()); - CHECK(ValueOne == "ValueOne"); - CHECK(ValueTwo == "ValueTwo"); - } - - SUBCASE("number") - { - const float ExpectedFloatValue = 21.21f; - const double ExpectedDoubleValue = 42.42; - - CbObjectWriter Writer; - Writer << "Float" << ExpectedFloatValue; - Writer << "Double" << ExpectedDoubleValue; - - CbObject Obj = Writer.Save(); - - StringBuilder<128> Sb; - const char* JsonText = Obj.ToJson(Sb).Data(); - - std::string JsonError; - json11::Json Json = json11::Json::parse(JsonText, JsonError); - - const float FloatValue = float(Json["Float"].number_value()); - const double DoubleValue = Json["Double"].number_value(); - - CHECK(JsonError.empty()); - CHECK(FloatValue == Approx(ExpectedFloatValue)); - CHECK(DoubleValue == Approx(ExpectedDoubleValue)); - } - - SUBCASE("number.nan") - { - constexpr float FloatNan = std::numeric_limits<float>::quiet_NaN(); - constexpr double DoubleNan = std::numeric_limits<double>::quiet_NaN(); - - CbObjectWriter Writer; - Writer << "FloatNan" << FloatNan; - Writer << "DoubleNan" << DoubleNan; - - CbObject Obj = Writer.Save(); - - StringBuilder<128> Sb; - const char* JsonText = Obj.ToJson(Sb).Data(); - - std::string JsonError; - json11::Json Json = json11::Json::parse(JsonText, JsonError); - - const double FloatValue = Json["FloatNan"].number_value(); - const double DoubleValue = Json["DoubleNan"].number_value(); - - CHECK(JsonError.empty()); - CHECK(FloatValue == 0); - CHECK(DoubleValue == 0); - } - - SUBCASE("stream") - { - const auto MakeObject = [&](std::string_view Name, const std::vector<int>& Fields) -> CbObject { - CbWriter Writer; - Writer.SetName(Name); - Writer.BeginObject(); - for (const auto& Field : Fields) - { - Writer.AddInteger(fmt::format("{}", Field), Field); - } - Writer.EndObject(); - return Writer.Save().AsObject(); - }; - - std::vector<uint8_t> Buffer; - - auto AppendToBuffer = [&](const void* Data, size_t Count) { - const uint8_t* AppendBytes = reinterpret_cast<const uint8_t*>(Data); - Buffer.insert(Buffer.end(), AppendBytes, AppendBytes + Count); - }; - - auto Append = [&](const CbFieldView& Field) { - Field.WriteToStream([&](const void* Data, size_t Count) { - const uint8_t* AppendBytes = reinterpret_cast<const uint8_t*>(Data); - Buffer.insert(Buffer.end(), AppendBytes, AppendBytes + Count); - }); - }; - - CbObject DataObjects[] = {MakeObject("Empty object"sv, {}), - MakeObject("OneField object"sv, {5}), - MakeObject("TwoField object"sv, {-5, 999}), - MakeObject("ThreeField object"sv, {1, 2, -129})}; - for (const CbObject& Object : DataObjects) - { - Object.AsField().WriteToStream(AppendToBuffer); - } - - ExtendableStringBuilder<128> Sb; - CompactBinaryToJson(MemoryView(Buffer.data(), Buffer.size()), Sb); - std::string JsonText = Sb.ToString().c_str(); - std::string JsonError; - json11::Json Json = json11::Json::parse(JsonText, JsonError); - std::string ParsedJsonString = Json.dump(); - CHECK(!ParsedJsonString.empty()); - } -} - TEST_CASE("uson.datetime") { using namespace std::literals; @@ -2362,107 +1702,6 @@ TEST_CASE("uson.datetime") } } -TEST_CASE("json.uson") -{ - using namespace std::literals; - using namespace json11; - - SUBCASE("empty") - { - CbFieldIterator It = LoadCompactBinaryFromJson(""sv); - CHECK(It.HasValue() == false); - } - - SUBCASE("object") - { - const Json JsonObject = Json::object{{"Null", nullptr}, - {"String", "Value1"}, - {"Bool", true}, - {"Number", 46.2}, - {"Array", Json::array{1, 2, 3}}, - {"Object", - Json::object{ - {"String", "Value2"}, - }}}; - - CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); - - CHECK(Cb["Null"].IsNull()); - CHECK(Cb["String"].AsString() == "Value1"sv); - CHECK(Cb["Bool"].AsBool()); - CHECK(Cb["Number"].AsDouble() == 46.2); - CHECK(Cb["Object"].IsObject()); - CbObjectView Object = Cb["Object"].AsObjectView(); - CHECK(Object["String"].AsString() == "Value2"sv); - } - - SUBCASE("array") - { - const Json JsonArray = Json::array{42, 43, 44}; - CbArray Cb = LoadCompactBinaryFromJson(JsonArray.dump()).AsArray(); - - auto It = Cb.CreateIterator(); - CHECK((*It).AsDouble() == 42); - It++; - CHECK((*It).AsDouble() == 43); - It++; - CHECK((*It).AsDouble() == 44); - } - - SUBCASE("objectid") - { - const Oid& Id = Oid::NewOid(); - - StringBuilder<64> Sb; - Id.ToString(Sb); - - Json JsonObject = Json::object{{"value", Sb.ToString()}}; - CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); - - CHECK(Cb["value"sv].IsObjectId()); - CHECK(Cb["value"sv].AsObjectId() == Id); - - Sb.Reset(); - Sb << "0x"; - Id.ToString(Sb); - - JsonObject = Json::object{{"value", Sb.ToString()}}; - Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); - - CHECK(Cb["value"sv].IsObjectId()); - CHECK(Cb["value"sv].AsObjectId() == Id); - } - - SUBCASE("iohash") - { - const uint8_t Data[] = { - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - }; - - const IoHash Hash = IoHash::HashBuffer(Data, sizeof(Data)); - - Json JsonObject = Json::object{{"value", Hash.ToHexString()}}; - CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); - - CHECK(Cb["value"sv].IsHash()); - CHECK(Cb["value"sv].AsHash() == Hash); - - JsonObject = Json::object{{"value", "0x" + Hash.ToHexString()}}; - Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); - - CHECK(Cb["value"sv].IsHash()); - CHECK(Cb["value"sv].AsHash() == Hash); - } -} - ////////////////////////////////////////////////////////////////////////// TEST_SUITE_BEGIN("core.datetime"); diff --git a/src/zencore/compactbinarybuilder.cpp b/src/zencore/compactbinarybuilder.cpp index 5c08d2e6e..63c0b9c5c 100644 --- a/src/zencore/compactbinarybuilder.cpp +++ b/src/zencore/compactbinarybuilder.cpp @@ -15,23 +15,21 @@ namespace zen { -template<typename T> uint64_t -AddUninitialized(std::vector<T>& Vector, uint64_t Count) +AddUninitialized(CbWriter::CbWriterData_t& Vector, uint64_t Count) { const uint64_t Offset = Vector.size(); Vector.resize(Offset + Count); return Offset; } -template<typename T> uint64_t -Append(std::vector<T>& Vector, const T* Data, uint64_t Count) +Append(CbWriter::CbWriterData_t& Vector, const uint8_t* Data, uint64_t Count) { const uint64_t Offset = Vector.size(); Vector.resize(Offset + Count); - memcpy(Vector.data() + Offset, Data, sizeof(T) * Count); + memcpy(Vector.data() + Offset, Data, sizeof(uint8_t) * Count); return Offset; } @@ -76,7 +74,7 @@ IsUniformType(const CbFieldType Type) /** Append the payload from the compact binary value to the array and return its type. */ static inline CbFieldType -AppendCompactBinary(const CbFieldView& Value, std::vector<uint8_t>& OutData) +AppendCompactBinary(const CbFieldView& Value, CbWriter::CbWriterData_t& OutData) { struct FCopy : public CbFieldView { @@ -221,12 +219,11 @@ CbWriter::SetName(const std::string_view Name) State.Flags |= StateFlags::Name; const uint32_t NameLenByteCount = MeasureVarUInt(uint32_t(Name.size())); const int64_t NameLenOffset = Data.size(); - Data.resize(NameLenOffset + NameLenByteCount); + Data.resize(NameLenOffset + NameLenByteCount + Name.size()); - WriteVarUInt(uint64_t(Name.size()), Data.data() + NameLenOffset); + WriteMeasuredVarUInt(uint64_t(Name.size()), NameLenByteCount, Data.data() + NameLenOffset); - const uint8_t* NamePtr = reinterpret_cast<const uint8_t*>(Name.data()); - Data.insert(Data.end(), NamePtr, NamePtr + Name.size()); + memcpy(Data.data() + NameLenOffset + NameLenByteCount, Name.data(), Name.size()); return *this; } @@ -341,7 +338,7 @@ CbWriter::EndObject() const uint64_t Size = uint64_t(Data.size() - PayloadOffset); const uint32_t SizeByteCount = MeasureVarUInt(Size); Data.insert(Data.begin() + PayloadOffset, SizeByteCount, 0); - WriteVarUInt(Size, Data.data() + PayloadOffset); + WriteMeasuredVarUInt(Size, SizeByteCount, Data.data() + PayloadOffset); EndField(bUniform ? CbFieldType::UniformObject : CbFieldType::Object); } @@ -400,8 +397,8 @@ CbWriter::EndArray() const uint64_t Size = uint64_t(Data.size() - PayloadOffset) + CountByteCount; const uint32_t SizeByteCount = MeasureVarUInt(Size); Data.insert(Data.begin() + PayloadOffset, SizeByteCount + CountByteCount, 0); - WriteVarUInt(Size, Data.data() + PayloadOffset); - WriteVarUInt(Count, Data.data() + PayloadOffset + SizeByteCount); + WriteMeasuredVarUInt(Size, SizeByteCount, Data.data() + PayloadOffset); + WriteMeasuredVarUInt(Count, CountByteCount, Data.data() + PayloadOffset + SizeByteCount); EndField(bUniform ? CbFieldType::UniformArray : CbFieldType::Array); } @@ -429,13 +426,13 @@ CbWriter::AddNull() void CbWriter::AddBinary(const void* const Value, const uint64_t Size) { - const size_t SizeByteCount = MeasureVarUInt(Size); + const uint32_t SizeByteCount = MeasureVarUInt(Size); Data.reserve(Data.size() + 1 + SizeByteCount + Size); BeginField(); const size_t SizeOffset = Data.size(); - Data.resize(Data.size() + SizeByteCount); - WriteVarUInt(Size, Data.data() + SizeOffset); - Data.insert(Data.end(), static_cast<const uint8_t*>(Value), static_cast<const uint8_t*>(Value) + Size); + Data.resize(Data.size() + SizeByteCount + Size); + WriteMeasuredVarUInt(Size, SizeByteCount, Data.data() + SizeOffset); + memcpy(Data.data() + SizeOffset + SizeByteCount, Value, Size); EndField(CbFieldType::Binary); } @@ -468,7 +465,7 @@ CbWriter::AddString(const std::string_view Value) Data.resize(Offset + SizeByteCount + Size); uint8_t* StringData = Data.data() + Offset; - WriteVarUInt(Size, StringData); + WriteMeasuredVarUInt(Size, SizeByteCount, StringData); StringData += SizeByteCount; if (Size > 0) { @@ -489,7 +486,7 @@ CbWriter::AddString(const std::wstring_view Value) const int64_t Offset = Data.size(); Data.resize(Offset + SizeByteCount + Size); uint8_t* StringData = Data.data() + Offset; - WriteVarUInt(Size, StringData); + WriteMeasuredVarUInt(Size, SizeByteCount, StringData); StringData += SizeByteCount; if (Size > 0) { @@ -511,7 +508,7 @@ CbWriter::AddInteger(const int32_t Value) const uint32_t MagnitudeByteCount = MeasureVarUInt(Magnitude); const int64_t Offset = Data.size(); Data.resize(Offset + MagnitudeByteCount); - WriteVarUInt(Magnitude, Data.data() + Offset); + WriteMeasuredVarUInt(Magnitude, MagnitudeByteCount, Data.data() + Offset); EndField(CbFieldType::IntegerNegative); } @@ -526,7 +523,7 @@ CbWriter::AddInteger(const int64_t Value) const uint64_t Magnitude = ~uint64_t(Value); const uint32_t MagnitudeByteCount = MeasureVarUInt(Magnitude); const uint64_t Offset = AddUninitialized(Data, MagnitudeByteCount); - WriteVarUInt(Magnitude, Data.data() + Offset); + WriteMeasuredVarUInt(Magnitude, MagnitudeByteCount, Data.data() + Offset); EndField(CbFieldType::IntegerNegative); } @@ -537,7 +534,7 @@ CbWriter::AddInteger(const uint32_t Value) BeginField(); const uint32_t ValueByteCount = MeasureVarUInt(Value); const uint64_t Offset = AddUninitialized(Data, ValueByteCount); - WriteVarUInt(Value, Data.data() + Offset); + WriteMeasuredVarUInt(Value, ValueByteCount, Data.data() + Offset); EndField(CbFieldType::IntegerPositive); } @@ -548,7 +545,7 @@ CbWriter::AddInteger(const uint64_t Value) BeginField(); const uint32_t ValueByteCount = MeasureVarUInt(Value); const uint64_t Offset = AddUninitialized(Data, ValueByteCount); - WriteVarUInt(Value, Data.data() + Offset); + WriteMeasuredVarUInt(Value, ValueByteCount, Data.data() + Offset); EndField(CbFieldType::IntegerPositive); } diff --git a/src/zencore/compactbinaryfile.cpp b/src/zencore/compactbinaryfile.cpp index f2121a0bd..ec2fc3cd5 100644 --- a/src/zencore/compactbinaryfile.cpp +++ b/src/zencore/compactbinaryfile.cpp @@ -1,7 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "zencore/compactbinaryfile.h" -#include "zencore/compactbinaryvalidation.h" +#include "zencore/compactbinaryutil.h" #include <zencore/filesystem.h> @@ -19,12 +19,12 @@ LoadCompactBinaryObject(const std::filesystem::path& FilePath) IoBuffer ObjectBuffer = ObjectFile.Flatten(); - if (CbValidateError Result = ValidateCompactBinary(ObjectBuffer, CbValidateMode::All); Result == CbValidateError::None) + CbValidateError ValidateResult; + CbObject Object = ValidateAndReadCompactBinaryObject(IoBuffer(ObjectBuffer), ValidateResult); + if (ValidateResult == CbValidateError::None) { - CbObject Object = LoadCompactBinaryObject(ObjectBuffer); const IoHash WorkerId = IoHash::HashBuffer(ObjectBuffer); - - return {.Object = Object, .Hash = WorkerId}; + return {.Object = std::move(Object), .Hash = WorkerId}; } return {.Hash = IoHash::Zero}; diff --git a/src/zencore/compactbinaryjson.cpp b/src/zencore/compactbinaryjson.cpp new file mode 100644 index 000000000..02f22ba4d --- /dev/null +++ b/src/zencore/compactbinaryjson.cpp @@ -0,0 +1,884 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/compactbinary.h" +#include "zencore/compactbinarybuilder.h" +#include "zencore/compactbinaryvalue.h" + +#include <zencore/assertfmt.h> +#include <zencore/base64.h> +#include <zencore/fmtutils.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <fmt/format.h> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CbJsonWriter +{ +public: + explicit CbJsonWriter(StringBuilderBase& InBuilder, bool AddTypeComment) : Builder(InBuilder), m_AddTypeComment(AddTypeComment) + { + NewLineAndIndent << LINE_TERMINATOR_ANSI; + } + + void BeginObject() + { + Builder << '{'; + NewLineAndIndent << '\t'; + NeedsNewLine = true; + } + + void EndObject() + { + NewLineAndIndent.RemoveSuffix(1); + if (NeedsComma) + { + WriteOptionalNewLine(); + } + Builder << '}'; + } + + void BeginArray() + { + Builder << '['; + NewLineAndIndent << '\t'; + NeedsNewLine = true; + } + + void EndArray() + { + NewLineAndIndent.RemoveSuffix(1); + if (NeedsComma) + { + WriteOptionalNewLine(); + } + Builder << ']'; + } + + void WriteField(CbFieldView Field) + { + using namespace std::literals; + + WriteOptionalComma(); + WriteOptionalNewLine(); + + if (std::u8string_view Name = Field.GetU8Name(); !Name.empty()) + { + AppendQuotedString(Name); + Builder << ": "sv; + } + + switch (CbValue Accessor = Field.GetValue(); Accessor.GetType()) + { + case CbFieldType::Null: + if (m_AddTypeComment) + { + Builder << "[Null] "; + } + Builder << "null"sv; + break; + case CbFieldType::Object: + { + if (m_AddTypeComment) + { + Builder << "[Object] "; + } + BeginObject(); + for (CbFieldView It : Field) + { + WriteField(It); + } + EndObject(); + } + break; + case CbFieldType::UniformObject: + { + if (m_AddTypeComment) + { + Builder << "[UniformObject] "; + } + BeginObject(); + for (CbFieldView It : Field) + { + WriteField(It); + } + EndObject(); + } + break; + case CbFieldType::Array: + { + if (m_AddTypeComment) + { + Builder << "[Array] "; + } + BeginArray(); + for (CbFieldView It : Field) + { + WriteField(It); + } + EndArray(); + } + break; + case CbFieldType::UniformArray: + { + if (m_AddTypeComment) + { + Builder << "[UniformArray] "; + } + BeginArray(); + for (CbFieldView It : Field) + { + WriteField(It); + } + EndArray(); + } + break; + case CbFieldType::Binary: + if (m_AddTypeComment) + { + Builder << "[Binary] "; + } + AppendBase64String(Accessor.AsBinary()); + break; + case CbFieldType::String: + if (m_AddTypeComment) + { + Builder << "[String] "; + } + AppendQuotedString(Accessor.AsU8String()); + break; + case CbFieldType::IntegerPositive: + if (m_AddTypeComment) + { + Builder << "[IntegerPositive] "; + } + Builder << Accessor.AsIntegerPositive(); + break; + case CbFieldType::IntegerNegative: + if (m_AddTypeComment) + { + Builder << "[IntegerNegative] "; + } + Builder << Accessor.AsIntegerNegative(); + break; + case CbFieldType::Float32: + { + if (m_AddTypeComment) + { + Builder << "[Float32] "; + } + const float Value = Accessor.AsFloat32(); + if (std::isfinite(Value)) + { + Builder.Append(fmt::format("{:.9g}", Value)); + } + else + { + Builder << "null"sv; + } + } + break; + case CbFieldType::Float64: + { + if (m_AddTypeComment) + { + Builder << "[Float64] "; + } + const double Value = Accessor.AsFloat64(); + if (std::isfinite(Value)) + { + Builder.Append(fmt::format("{:.17g}", Value)); + } + else + { + Builder << "null"sv; + } + } + break; + case CbFieldType::BoolFalse: + if (m_AddTypeComment) + { + Builder << "[BoolFalse] "; + } + Builder << "false"sv; + break; + case CbFieldType::BoolTrue: + if (m_AddTypeComment) + { + Builder << "[BoolTrue] "; + } + Builder << "true"sv; + break; + case CbFieldType::ObjectAttachment: + { + if (m_AddTypeComment) + { + Builder << "[ObjectAttachment] "; + } + Builder << '"'; + Accessor.AsAttachment().ToHexString(Builder); + Builder << '"'; + } + break; + case CbFieldType::BinaryAttachment: + { + if (m_AddTypeComment) + { + Builder << "[BinaryAttachment] "; + } + Builder << '"'; + Accessor.AsAttachment().ToHexString(Builder); + Builder << '"'; + } + break; + case CbFieldType::Hash: + { + if (m_AddTypeComment) + { + Builder << "[Hash] "; + } + Builder << '"'; + Accessor.AsHash().ToHexString(Builder); + Builder << '"'; + } + break; + case CbFieldType::Uuid: + { + if (m_AddTypeComment) + { + Builder << "[Uuid] "; + } + Builder << '"'; + Accessor.AsUuid().ToString(Builder); + Builder << '"'; + } + break; + case CbFieldType::DateTime: + if (m_AddTypeComment) + { + Builder << "[DateTime] "; + } + Builder << '"' << DateTime(Accessor.AsDateTimeTicks()).ToIso8601() << '"'; + break; + case CbFieldType::TimeSpan: + { + if (m_AddTypeComment) + { + Builder << "[TimeSpan] "; + } + const TimeSpan Span(Accessor.AsTimeSpanTicks()); + if (Span.GetDays() == 0) + { + Builder << '"' << Span.ToString("%h:%m:%s.%n") << '"'; + } + else + { + Builder << '"' << Span.ToString("%d.%h:%m:%s.%n") << '"'; + } + break; + } + case CbFieldType::ObjectId: + if (m_AddTypeComment) + { + Builder << "[ObjectId] "; + } + Builder << '"'; + Accessor.AsObjectId().ToString(Builder); + Builder << '"'; + break; + case CbFieldType::CustomById: + { + if (m_AddTypeComment) + { + Builder << "[CustomById] "; + } + CbCustomById Custom = Accessor.AsCustomById(); + Builder << "{ \"Id\": "; + Builder << Custom.Id; + Builder << ", \"Data\": "; + AppendBase64String(Custom.Data); + Builder << " }"; + break; + } + case CbFieldType::CustomByName: + { + if (m_AddTypeComment) + { + Builder << "[CustomByName] "; + } + CbCustomByName Custom = Accessor.AsCustomByName(); + Builder << "{ \"Name\": "; + AppendQuotedString(Custom.Name); + Builder << ", \"Data\": "; + AppendBase64String(Custom.Data); + Builder << " }"; + break; + } + default: + ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType())); + break; + } + + NeedsComma = true; + NeedsNewLine = true; + } + +private: + void WriteOptionalComma() + { + if (NeedsComma) + { + NeedsComma = false; + Builder << ','; + } + } + + void WriteOptionalNewLine() + { + if (NeedsNewLine) + { + NeedsNewLine = false; + Builder << NewLineAndIndent; + } + } + + void AppendQuotedString(std::u8string_view Value) + { + using namespace std::literals; + + const AsciiSet EscapeSet( + "\\\"\b\f\n\r\t" + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f" + "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"); + + Builder << '\"'; + while (!Value.empty()) + { + std::u8string_view Verbatim = AsciiSet::FindPrefixWithout(Value, EscapeSet); + Builder << Verbatim; + + Value = Value.substr(Verbatim.size()); + + std::u8string_view Escape = AsciiSet::FindPrefixWith(Value, EscapeSet); + for (char Char : Escape) + { + switch (Char) + { + case '\\': + Builder << "\\\\"sv; + break; + case '\"': + Builder << "\\\""sv; + break; + case '\b': + Builder << "\\b"sv; + break; + case '\f': + Builder << "\\f"sv; + break; + case '\n': + Builder << "\\n"sv; + break; + case '\r': + Builder << "\\r"sv; + break; + case '\t': + Builder << "\\t"sv; + break; + default: + Builder << Char; + break; + } + } + Value = Value.substr(Escape.size()); + } + Builder << '\"'; + } + + void AppendBase64String(MemoryView Value) + { + Builder << '"'; + ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024); + const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); + const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); + Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + Builder << '"'; + } + +private: + StringBuilderBase& Builder; + ExtendableStringBuilder<32> NewLineAndIndent; + const bool m_AddTypeComment; + bool NeedsComma{false}; + bool NeedsNewLine{false}; +}; + +void +CompactBinaryToJson(const CbObjectView& Object, StringBuilderBase& Builder, bool AddTypeComment) +{ + CbJsonWriter Writer(Builder, AddTypeComment); + Writer.WriteField(Object.AsFieldView()); +} + +void +CompactBinaryToJson(const CbArrayView& Array, StringBuilderBase& Builder) +{ + CbJsonWriter Writer(Builder, /*AddTypeComment*/ false); + Writer.WriteField(Array.AsFieldView()); +} + +void +CompactBinaryToJson(MemoryView Data, StringBuilderBase& InBuilder, bool AddTypeComment) +{ + std::vector<CbFieldView> Fields = ReadCompactBinaryStream(Data); + CbJsonWriter Writer(InBuilder, AddTypeComment); + if (!Fields.empty()) + { + if (Fields.size() == 1) + { + Writer.WriteField(Fields[0]); + return; + } + bool UseTopLevelObject = Fields[0].HasName(); + if (UseTopLevelObject) + { + Writer.BeginObject(); + } + else + { + Writer.BeginArray(); + } + for (const CbFieldView& Field : Fields) + { + Writer.WriteField(Field); + } + if (UseTopLevelObject) + { + Writer.EndObject(); + } + else + { + Writer.EndArray(); + } + } +} + +class CbJsonReader +{ +public: + static CbFieldIterator Read(std::string_view JsonText, std::string& Error) + { + using namespace json11; + + const Json Json = Json::parse(std::string(JsonText), Error); + + if (Error.empty()) + { + CbWriter Writer; + if (ReadField(Writer, Json, std::string_view(), Error)) + { + return Writer.Save(); + } + } + + return CbFieldIterator(); + } + +private: + static bool ReadField(CbWriter& Writer, const json11::Json& Json, const std::string_view FieldName, std::string& Error) + { + using namespace json11; + + switch (Json.type()) + { + case Json::Type::OBJECT: + { + if (FieldName.empty()) + { + Writer.BeginObject(); + } + else + { + Writer.BeginObject(FieldName); + } + + for (const auto& Kv : Json.object_items()) + { + const std::string& Name = Kv.first; + const json11::Json& Item = Kv.second; + + if (ReadField(Writer, Item, Name, Error) == false) + { + return false; + } + } + + Writer.EndObject(); + } + break; + case Json::Type::ARRAY: + { + if (FieldName.empty()) + { + Writer.BeginArray(); + } + else + { + Writer.BeginArray(FieldName); + } + + for (const json11::Json& Item : Json.array_items()) + { + if (ReadField(Writer, Item, std::string_view(), Error) == false) + { + return false; + } + } + + Writer.EndArray(); + } + break; + case Json::Type::NUL: + { + if (FieldName.empty()) + { + Writer.AddNull(); + } + else + { + Writer.AddNull(FieldName); + } + } + break; + case Json::Type::BOOL: + { + if (FieldName.empty()) + { + Writer.AddBool(Json.bool_value()); + } + else + { + Writer.AddBool(FieldName, Json.bool_value()); + } + } + break; + case Json::Type::NUMBER: + { + if (FieldName.empty()) + { + Writer.AddFloat(Json.number_value()); + } + else + { + Writer.AddFloat(FieldName, Json.number_value()); + } + } + break; + case Json::Type::STRING: + { + Oid Id; + if (Oid::TryParse(Json.string_value(), Id)) + { + if (FieldName.empty()) + { + Writer.AddObjectId(Id); + } + else + { + Writer.AddObjectId(FieldName, Id); + } + + return true; + } + + IoHash Hash; + if (IoHash::TryParse(Json.string_value(), Hash)) + { + if (FieldName.empty()) + { + Writer.AddHash(Hash); + } + else + { + Writer.AddHash(FieldName, Hash); + } + + return true; + } + + if (FieldName.empty()) + { + Writer.AddString(Json.string_value()); + } + else + { + Writer.AddString(FieldName, Json.string_value()); + } + } + break; + default: + break; + } + + return true; + } +}; + +CbFieldIterator +LoadCompactBinaryFromJson(std::string_view Json, std::string& Error) +{ + if (Json.empty() == false) + { + return CbJsonReader::Read(Json, Error); + } + + return CbFieldIterator(); +} + +CbFieldIterator +LoadCompactBinaryFromJson(std::string_view Json) +{ + std::string Error; + return LoadCompactBinaryFromJson(Json, Error); +} + +#if ZEN_WITH_TESTS +void +cbjson_forcelink() +{ +} + +TEST_CASE("uson.json") +{ + using namespace std::literals; + + SUBCASE("string") + { + CbObjectWriter Writer; + Writer << "KeyOne" + << "ValueOne"; + Writer << "KeyTwo" + << "ValueTwo"; + CbObject Obj = Writer.Save(); + + StringBuilder<128> Sb; + const char* JsonText = Obj.ToJson(Sb).Data(); + + std::string JsonError; + json11::Json Json = json11::Json::parse(JsonText, JsonError); + + const std::string ValueOne = Json["KeyOne"].string_value(); + const std::string ValueTwo = Json["KeyTwo"].string_value(); + + CHECK(JsonError.empty()); + CHECK(ValueOne == "ValueOne"); + CHECK(ValueTwo == "ValueTwo"); + } + + SUBCASE("number") + { + const float ExpectedFloatValue = 21.21f; + const double ExpectedDoubleValue = 42.42; + + CbObjectWriter Writer; + Writer << "Float" << ExpectedFloatValue; + Writer << "Double" << ExpectedDoubleValue; + + CbObject Obj = Writer.Save(); + + StringBuilder<128> Sb; + const char* JsonText = Obj.ToJson(Sb).Data(); + + std::string JsonError; + json11::Json Json = json11::Json::parse(JsonText, JsonError); + + const float FloatValue = float(Json["Float"].number_value()); + const double DoubleValue = Json["Double"].number_value(); + + CHECK(JsonError.empty()); + CHECK(FloatValue == Approx(ExpectedFloatValue)); + CHECK(DoubleValue == Approx(ExpectedDoubleValue)); + } + + SUBCASE("number.nan") + { + constexpr float FloatNan = std::numeric_limits<float>::quiet_NaN(); + constexpr double DoubleNan = std::numeric_limits<double>::quiet_NaN(); + + CbObjectWriter Writer; + Writer << "FloatNan" << FloatNan; + Writer << "DoubleNan" << DoubleNan; + + CbObject Obj = Writer.Save(); + + StringBuilder<128> Sb; + const char* JsonText = Obj.ToJson(Sb).Data(); + + std::string JsonError; + json11::Json Json = json11::Json::parse(JsonText, JsonError); + + const double FloatValue = Json["FloatNan"].number_value(); + const double DoubleValue = Json["DoubleNan"].number_value(); + + CHECK(JsonError.empty()); + CHECK(FloatValue == 0); + CHECK(DoubleValue == 0); + } + + SUBCASE("stream") + { + const auto MakeObject = [&](std::string_view Name, const std::vector<int>& Fields) -> CbObject { + CbWriter Writer; + Writer.SetName(Name); + Writer.BeginObject(); + for (const auto& Field : Fields) + { + Writer.AddInteger(fmt::format("{}", Field), Field); + } + Writer.EndObject(); + return Writer.Save().AsObject(); + }; + + std::vector<uint8_t> Buffer; + + auto AppendToBuffer = [&](const void* Data, size_t Count) { + const uint8_t* AppendBytes = reinterpret_cast<const uint8_t*>(Data); + Buffer.insert(Buffer.end(), AppendBytes, AppendBytes + Count); + }; + + auto Append = [&](const CbFieldView& Field) { + Field.WriteToStream([&](const void* Data, size_t Count) { + const uint8_t* AppendBytes = reinterpret_cast<const uint8_t*>(Data); + Buffer.insert(Buffer.end(), AppendBytes, AppendBytes + Count); + }); + }; + + CbObject DataObjects[] = {MakeObject("Empty object"sv, {}), + MakeObject("OneField object"sv, {5}), + MakeObject("TwoField object"sv, {-5, 999}), + MakeObject("ThreeField object"sv, {1, 2, -129})}; + for (const CbObject& Object : DataObjects) + { + Object.AsField().WriteToStream(AppendToBuffer); + } + + ExtendableStringBuilder<128> Sb; + CompactBinaryToJson(MemoryView(Buffer.data(), Buffer.size()), Sb); + std::string JsonText = Sb.ToString().c_str(); + std::string JsonError; + json11::Json Json = json11::Json::parse(JsonText, JsonError); + std::string ParsedJsonString = Json.dump(); + CHECK(!ParsedJsonString.empty()); + } +} + +TEST_CASE("json.uson") +{ + using namespace std::literals; + using namespace json11; + + SUBCASE("empty") + { + CbFieldIterator It = LoadCompactBinaryFromJson(""sv); + CHECK(It.HasValue() == false); + } + + SUBCASE("object") + { + const Json JsonObject = Json::object{{"Null", nullptr}, + {"String", "Value1"}, + {"Bool", true}, + {"Number", 46.2}, + {"Array", Json::array{1, 2, 3}}, + {"Object", + Json::object{ + {"String", "Value2"}, + }}}; + + CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["Null"].IsNull()); + CHECK(Cb["String"].AsString() == "Value1"sv); + CHECK(Cb["Bool"].AsBool()); + CHECK(Cb["Number"].AsDouble() == 46.2); + CHECK(Cb["Object"].IsObject()); + CbObjectView Object = Cb["Object"].AsObjectView(); + CHECK(Object["String"].AsString() == "Value2"sv); + } + + SUBCASE("array") + { + const Json JsonArray = Json::array{42, 43, 44}; + CbArray Cb = LoadCompactBinaryFromJson(JsonArray.dump()).AsArray(); + + auto It = Cb.CreateIterator(); + CHECK((*It).AsDouble() == 42); + It++; + CHECK((*It).AsDouble() == 43); + It++; + CHECK((*It).AsDouble() == 44); + } + + SUBCASE("objectid") + { + const Oid& Id = Oid::NewOid(); + + StringBuilder<64> Sb; + Id.ToString(Sb); + + Json JsonObject = Json::object{{"value", Sb.ToString()}}; + CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["value"sv].IsObjectId()); + CHECK(Cb["value"sv].AsObjectId() == Id); + + Sb.Reset(); + Sb << "0x"; + Id.ToString(Sb); + + JsonObject = Json::object{{"value", Sb.ToString()}}; + Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["value"sv].IsObjectId()); + CHECK(Cb["value"sv].AsObjectId() == Id); + } + + SUBCASE("iohash") + { + const uint8_t Data[] = { + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + }; + + const IoHash Hash = IoHash::HashBuffer(Data, sizeof(Data)); + + Json JsonObject = Json::object{{"value", Hash.ToHexString()}}; + CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["value"sv].IsHash()); + CHECK(Cb["value"sv].AsHash() == Hash); + + JsonObject = Json::object{{"value", "0x" + Hash.ToHexString()}}; + Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["value"sv].IsHash()); + CHECK(Cb["value"sv].AsHash() == Hash); + } +} + +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp index a4fa38a1d..ffe64f2e9 100644 --- a/src/zencore/compactbinarypackage.cpp +++ b/src/zencore/compactbinarypackage.cpp @@ -3,40 +3,56 @@ #include "zencore/compactbinarypackage.h" #include <zencore/compactbinarybuilder.h> #include <zencore/compactbinaryvalidation.h> +#include <zencore/eastlutil.h> #include <zencore/endian.h> #include <zencore/stream.h> #include <zencore/testing.h> +#include <EASTL/span.h> + namespace zen { /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -CbAttachment::CbAttachment(const CompressedBuffer& InValue, const IoHash& Hash) : CbAttachment(InValue.MakeOwned(), Hash) +CbAttachment::CbAttachment(const CbObject& InValue, const IoHash* const InHash) { -} + auto SetValue = [&](const CbObject& ValueToSet) { + if (InHash) + { + Value.emplace<CbObject>(ValueToSet); + Hash = *InHash; + } + else + { + Value.emplace<CbObject>(ValueToSet); + Hash = ValueToSet.GetHash(); + } + }; -CbAttachment::CbAttachment(const SharedBuffer& InValue) : CbAttachment(CompositeBuffer(InValue)) -{ + MemoryView View; + if (!InValue.IsOwned() || !InValue.TryGetSerializedView(View)) + { + SetValue(CbObject::Clone(InValue)); + } + else + { + SetValue(InValue); + } } -CbAttachment::CbAttachment(const SharedBuffer& InValue, const IoHash& InHash) : CbAttachment(CompositeBuffer(InValue), InHash) +CbAttachment::CbAttachment(const CompressedBuffer& InValue, const IoHash& Hash) : Hash(Hash), Value(InValue) { + ZEN_ASSERT(!std::get<CompressedBuffer>(Value).IsNull()); } -CbAttachment::CbAttachment(const CompositeBuffer& InValue) -: Hash(InValue.IsNull() ? IoHash::Zero : IoHash::HashBuffer(InValue)) -, Value(InValue) +CbAttachment::CbAttachment(CompressedBuffer&& InValue, const IoHash& InHash) : Hash(InHash), Value(std::move(InValue)) { - if (std::get<CompositeBuffer>(Value).IsNull()) - { - Value.emplace<std::nullptr_t>(); - } + ZEN_ASSERT(!std::get<CompressedBuffer>(Value).IsNull()); } CbAttachment::CbAttachment(CompositeBuffer&& InValue) : Hash(InValue.IsNull() ? IoHash::Zero : IoHash::HashBuffer(InValue)) , Value(std::move(InValue)) - { if (std::get<CompositeBuffer>(Value).IsNull()) { @@ -44,7 +60,7 @@ CbAttachment::CbAttachment(CompositeBuffer&& InValue) } } -CbAttachment::CbAttachment(CompositeBuffer&& InValue, const IoHash& InHash) : Hash(InHash), Value(InValue) +CbAttachment::CbAttachment(CompositeBuffer&& InValue, const IoHash& InHash) : Hash(InHash), Value(std::move(InValue)) { if (std::get<CompositeBuffer>(Value).IsNull()) { @@ -52,40 +68,6 @@ CbAttachment::CbAttachment(CompositeBuffer&& InValue, const IoHash& InHash) : Ha } } -CbAttachment::CbAttachment(CompressedBuffer&& InValue, const IoHash& InHash) : Hash(InHash), Value(InValue) -{ - if (std::get<CompressedBuffer>(Value).IsNull()) - { - Value.emplace<std::nullptr_t>(); - } -} - -CbAttachment::CbAttachment(const CbObject& InValue, const IoHash* const InHash) -{ - auto SetValue = [&](const CbObject& ValueToSet) { - if (InHash) - { - Value.emplace<CbObject>(ValueToSet); - Hash = *InHash; - } - else - { - Value.emplace<CbObject>(ValueToSet); - Hash = ValueToSet.GetHash(); - } - }; - - MemoryView View; - if (!InValue.IsOwned() || !InValue.TryGetSerializedView(View)) - { - SetValue(CbObject::Clone(InValue)); - } - else - { - SetValue(InValue); - } -} - bool CbAttachment::TryLoad(IoBuffer& InBuffer, BufferAllocator Allocator) { @@ -186,7 +168,7 @@ TryLoad_ArchiveFieldIntoAttachment(CbAttachment& TargetAttachment, CbField&& Fie { return false; } - TargetAttachment = CbAttachment(CompositeBuffer(Buffer), BinaryAttachmentHash); + TargetAttachment = CbAttachment(std::move(Buffer), BinaryAttachmentHash); } else if (SharedBuffer Buffer = Field.AsBinary(); !Field.HasError()) { @@ -201,7 +183,7 @@ TryLoad_ArchiveFieldIntoAttachment(CbAttachment& TargetAttachment, CbField&& Fie else { // Is an uncompressed empty binary blob - TargetAttachment = CbAttachment(CompositeBuffer(Buffer), IoHash::HashBuffer(nullptr, 0)); + TargetAttachment = CbAttachment(std::move(Buffer), IoHash::HashBuffer(nullptr, 0)); } } else @@ -282,7 +264,7 @@ CbAttachment::GetHash() const return Hash; } -CompositeBuffer +const CompositeBuffer& CbAttachment::AsCompositeBinary() const { if (const CompositeBuffer* BinValue = std::get_if<CompositeBuffer>(&Value)) @@ -304,7 +286,7 @@ CbAttachment::AsBinary() const return {}; } -CompressedBuffer +const CompressedBuffer& CbAttachment::AsCompressedBinary() const { if (const CompressedBuffer* CompValue = std::get_if<CompressedBuffer>(&Value)) @@ -329,6 +311,11 @@ CbAttachment::AsObject() const /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +CbPackage::CbPackage() +{ + Attachments.reserve(16); +} + void CbPackage::SetObject(CbObject InObject, const IoHash* InObjectHash, AttachmentResolver* InResolver) { @@ -357,6 +344,12 @@ CbPackage::SetObject(CbObject InObject, const IoHash* InObjectHash, AttachmentRe } void +CbPackage::ReserveAttachments(size_t Count) +{ + Attachments.reserve(Count); +} + +void CbPackage::AddAttachment(const CbAttachment& Attachment, AttachmentResolver* Resolver) { if (!Attachment.IsNull()) @@ -386,17 +379,22 @@ CbPackage::AddAttachments(std::span<const CbAttachment> InAttachments) { return; } + for (const CbAttachment& Attachment : InAttachments) + { + ZEN_ASSERT(!Attachment.IsNull()); + } + // Assume we have no duplicates! Attachments.insert(Attachments.end(), InAttachments.begin(), InAttachments.end()); std::sort(Attachments.begin(), Attachments.end()); - ZEN_ASSERT_SLOW(std::unique(Attachments.begin(), Attachments.end()) == Attachments.end()); + ZEN_ASSERT_SLOW(eastl::unique(Attachments.begin(), Attachments.end()) == Attachments.end()); } int32_t CbPackage::RemoveAttachment(const IoHash& Hash) { return gsl::narrow_cast<int32_t>( - std::erase_if(Attachments, [&Hash](const CbAttachment& Attachment) -> bool { return Attachment.GetHash() == Hash; })); + erase_if(Attachments, [&Hash](const CbAttachment& Attachment) -> bool { return Attachment.GetHash() == Hash; })); } bool @@ -741,7 +739,7 @@ namespace legacy { } else { - Package.AddAttachment(CbAttachment(CompositeBuffer(std::move(Buffer)), Hash)); + Package.AddAttachment(CbAttachment(std::move(Buffer), Hash)); } } } diff --git a/src/zencore/compactbinaryutil.cpp b/src/zencore/compactbinaryutil.cpp new file mode 100644 index 000000000..074bdaffd --- /dev/null +++ b/src/zencore/compactbinaryutil.cpp @@ -0,0 +1,39 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/compactbinaryutil.h> + +#include <zencore/compress.h> +#include <zencore/filesystem.h> + +namespace zen { + +CbObject +ValidateAndReadCompactBinaryObject(const SharedBuffer&& Payload, CbValidateError& OutError) +{ + if (Payload.GetSize() > 0) + { + if (OutError = ValidateCompactBinary(Payload.GetView(), CbValidateMode::Default); OutError == CbValidateError::None) + { + CbObject Object(std::move(Payload)); + if (Object.GetView().GetSize() != Payload.GetSize()) + { + OutError |= CbValidateError::OutOfBounds; + return {}; + } + return Object; + } + } + return CbObject(); +} + +CbObject +ValidateAndReadCompactBinaryObject(const CompressedBuffer&& Payload, CbValidateError& OutError) +{ + if (CompositeBuffer Decompressed = Payload.DecompressToComposite()) + { + return ValidateAndReadCompactBinaryObject(std::move(Decompressed).Flatten(), OutError); + } + return CbObject(); +} + +} // namespace zen diff --git a/src/zencore/compactbinaryvalidation.cpp b/src/zencore/compactbinaryvalidation.cpp index 462978f63..d7292f405 100644 --- a/src/zencore/compactbinaryvalidation.cpp +++ b/src/zencore/compactbinaryvalidation.cpp @@ -4,7 +4,7 @@ #include <zencore/compactbinarypackage.h> #include <zencore/endian.h> -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <zencore/string.h> #include <zencore/testing.h> @@ -86,23 +86,24 @@ ValidateCbFieldType(MemoryView& View, CbValidateMode Mode, CbValidateError& Erro static uint64_t ValidateCbUInt(MemoryView& View, CbValidateMode Mode, CbValidateError& Error) { - if (View.GetSize() > 0 && View.GetSize() >= MeasureVarUInt(View.GetData())) + size_t ViewSize = View.GetSize(); + if (ViewSize > 0) { - uint32_t ValueByteCount; - const uint64_t Value = ReadVarUInt(View.GetData(), ValueByteCount); - if (EnumHasAnyFlags(Mode, CbValidateMode::Format) && ValueByteCount > MeasureVarUInt(Value)) + uint32_t ValueByteCount = MeasureVarUInt(View.GetData()); + if (ViewSize >= ValueByteCount) { - AddError(Error, CbValidateError::InvalidInteger); + const uint64_t Value = ReadMeasuredVarUInt(View.GetData(), ValueByteCount); + if (EnumHasAnyFlags(Mode, CbValidateMode::Format) && ValueByteCount != MeasureVarUInt(Value)) + { + AddError(Error, CbValidateError::InvalidInteger); + } + View += ValueByteCount; + return Value; } - View += ValueByteCount; - return Value; - } - else - { - AddError(Error, CbValidateError::OutOfBounds); - View.Reset(); - return 0; } + AddError(Error, CbValidateError::OutOfBounds); + View.Reset(); + return 0; } /** @@ -134,6 +135,37 @@ ValidateCbFloat64(MemoryView& View, CbValidateMode Mode, CbValidateError& Error) } /** + * Validate and read a fixed-size value from the view. + * + * Modifies the view to start at the end of the value, and adds error flags if applicable. + */ +static MemoryView +ValidateCbFixedValue(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, uint64_t Size) +{ + ZEN_UNUSED(Mode); + + const MemoryView Value = View.Left(Size); + View += Size; + if (Value.GetSize() < Size) + { + AddError(Error, CbValidateError::OutOfBounds); + } + return Value; +}; + +/** + * Validate and read a value from the view where the view begins with the value size. + * + * Modifies the view to start at the end of the value, and adds error flags if applicable. + */ +static MemoryView +ValidateCbDynamicValue(MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + const uint64_t ValueSize = ValidateCbUInt(View, Mode, Error); + return ValidateCbFixedValue(View, Mode, Error, ValueSize); +} + +/** * Validate and read a string from the view. * * Modifies the view to start at the end of the string, and adds error flags if applicable. @@ -377,8 +409,20 @@ ValidateCbField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, c ValidateFixedPayload(12); break; case CbFieldType::CustomById: + { + MemoryView Value = ValidateCbDynamicValue(View, Mode, Error); + ValidateCbUInt(Value, Mode, Error); + } + break; case CbFieldType::CustomByName: - ZEN_NOT_IMPLEMENTED(); // TODO: FIX! + { + MemoryView Value = ValidateCbDynamicValue(View, Mode, Error); + const std::string_view TypeName = ValidateCbString(Value, Mode, Error); + if (TypeName.empty() && !EnumHasAnyFlags(Error, CbValidateError::OutOfBounds)) + { + AddError(Error, CbValidateError::InvalidType); + } + } break; } @@ -544,10 +588,17 @@ ValidateCompactBinary(MemoryView View, CbValidateMode Mode, CbFieldType Type) CbValidateError Error = CbValidateError::None; if (EnumHasAnyFlags(Mode, CbValidateMode::All)) { - ValidateCbField(View, Mode, Error, Type); - if (!View.IsEmpty() && EnumHasAnyFlags(Mode, CbValidateMode::Padding)) + if (View.IsEmpty()) { - AddError(Error, CbValidateError::Padding); + AddError(Error, CbValidateError::OutOfBounds); + } + else + { + ValidateCbField(View, Mode, Error, Type); + if (!View.IsEmpty() && EnumHasAnyFlags(Mode, CbValidateMode::Padding)) + { + AddError(Error, CbValidateError::Padding); + } } } return Error; @@ -654,10 +705,11 @@ ToString(const CbValidateError Error) ExtendableStringBuilder<128> Out; - auto AppendFlag = [&, IsFirst = false](std::string_view FlagString) { + auto AppendFlag = [&, IsFirst = true](std::string_view FlagString) mutable { if (!IsFirst) Out.Append('|'); Out.Append(FlagString); + IsFirst = false; }; #define _ENUM_CASE(V) \ @@ -686,7 +738,11 @@ ToString(const CbValidateError Error) #undef _ENUM_CASE - return "Error"; + if (Out.Size() == 0) + { + return "Error"; + } + return Out.ToString(); } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/zencore/compactbinaryyaml.cpp b/src/zencore/compactbinaryyaml.cpp new file mode 100644 index 000000000..3a9705684 --- /dev/null +++ b/src/zencore/compactbinaryyaml.cpp @@ -0,0 +1,352 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/compactbinary.h" +#include "zencore/compactbinarybuilder.h" +#include "zencore/compactbinaryvalue.h" + +#include <zencore/assertfmt.h> +#include <zencore/base64.h> +#include <zencore/fmtutils.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <fmt/format.h> +#include <string_view> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <ryml/ryml.hpp> +#include <ryml/ryml_std.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +class CbYamlWriter +{ +public: + explicit CbYamlWriter(StringBuilderBase& InBuilder) : m_StrBuilder(InBuilder) { m_NodeStack.push_back(m_Tree.rootref()); } + + void WriteField(CbFieldView Field) + { + ryml::NodeRef Node; + + if (m_IsFirst) + { + Node = Top(); + + m_IsFirst = false; + } + else + { + Node = Top().append_child(); + } + + if (std::u8string_view Name = Field.GetU8Name(); !Name.empty()) + { + Node.set_key_serialized(ryml::csubstr((const char*)Name.data(), Name.size())); + } + + switch (CbValue Accessor = Field.GetValue(); Accessor.GetType()) + { + case CbFieldType::Null: + Node.set_val("null"); + break; + case CbFieldType::Object: + case CbFieldType::UniformObject: + Node |= ryml::MAP; + m_NodeStack.push_back(Node); + for (CbFieldView It : Field) + { + WriteField(It); + } + m_NodeStack.pop_back(); + break; + case CbFieldType::Array: + case CbFieldType::UniformArray: + Node |= ryml::SEQ; + m_NodeStack.push_back(Node); + for (CbFieldView It : Field) + { + WriteField(It); + } + m_NodeStack.pop_back(); + break; + case CbFieldType::Binary: + { + ExtendableStringBuilder<256> Builder; + const MemoryView Value = Accessor.AsBinary(); + ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024); + const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); + const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); + Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + + Node.set_key_serialized(Builder.c_str()); + } + break; + case CbFieldType::String: + { + const std::u8string_view U8String = Accessor.AsU8String(); + Node.set_val(ryml::csubstr((const char*)U8String.data(), U8String.size())); + } + break; + case CbFieldType::IntegerPositive: + Node << Accessor.AsIntegerPositive(); + break; + case CbFieldType::IntegerNegative: + Node << Accessor.AsIntegerNegative(); + break; + case CbFieldType::Float32: + if (const float Value = Accessor.AsFloat32(); std::isfinite(Value)) + { + Node << Value; + } + else + { + Node << "null"; + } + break; + case CbFieldType::Float64: + if (const double Value = Accessor.AsFloat64(); std::isfinite(Value)) + { + Node << Value; + } + else + { + Node << "null"; + } + break; + case CbFieldType::BoolFalse: + Node << "false"; + break; + case CbFieldType::BoolTrue: + Node << "true"; + break; + case CbFieldType::ObjectAttachment: + case CbFieldType::BinaryAttachment: + Node << Accessor.AsAttachment().ToHexString(); + break; + case CbFieldType::Hash: + Node << Accessor.AsHash().ToHexString(); + break; + case CbFieldType::Uuid: + Node << fmt::format("{}", Accessor.AsUuid()); + break; + case CbFieldType::DateTime: + Node << DateTime(Accessor.AsDateTimeTicks()).ToIso8601(); + break; + case CbFieldType::TimeSpan: + if (const TimeSpan Span(Accessor.AsTimeSpanTicks()); Span.GetDays() == 0) + { + Node << Span.ToString("%h:%m:%s.%n"); + } + else + { + Node << Span.ToString("%d.%h:%m:%s.%n"); + } + break; + case CbFieldType::ObjectId: + Node << fmt::format("{}", Accessor.AsObjectId()); + break; + case CbFieldType::CustomById: + { + CbCustomById Custom = Accessor.AsCustomById(); + + Node |= ryml::MAP; + + ryml::NodeRef IdNode = Node.append_child(); + IdNode.set_key("Id"); + IdNode.set_val_serialized(fmt::format("{}", Custom.Id)); + + ryml::NodeRef DataNode = Node.append_child(); + DataNode.set_key("Data"); + + ExtendableStringBuilder<256> Builder; + const MemoryView& Value = Custom.Data; + const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); + const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); + Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + + DataNode.set_val_serialized(Builder.c_str()); + } + break; + case CbFieldType::CustomByName: + { + CbCustomByName Custom = Accessor.AsCustomByName(); + + Node |= ryml::MAP; + + ryml::NodeRef NameNode = Node.append_child(); + NameNode.set_key("Name"); + std::string_view Name = std::string_view((const char*)Custom.Name.data(), Custom.Name.size()); + NameNode.set_val_serialized(std::string(Name)); + + ryml::NodeRef DataNode = Node.append_child(); + DataNode.set_key("Data"); + + ExtendableStringBuilder<256> Builder; + const MemoryView& Value = Custom.Data; + const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); + const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); + Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + + DataNode.set_val_serialized(Builder.c_str()); + } + break; + default: + ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType())); + break; + } + + if (m_NodeStack.size() == 1) + { + std::string Yaml = ryml::emitrs_yaml<std::string>(m_Tree); + m_StrBuilder << Yaml; + } + } + +private: + StringBuilderBase& m_StrBuilder; + bool m_IsFirst = true; + + ryml::Tree m_Tree; + std::vector<ryml::NodeRef> m_NodeStack; + ryml::NodeRef& Top() { return m_NodeStack.back(); } +}; + +void +CompactBinaryToYaml(const CbObjectView& Object, StringBuilderBase& Builder) +{ + CbYamlWriter Writer(Builder); + Writer.WriteField(Object.AsFieldView()); +} + +void +CompactBinaryToYaml(const CbArrayView& Array, StringBuilderBase& Builder) +{ + CbYamlWriter Writer(Builder); + Writer.WriteField(Array.AsFieldView()); +} + +#if ZEN_WITH_TESTS +void +cbyaml_forcelink() +{ +} + +TEST_CASE("uson.yaml") +{ + using namespace std::literals; + + SUBCASE("simple") + { + CbObjectWriter Writer; + Writer << "KeyOne" + << "ValueOne"; + Writer << "KeyTwo" + << "ValueTwo"; + CbObject Obj = Writer.Save(); + + StringBuilder<128> Sb; + CbYamlWriter YamlWriter(Sb); + YamlWriter.WriteField(Obj.AsFieldView()); + + CHECK_EQ(Sb.ToView(), "KeyOne: ValueOne\nKeyTwo: ValueTwo\n"sv); + } + + SUBCASE("scalar_fields") + { + CbObjectWriter Writer; + Writer << "small_int"sv << 10; + Writer << "neg_small_int"sv << -10; + Writer << "small_real"sv << 10.5f; + Writer << "neg_small_real"sv << -10.5f; + Writer.AddNull("null_val"sv); + Writer.AddDateTimeTicks("date"sv, 622'033'000'000'000'000ull); + Writer.AddHash("hash"sv, IoHash::FromHexString("0011223344556677889900112233445566778899"sv)); + Writer.AddObjectId("oid"sv, Oid::FromHexString("112233445566778899001122"sv)); + Writer.AddTimeSpanTicks("dt"sv, 3'000'000'000'000ull); + Writer.AddUuid("guid"sv, Guid::FromString("E0596ADC-996A-4BA4-ACA3-A2A378AB2796")); + Writer.AddBool("yes"sv, true); + Writer.AddBool("no"sv, false); + CbObject Obj = Writer.Save(); + + ExtendableStringBuilder<128> Sb; + CbYamlWriter YamlWriter(Sb); + YamlWriter.WriteField(Obj.AsFieldView()); + + CHECK_EQ(Sb.ToView(), + "small_int: 10\n" + "neg_small_int: -10\n" + "small_real: 10.5\n" + "neg_small_real: -10.5\n" + "null_val: null\n" + "date: '1972-02-23T14:26:40.000Z'\n" + "hash: 0011223344556677889900112233445566778899\n" + "oid: 112233445566778899001122\n" + "dt: '+3.11:20:00.000000000'\n" + "guid: 'e0596adc-996a-4ba4-aca3-a2a378ab2796'\n" + "yes: true\n" + "no: false\n"sv); + } + + SUBCASE("complex_fields") + { + CbObjectWriter Writer; + Writer.BeginObject("sub"); + Writer.AddBool("no"sv, false); + Writer.BeginObject("sub"); + Writer.AddBool("yes"sv, true); + Writer.BeginArray("seq"); + Writer.AddInteger(1); + Writer.AddInteger(2); + Writer.AddInteger(3); + Writer.EndArray(); + Writer.EndObject(); + Writer.EndObject(); + Writer.BeginArray("seq"); + Writer.AddInteger(1); + Writer.AddInteger(2); + Writer.AddInteger(3); + Writer.EndArray(); + Writer.BeginArray("mixed_seq"); + Writer.AddInteger(1); + Writer.AddString("hello"sv); + Writer.AddFloat(44.4f); + Writer.BeginObject(); + Writer.AddBool("yes"sv, true); + Writer.AddBool("no"sv, false); + Writer.EndObject(); + Writer.EndArray(); + CbObject Obj = Writer.Save(); + + ExtendableStringBuilder<128> Sb; + CbYamlWriter YamlWriter(Sb); + YamlWriter.WriteField(Obj.AsFieldView()); + + CHECK_EQ(Sb.ToView(), + R"(sub: + no: false + sub: + yes: true + seq: + - 1 + - 2 + - 3 +seq: + - 1 + - 2 + - 3 +mixed_seq: + - 1 + - hello + - 44.4 + - yes: true + no: false +)"sv); + } +} +#endif + +} // namespace zen diff --git a/src/zencore/compositebuffer.cpp b/src/zencore/compositebuffer.cpp index f29f6e810..252ac9045 100644 --- a/src/zencore/compositebuffer.cpp +++ b/src/zencore/compositebuffer.cpp @@ -93,10 +93,36 @@ CompositeBuffer::Mid(uint64_t Offset, uint64_t Size) const const uint64_t BufferSize = GetSize(); Offset = Min(Offset, BufferSize); Size = Min(Size, BufferSize - Offset); + CompositeBuffer Buffer; - IterateRange(Offset, Size, [&Buffer](MemoryView View, const SharedBuffer& ViewOuter) { - Buffer.m_Segments.push_back(SharedBuffer::MakeView(View, ViewOuter)); - }); + { + for (const SharedBuffer& Segment : m_Segments) + { + if (const uint64_t SegmentSize = Segment.GetSize(); Offset <= SegmentSize) + { + size_t PartSize = Min(Size, SegmentSize - Offset); + if (PartSize == SegmentSize) + { + Buffer.m_Segments.push_back(Segment); + } + else if (PartSize > 0 || Size == 0) + { + // We need to add the segment even if PartSize is zero if we are picking up zero bytes. + Buffer.m_Segments.push_back(SharedBuffer(IoBuffer(Segment.AsIoBuffer(), Offset, PartSize))); + } + Offset = 0; + Size -= PartSize; + if (Size == 0) + { + break; + } + } + else + { + Offset -= SegmentSize; + } + } + } return Buffer; } @@ -107,24 +133,28 @@ CompositeBuffer::ViewOrCopyRange(uint64_t Offset, std::function<UniqueBuffer(uint64_t Size)> Allocator) const { MemoryView View; - IterateRange(Offset, Size, [Size, &View, &CopyBuffer, &Allocator, WriteView = MutableMemoryView()](MemoryView Segment) mutable { - if (Size == Segment.GetSize()) - { - View = Segment; - } - else - { - if (WriteView.IsEmpty()) + IterateRange( + Offset, + Size, + [Size, &View, &CopyBuffer, &Allocator, WriteView = MutableMemoryView()](MemoryView Segment, const SharedBuffer& ViewOuter) mutable { + if (Segment.GetSize() == ViewOuter.GetSize()) { - if (CopyBuffer.GetSize() < Size) + // We assume that the segment of the buffer is kept in memory + View = Segment; + } + else + { + if (WriteView.IsEmpty()) { - CopyBuffer = Allocator(Size); + if (CopyBuffer.GetSize() < Size) + { + CopyBuffer = Allocator(Size); + } + View = WriteView = CopyBuffer.GetMutableView().Left(Size); } - View = WriteView = CopyBuffer.GetMutableView().Left(Size); + WriteView = WriteView.CopyFrom(Segment); } - WriteView = WriteView.CopyFrom(Segment); - } - }); + }); return View; } @@ -155,7 +185,11 @@ CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& Copy // A hot path for this code is when we call CompressedBuffer::FromCompressed which // is only interested in reading the header (first 64 bytes or so) and then throws // away the materialized data. - MutableMemoryView WriteView; + if (CopyBuffer.GetSize() < Size) + { + CopyBuffer = UniqueBuffer::Alloc(Size); + } + MutableMemoryView WriteView = CopyBuffer.GetMutableView(); size_t SegmentCount = m_Segments.size(); ZEN_ASSERT(It.SegmentIndex < SegmentCount); uint64_t SizeLeft = Size; @@ -163,31 +197,10 @@ CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& Copy { const SharedBuffer& Segment = m_Segments[It.SegmentIndex]; size_t SegmentSize = Segment.GetSize(); - if (Size == SizeLeft && Size <= (SegmentSize - It.OffsetInSegment)) - { - IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, SizeLeft); - MemoryView View = SubSegment.GetView(); - It.OffsetInSegment += SizeLeft; - ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize); - if (It.OffsetInSegment == SegmentSize) - { - It.SegmentIndex++; - It.OffsetInSegment = 0; - } - return View; - } - if (WriteView.GetSize() == 0) - { - if (CopyBuffer.GetSize() < Size) - { - CopyBuffer = UniqueBuffer::Alloc(Size); - } - WriteView = CopyBuffer.GetMutableView(); - } - size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft); - IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize); - MemoryView ReadView = SubSegment.GetView(); - WriteView = WriteView.CopyFrom(ReadView); + size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft); + IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize); + MemoryView ReadView = SubSegment.GetView(); + WriteView = WriteView.CopyFrom(ReadView); It.OffsetInSegment += CopySize; ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize); if (It.OffsetInSegment == SegmentSize) @@ -254,7 +267,15 @@ CompositeBuffer::IterateRange(uint64_t Offset, ZEN_ASSERT(Offset + Size <= GetSize()); for (const SharedBuffer& Segment : m_Segments) { - if (const uint64_t SegmentSize = Segment.GetSize(); Offset <= SegmentSize) + const uint64_t SegmentSize = Segment.GetSize(); + if (Size == 0 && Offset == SegmentSize) + { + // Special case for getting the zero size end of a composite buffer + const MemoryView View = Segment.GetView().Mid(Offset, 0); + Visitor(View, Segment); + break; + } + else if (Offset <= SegmentSize) { const MemoryView View = Segment.GetView().Mid(Offset, Size); Offset = 0; diff --git a/src/zencore/compress.cpp b/src/zencore/compress.cpp index c41bdac42..d9f381811 100644 --- a/src/zencore/compress.cpp +++ b/src/zencore/compress.cpp @@ -2,10 +2,12 @@ #include <zencore/compress.h> +#include <zencore/basicfile.h> #include <zencore/blake3.h> #include <zencore/compositebuffer.h> #include <zencore/crc32.h> #include <zencore/endian.h> +#include <zencore/filesystem.h> #include <zencore/intmath.h> #include <zencore/iohash.h> #include <zencore/stream.h> @@ -157,6 +159,10 @@ class BaseEncoder { public: [[nodiscard]] virtual CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t BlockSize = DefaultBlockSize) const = 0; + [[nodiscard]] virtual bool CompressToStream( + const CompositeBuffer& RawData, + std::function<void(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback, + uint64_t BlockSize = DefaultBlockSize) const = 0; }; class BaseDecoder @@ -184,6 +190,14 @@ public: const MemoryView HeaderView, uint64_t RawOffset, uint64_t RawSize) const = 0; + + virtual bool DecompressToStream( + const BufferHeader& Header, + const CompositeBuffer& CompressedData, + uint64_t RawOffset, + uint64_t RawSize, + std::function<bool(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback) + const = 0; }; /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -191,19 +205,58 @@ public: class NoneEncoder final : public BaseEncoder { public: - [[nodiscard]] CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t /* BlockSize */) const final + [[nodiscard]] virtual CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t /* BlockSize */) const final { - BufferHeader Header; - Header.Method = CompressionMethod::None; - Header.BlockCount = 1; - Header.TotalRawSize = RawData.GetSize(); - Header.TotalCompressedSize = Header.TotalRawSize + sizeof(BufferHeader); - Header.RawHash = BLAKE3::HashBuffer(RawData); - - UniqueBuffer HeaderData = UniqueBuffer::Alloc(sizeof(BufferHeader)); - Header.Write(HeaderData); + UniqueBuffer HeaderData = CompressedBuffer::CreateHeaderForNoneEncoder(RawData.GetSize(), BLAKE3::HashBuffer(RawData)); return CompositeBuffer(HeaderData.MoveToShared(), RawData.MakeOwned()); } + + [[nodiscard]] virtual bool CompressToStream( + const CompositeBuffer& RawData, + std::function<void(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback, + uint64_t /* BlockSize */) const final + { + const uint64_t HeaderSize = CompressedBuffer::GetHeaderSizeForNoneEncoder(); + + uint64_t RawOffset = 0; + BLAKE3Stream HashStream; + + for (const SharedBuffer& Segment : RawData.GetSegments()) + { + IoBufferFileReference FileRef = {nullptr, 0, 0}; + IoBuffer SegmentBuffer = Segment.AsIoBuffer(); + if (SegmentBuffer.GetFileReference(FileRef)) + { + ZEN_ASSERT(FileRef.FileHandle != nullptr); + + ScanFile(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + 512u * 1024u, + [&](const void* Data, size_t Size) { + HashStream.Append(Data, Size); + CompositeBuffer Tmp(SharedBuffer::MakeView(Data, Size)); + Callback(RawOffset, Size, HeaderSize + RawOffset, Tmp); + RawOffset += Size; + }); + } + else + { + const uint64_t Size = SegmentBuffer.GetSize(); + HashStream.Append(SegmentBuffer); + Callback(RawOffset, Size, HeaderSize + RawOffset, CompositeBuffer(Segment)); + RawOffset += Size; + } + } + + ZEN_ASSERT(RawOffset == RawData.GetSize()); + + UniqueBuffer HeaderData = CompressedBuffer::CreateHeaderForNoneEncoder(RawData.GetSize(), HashStream.GetHash()); + ZEN_ASSERT(HeaderData.GetSize() == HeaderSize); + Callback(0, 0, 0, CompositeBuffer(HeaderData.MoveToShared())); + + return true; + } }; class NoneDecoder final : public BaseDecoder @@ -270,6 +323,45 @@ public: } [[nodiscard]] uint64_t GetHeaderSize(const BufferHeader&) const final { return sizeof(BufferHeader); } + + virtual bool DecompressToStream( + const BufferHeader& Header, + const CompositeBuffer& CompressedData, + uint64_t RawOffset, + uint64_t RawSize, + std::function<bool(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback) + const final + { + if (Header.Method == CompressionMethod::None && Header.TotalCompressedSize == CompressedData.GetSize() && + Header.TotalCompressedSize == Header.TotalRawSize + sizeof(BufferHeader) && RawOffset < Header.TotalRawSize && + (RawOffset + RawSize) <= Header.TotalRawSize) + { + bool Result = true; + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if ((CompressedData.GetSegments().size() == 1) && CompressedData.GetSegments()[0].AsIoBuffer().GetFileReference(FileRef)) + { + ZEN_ASSERT(FileRef.FileHandle != nullptr); + uint64_t CallbackOffset = 0; + ScanFile(FileRef.FileHandle, sizeof(BufferHeader) + RawOffset, RawSize, 512u * 1024u, [&](const void* Data, size_t Size) { + if (Result) + { + CompositeBuffer Tmp(SharedBuffer::MakeView(Data, Size)); + Result = Callback(sizeof(BufferHeader) + RawOffset + CallbackOffset, Size, CallbackOffset, Tmp); + } + CallbackOffset += Size; + }); + return Result; + } + else + { + return Callback(sizeof(BufferHeader) + RawOffset, + RawSize, + 0, + CompressedData.Mid(sizeof(BufferHeader) + RawOffset, RawSize)); + } + } + return false; + } }; ////////////////////////////////////////////////////////////////////////// @@ -277,7 +369,11 @@ public: class BlockEncoder : public BaseEncoder { public: - CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t BlockSize = DefaultBlockSize) const final; + virtual CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t BlockSize) const final; + virtual bool CompressToStream( + const CompositeBuffer& RawData, + std::function<void(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback, + uint64_t BlockSize) const final; protected: virtual CompressionMethod GetMethod() const = 0; @@ -322,37 +418,77 @@ BlockEncoder::Compress(const CompositeBuffer& RawData, const uint64_t BlockSize) CompressedBlockSizes.reserve(BlockCount); uint64_t CompressedSize = 0; { - UniqueBuffer RawBlockCopy; MutableMemoryView CompressedBlocksView = CompressedData.GetMutableView() + sizeof(BufferHeader) + MetaSize; - CompositeBuffer::Iterator It = RawData.GetIterator(0); - - for (uint64_t RawOffset = 0; RawOffset < RawSize;) + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if ((RawData.GetSegments().size() == 1) && RawData.GetSegments()[0].AsIoBuffer().GetFileReference(FileRef)) { - const uint64_t RawBlockSize = zen::Min(RawSize - RawOffset, BlockSize); - const MemoryView RawBlock = RawData.ViewOrCopyRange(It, RawBlockSize, RawBlockCopy); - RawHash.Append(RawBlock); - - MutableMemoryView CompressedBlock = CompressedBlocksView; - if (!CompressBlock(CompressedBlock, RawBlock)) + ZEN_ASSERT(FileRef.FileHandle != nullptr); + UniqueBuffer RawBlockCopy = UniqueBuffer::Alloc(BlockSize); + BasicFile Source; + Source.Attach(FileRef.FileHandle); + for (uint64_t RawOffset = 0; RawOffset < RawSize;) { - return CompositeBuffer(); - } + const uint64_t RawBlockSize = zen::Min(RawSize - RawOffset, BlockSize); + Source.Read(RawBlockCopy.GetData(), RawBlockSize, FileRef.FileChunkOffset + RawOffset); + const MemoryView RawBlock = RawBlockCopy.GetView().Left(RawBlockSize); + RawHash.Append(RawBlock); + MutableMemoryView CompressedBlock = CompressedBlocksView; + if (!CompressBlock(CompressedBlock, RawBlock)) + { + Source.Detach(); + return CompositeBuffer(); + } - uint64_t CompressedBlockSize = CompressedBlock.GetSize(); - if (RawBlockSize <= CompressedBlockSize) - { - CompressedBlockSize = RawBlockSize; - CompressedBlocksView = CompressedBlocksView.CopyFrom(RawBlock); + uint64_t CompressedBlockSize = CompressedBlock.GetSize(); + if (RawBlockSize <= CompressedBlockSize) + { + CompressedBlockSize = RawBlockSize; + CompressedBlocksView = CompressedBlocksView.CopyFrom(RawBlock); + } + else + { + CompressedBlocksView += CompressedBlockSize; + } + + CompressedBlockSizes.push_back(static_cast<uint32_t>(CompressedBlockSize)); + CompressedSize += CompressedBlockSize; + RawOffset += RawBlockSize; } - else + Source.Detach(); + } + else + { + UniqueBuffer RawBlockCopy; + CompositeBuffer::Iterator It = RawData.GetIterator(0); + + for (uint64_t RawOffset = 0; RawOffset < RawSize;) { - CompressedBlocksView += CompressedBlockSize; - } + const uint64_t RawBlockSize = zen::Min(RawSize - RawOffset, BlockSize); + const MemoryView RawBlock = RawData.ViewOrCopyRange(It, RawBlockSize, RawBlockCopy); + RawHash.Append(RawBlock); + + MutableMemoryView CompressedBlock = CompressedBlocksView; + if (!CompressBlock(CompressedBlock, RawBlock)) + { + return CompositeBuffer(); + } - CompressedBlockSizes.push_back(static_cast<uint32_t>(CompressedBlockSize)); - CompressedSize += CompressedBlockSize; - RawOffset += RawBlockSize; + uint64_t CompressedBlockSize = CompressedBlock.GetSize(); + if (RawBlockSize <= CompressedBlockSize) + { + CompressedBlockSize = RawBlockSize; + CompressedBlocksView = CompressedBlocksView.CopyFrom(RawBlock); + } + else + { + CompressedBlocksView += CompressedBlockSize; + } + + CompressedBlockSizes.push_back(static_cast<uint32_t>(CompressedBlockSize)); + CompressedSize += CompressedBlockSize; + RawOffset += RawBlockSize; + } } } @@ -385,6 +521,143 @@ BlockEncoder::Compress(const CompositeBuffer& RawData, const uint64_t BlockSize) return CompositeBuffer(SharedBuffer::MakeView(CompositeView, CompressedData.MoveToShared())); } +bool +BlockEncoder::CompressToStream( + const CompositeBuffer& RawData, + std::function<void(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback, + uint64_t BlockSize = DefaultBlockSize) const +{ + ZEN_ASSERT(IsPow2(BlockSize) && (BlockSize <= (1u << 31))); + + const uint64_t RawSize = RawData.GetSize(); + BLAKE3Stream RawHash; + + const uint64_t BlockCount = RoundUp(RawSize, BlockSize) / BlockSize; + ZEN_ASSERT(BlockCount <= ~uint32_t(0)); + + const uint64_t MetaSize = BlockCount * sizeof(uint32_t); + const uint64_t FullHeaderSize = sizeof(BufferHeader) + MetaSize; + + std::vector<uint32_t> CompressedBlockSizes; + CompressedBlockSizes.reserve(BlockCount); + uint64_t CompressedSize = 0; + { + UniqueBuffer CompressedBlockBuffer = UniqueBuffer::Alloc(GetCompressedBlocksBound(1, BlockSize, Min(RawSize, BlockSize))); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if ((RawData.GetSegments().size() == 1) && RawData.GetSegments()[0].AsIoBuffer().GetFileReference(FileRef)) + { + ZEN_ASSERT(FileRef.FileHandle != nullptr); + UniqueBuffer RawBlockCopy = UniqueBuffer::Alloc(BlockSize); + BasicFile Source; + Source.Attach(FileRef.FileHandle); + for (uint64_t RawOffset = 0; RawOffset < RawSize;) + { + const uint64_t RawBlockSize = zen::Min(RawSize - RawOffset, BlockSize); + Source.Read(RawBlockCopy.GetData(), RawBlockSize, FileRef.FileChunkOffset + RawOffset); + const MemoryView RawBlock = RawBlockCopy.GetView().Left(RawBlockSize); + RawHash.Append(RawBlock); + MutableMemoryView CompressedBlock = CompressedBlockBuffer.GetMutableView(); + if (!CompressBlock(CompressedBlock, RawBlock)) + { + Source.Detach(); + return false; + } + + uint64_t CompressedBlockSize = CompressedBlock.GetSize(); + if (RawBlockSize <= CompressedBlockSize) + { + Callback(FileRef.FileChunkOffset + RawOffset, + RawBlockSize, + FullHeaderSize + CompressedSize, + CompositeBuffer(IoBuffer(IoBuffer::Wrap, RawBlockCopy.GetView().GetData(), RawBlockSize))); + CompressedBlockSize = RawBlockSize; + } + else + { + Callback(FileRef.FileChunkOffset + RawOffset, + RawBlockSize, + FullHeaderSize + CompressedSize, + CompositeBuffer(IoBuffer(IoBuffer::Wrap, CompressedBlock.GetData(), CompressedBlockSize))); + } + + CompressedBlockSizes.push_back(static_cast<uint32_t>(CompressedBlockSize)); + CompressedSize += CompressedBlockSize; + RawOffset += RawBlockSize; + } + Source.Detach(); + } + else + { + UniqueBuffer RawBlockCopy; + CompositeBuffer::Iterator It = RawData.GetIterator(0); + + for (uint64_t RawOffset = 0; RawOffset < RawSize;) + { + const uint64_t RawBlockSize = zen::Min(RawSize - RawOffset, BlockSize); + const MemoryView RawBlock = RawData.ViewOrCopyRange(It, RawBlockSize, RawBlockCopy); + RawHash.Append(RawBlock); + + MutableMemoryView CompressedBlock = CompressedBlockBuffer.GetMutableView(); + if (!CompressBlock(CompressedBlock, RawBlock)) + { + return false; + } + + uint64_t CompressedBlockSize = CompressedBlock.GetSize(); + if (RawBlockSize <= CompressedBlockSize) + { + Callback(RawOffset, + RawBlockSize, + FullHeaderSize + CompressedSize, + CompositeBuffer(IoBuffer(IoBuffer::Wrap, RawBlock.GetData(), RawBlockSize))); + CompressedBlockSize = RawBlockSize; + } + else + { + Callback(RawOffset, + RawBlockSize, + FullHeaderSize + CompressedSize, + CompositeBuffer(IoBuffer(IoBuffer::Wrap, CompressedBlock.GetData(), CompressedBlockSize))); + } + + CompressedBlockSizes.push_back(static_cast<uint32_t>(CompressedBlockSize)); + CompressedSize += CompressedBlockSize; + RawOffset += RawBlockSize; + } + } + } + + // Return failure if the compressed data is larger than the raw data. + if (RawSize <= MetaSize + CompressedSize) + { + return false; + } + + // Write the header and calculate the CRC-32. + for (uint32_t& Size : CompressedBlockSizes) + { + Size = ByteSwap(Size); + } + UniqueBuffer HeaderBuffer = UniqueBuffer::Alloc(sizeof(BufferHeader) + MetaSize); + + BufferHeader Header; + Header.Method = GetMethod(); + Header.Compressor = GetCompressor(); + Header.CompressionLevel = GetCompressionLevel(); + Header.BlockSizeExponent = static_cast<uint8_t>(zen::FloorLog2_64(BlockSize)); + Header.BlockCount = static_cast<uint32_t>(BlockCount); + Header.TotalRawSize = RawSize; + Header.TotalCompressedSize = sizeof(BufferHeader) + MetaSize + CompressedSize; + Header.RawHash = RawHash.GetHash(); + + HeaderBuffer.GetMutableView().Mid(sizeof(BufferHeader), MetaSize).CopyFrom(MakeMemoryView(CompressedBlockSizes)); + Header.Write(HeaderBuffer.GetMutableView()); + + Callback(0, 0, 0, CompositeBuffer(IoBuffer(IoBuffer::Wrap, HeaderBuffer.GetData(), HeaderBuffer.GetSize()))); + return true; +} + class BlockDecoder : public BaseDecoder { public: @@ -414,6 +687,14 @@ public: MutableMemoryView RawView, uint64_t RawOffset) const final; + virtual bool DecompressToStream( + const BufferHeader& Header, + const CompositeBuffer& CompressedData, + uint64_t RawOffset, + uint64_t RawSize, + std::function<bool(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback) + const final; + protected: virtual bool DecompressBlock(MutableMemoryView RawData, MemoryView CompressedData) const = 0; }; @@ -536,6 +817,168 @@ BlockDecoder::DecompressToComposite(const BufferHeader& Header, const CompositeB } bool +BlockDecoder::DecompressToStream( + const BufferHeader& Header, + const CompositeBuffer& CompressedData, + uint64_t RawOffset, + uint64_t RawSize, + std::function<bool(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback) const +{ + if (Header.TotalCompressedSize != CompressedData.GetSize()) + { + return false; + } + + const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent; + + UniqueBuffer BlockSizeBuffer; + MemoryView BlockSizeView = CompressedData.ViewOrCopyRange(sizeof(BufferHeader), Header.BlockCount * sizeof(uint32_t), BlockSizeBuffer); + std::span<uint32_t const> CompressedBlockSizes(reinterpret_cast<const uint32_t*>(BlockSizeView.GetData()), Header.BlockCount); + + UniqueBuffer CompressedBlockCopy; + + const size_t FirstBlockIndex = uint64_t(RawOffset / BlockSize); + const size_t LastBlockIndex = uint64_t((RawOffset + RawSize - 1) / BlockSize); + const uint64_t LastBlockSize = BlockSize - ((Header.BlockCount * BlockSize) - Header.TotalRawSize); + uint64_t OffsetInFirstBlock = RawOffset % BlockSize; + uint64_t CompressedOffset = sizeof(BufferHeader) + uint64_t(Header.BlockCount) * sizeof(uint32_t); + uint64_t RemainingRawSize = RawSize; + + for (size_t BlockIndex = 0; BlockIndex < FirstBlockIndex; BlockIndex++) + { + const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + CompressedOffset += CompressedBlockSize; + } + + UniqueBuffer RawDataBuffer; + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if ((CompressedData.GetSegments().size() == 1) && CompressedData.GetSegments()[0].AsIoBuffer().GetFileReference(FileRef)) + { + ZEN_ASSERT(FileRef.FileHandle != nullptr); + BasicFile Source; + Source.Attach(FileRef.FileHandle); + + for (size_t BlockIndex = FirstBlockIndex; BlockIndex <= LastBlockIndex; BlockIndex++) + { + const uint64_t UncompressedBlockSize = BlockIndex == Header.BlockCount - 1 ? LastBlockSize : BlockSize; + const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + const bool IsCompressed = CompressedBlockSize < UncompressedBlockSize; + + const uint64_t BytesToUncompress = OffsetInFirstBlock > 0 ? zen::Min(RawSize, UncompressedBlockSize - OffsetInFirstBlock) + : zen::Min(RemainingRawSize, BlockSize); + + if (CompressedBlockCopy.GetSize() < CompressedBlockSize) + { + CompressedBlockCopy = UniqueBuffer::Alloc(CompressedBlockSize); + } + Source.Read(CompressedBlockCopy.GetData(), CompressedBlockSize, FileRef.FileChunkOffset + CompressedOffset); + + MemoryView CompressedBlock = CompressedBlockCopy.GetView().Left(CompressedBlockSize); + + if (IsCompressed) + { + if (RawDataBuffer.IsNull()) + { + RawDataBuffer = UniqueBuffer::Alloc(zen::Min(RawSize, UncompressedBlockSize)); + } + else + { + ZEN_ASSERT(RawDataBuffer.GetSize() >= UncompressedBlockSize); + } + MutableMemoryView UncompressedBlock = RawDataBuffer.GetMutableView().Left(UncompressedBlockSize); + if (!DecompressBlock(UncompressedBlock, CompressedBlock)) + { + Source.Detach(); + return false; + } + if (!Callback(FileRef.FileChunkOffset + CompressedOffset, + CompressedBlockSize, + BlockIndex * BlockSize + OffsetInFirstBlock, + CompositeBuffer(IoBuffer(IoBuffer::Wrap, RawDataBuffer.GetData(), BytesToUncompress)))) + { + Source.Detach(); + return false; + } + } + else + { + if (!Callback( + FileRef.FileChunkOffset + CompressedOffset, + BytesToUncompress, + BlockIndex * BlockSize + OffsetInFirstBlock, + CompositeBuffer( + IoBuffer(IoBuffer::Wrap, CompressedBlockCopy.GetView().Mid(OffsetInFirstBlock).GetData(), BytesToUncompress)))) + { + Source.Detach(); + return false; + } + } + + OffsetInFirstBlock = 0; + RemainingRawSize -= BytesToUncompress; + CompressedOffset += CompressedBlockSize; + } + Source.Detach(); + } + else + { + for (size_t BlockIndex = FirstBlockIndex; BlockIndex <= LastBlockIndex; BlockIndex++) + { + const uint64_t UncompressedBlockSize = BlockIndex == Header.BlockCount - 1 ? LastBlockSize : BlockSize; + const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + const bool IsCompressed = CompressedBlockSize < UncompressedBlockSize; + + const uint64_t BytesToUncompress = OffsetInFirstBlock > 0 ? zen::Min(RawSize, UncompressedBlockSize - OffsetInFirstBlock) + : zen::Min(RemainingRawSize, BlockSize); + + MemoryView CompressedBlock = CompressedData.ViewOrCopyRange(CompressedOffset, CompressedBlockSize, CompressedBlockCopy); + + if (IsCompressed) + { + if (RawDataBuffer.IsNull()) + { + RawDataBuffer = UniqueBuffer::Alloc(zen::Min(RawSize, UncompressedBlockSize)); + } + else + { + ZEN_ASSERT(RawDataBuffer.GetSize() >= UncompressedBlockSize); + } + MutableMemoryView UncompressedBlock = RawDataBuffer.GetMutableView().Left(UncompressedBlockSize); + if (!DecompressBlock(UncompressedBlock, CompressedBlock)) + { + return false; + } + if (!Callback(CompressedOffset, + UncompressedBlockSize, + BlockIndex * BlockSize + OffsetInFirstBlock, + CompositeBuffer(IoBuffer(IoBuffer::Wrap, RawDataBuffer.GetData(), BytesToUncompress)))) + { + return false; + } + } + else + { + if (!Callback( + CompressedOffset, + BytesToUncompress, + BlockIndex * BlockSize + OffsetInFirstBlock, + CompositeBuffer( + IoBuffer(IoBuffer::Wrap, CompressedBlockCopy.GetView().Mid(OffsetInFirstBlock).GetData(), BytesToUncompress)))) + { + return false; + } + } + + OffsetInFirstBlock = 0; + RemainingRawSize -= BytesToUncompress; + CompressedOffset += CompressedBlockSize; + } + } + return true; +} + +bool BlockDecoder::TryDecompressTo(const BufferHeader& Header, const CompositeBuffer& CompressedData, MutableMemoryView RawView, @@ -568,51 +1011,118 @@ BlockDecoder::TryDecompressTo(const BufferHeader& Header, CompressedOffset += CompressedBlockSize; } - for (size_t BlockIndex = FirstBlockIndex; BlockIndex <= LastBlockIndex; BlockIndex++) + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if ((CompressedData.GetSegments().size() == 1) && CompressedData.GetSegments()[0].AsIoBuffer().GetFileReference(FileRef)) { - const uint64_t UncompressedBlockSize = BlockIndex == Header.BlockCount - 1 ? LastBlockSize : BlockSize; - const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); - const bool IsCompressed = CompressedBlockSize < UncompressedBlockSize; + ZEN_ASSERT(FileRef.FileHandle != nullptr); + BasicFile Source; + Source.Attach(FileRef.FileHandle); - const uint64_t BytesToUncompress = OffsetInFirstBlock > 0 ? zen::Min(RawView.GetSize(), UncompressedBlockSize - OffsetInFirstBlock) - : zen::Min(RemainingRawSize, BlockSize); + for (size_t BlockIndex = FirstBlockIndex; BlockIndex <= LastBlockIndex; BlockIndex++) + { + const uint64_t UncompressedBlockSize = BlockIndex == Header.BlockCount - 1 ? LastBlockSize : BlockSize; + const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + const bool IsCompressed = CompressedBlockSize < UncompressedBlockSize; - MemoryView CompressedBlock = CompressedData.ViewOrCopyRange(CompressedOffset, CompressedBlockSize, CompressedBlockCopy); + const uint64_t BytesToUncompress = OffsetInFirstBlock > 0 + ? zen::Min(RawView.GetSize(), UncompressedBlockSize - OffsetInFirstBlock) + : zen::Min(RemainingRawSize, BlockSize); - if (IsCompressed) - { - MutableMemoryView UncompressedBlock = RawView.Left(BytesToUncompress); + if (CompressedBlockCopy.GetSize() < CompressedBlockSize) + { + CompressedBlockCopy = UniqueBuffer::Alloc(CompressedBlockSize); + } + Source.Read(CompressedBlockCopy.GetData(), CompressedBlockSize, FileRef.FileChunkOffset + CompressedOffset); - const bool IsAligned = BytesToUncompress == UncompressedBlockSize; - if (!IsAligned) + MemoryView CompressedBlock = CompressedBlockCopy.GetView().Left(CompressedBlockSize); + + if (IsCompressed) { - // Decompress to a temporary buffer when the first or the last block reads are not aligned with the block boundaries. - if (UncompressedBlockCopy.IsNull()) + MutableMemoryView UncompressedBlock = RawView.Left(BytesToUncompress); + + const bool IsAligned = BytesToUncompress == UncompressedBlockSize; + if (!IsAligned) { - UncompressedBlockCopy = UniqueBuffer::Alloc(BlockSize); + // Decompress to a temporary buffer when the first or the last block reads are not aligned with the block boundaries. + if (UncompressedBlockCopy.IsNull()) + { + UncompressedBlockCopy = UniqueBuffer::Alloc(BlockSize); + } + UncompressedBlock = UncompressedBlockCopy.GetMutableView().Mid(0, UncompressedBlockSize); } - UncompressedBlock = UncompressedBlockCopy.GetMutableView().Mid(0, UncompressedBlockSize); - } - if (!DecompressBlock(UncompressedBlock, CompressedBlock)) - { - return false; - } + if (!DecompressBlock(UncompressedBlock, CompressedBlock)) + { + Source.Detach(); + return false; + } - if (!IsAligned) + if (!IsAligned) + { + RawView.CopyFrom(UncompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress)); + } + } + else { - RawView.CopyFrom(UncompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress)); + RawView.CopyFrom(CompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress)); } + + OffsetInFirstBlock = 0; + RemainingRawSize -= BytesToUncompress; + CompressedOffset += CompressedBlockSize; + RawView += BytesToUncompress; } - else + Source.Detach(); + } + else + { + for (size_t BlockIndex = FirstBlockIndex; BlockIndex <= LastBlockIndex; BlockIndex++) { - RawView.CopyFrom(CompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress)); - } + const uint64_t UncompressedBlockSize = BlockIndex == Header.BlockCount - 1 ? LastBlockSize : BlockSize; + const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + const bool IsCompressed = CompressedBlockSize < UncompressedBlockSize; - OffsetInFirstBlock = 0; - RemainingRawSize -= BytesToUncompress; - CompressedOffset += CompressedBlockSize; - RawView += BytesToUncompress; + const uint64_t BytesToUncompress = OffsetInFirstBlock > 0 + ? zen::Min(RawView.GetSize(), UncompressedBlockSize - OffsetInFirstBlock) + : zen::Min(RemainingRawSize, BlockSize); + + MemoryView CompressedBlock = CompressedData.ViewOrCopyRange(CompressedOffset, CompressedBlockSize, CompressedBlockCopy); + + if (IsCompressed) + { + MutableMemoryView UncompressedBlock = RawView.Left(BytesToUncompress); + + const bool IsAligned = BytesToUncompress == UncompressedBlockSize; + if (!IsAligned) + { + // Decompress to a temporary buffer when the first or the last block reads are not aligned with the block boundaries. + if (UncompressedBlockCopy.IsNull()) + { + UncompressedBlockCopy = UniqueBuffer::Alloc(BlockSize); + } + UncompressedBlock = UncompressedBlockCopy.GetMutableView().Mid(0, UncompressedBlockSize); + } + + if (!DecompressBlock(UncompressedBlock, CompressedBlock)) + { + return false; + } + + if (!IsAligned) + { + RawView.CopyFrom(UncompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress)); + } + } + else + { + RawView.CopyFrom(CompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress)); + } + + OffsetInFirstBlock = 0; + RemainingRawSize -= BytesToUncompress; + CompressedOffset += CompressedBlockSize; + RawView += BytesToUncompress; + } } return RemainingRawSize == 0; @@ -713,10 +1223,13 @@ BlockDecoder::DecompressToComposite(DecoderContext& Context, const uint64_t RawOffset, const uint64_t RawSize) const { - UniqueBuffer Buffer = UniqueBuffer::Alloc(RawSize); - if (TryDecompressTo(Context, Source, Header, HeaderView, RawOffset, Buffer)) + if (RawSize > 0) { - return CompositeBuffer(Buffer.MoveToShared()); + UniqueBuffer Buffer = UniqueBuffer::Alloc(RawSize); + if (TryDecompressTo(Context, Source, Header, HeaderView, RawOffset, Buffer)) + { + return CompositeBuffer(Buffer.MoveToShared()); + } } return CompositeBuffer(); } @@ -871,7 +1384,7 @@ GetDecoder(CompressionMethod Method) ////////////////////////////////////////////////////////////////////////// bool -BufferHeader::IsValid(const CompositeBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +ReadHeader(const CompositeBuffer& CompressedData, BufferHeader& OutHeader, UniqueBuffer* OutHeaderData) { const uint64_t CompressedDataSize = CompressedData.GetSize(); if (CompressedDataSize < sizeof(BufferHeader)) @@ -879,61 +1392,89 @@ BufferHeader::IsValid(const CompositeBuffer& CompressedData, IoHash& OutRawHash, return false; } - const size_t StackBufferSize = 256; - uint8_t StackBuffer[StackBufferSize]; - uint64_t ReadSize = Min(CompressedDataSize, StackBufferSize); - BufferHeader* Header = reinterpret_cast<BufferHeader*>(StackBuffer); + const size_t HeaderBufferSize = 1024; + uint8_t HeaderBuffer[HeaderBufferSize]; + uint64_t ReadSize = Min(CompressedDataSize, HeaderBufferSize); + uint64_t FirstSegmentSize = CompressedData.GetSegments()[0].GetSize(); + if (FirstSegmentSize >= sizeof(BufferHeader)) { - CompositeBuffer::Iterator It; - CompressedData.CopyTo(MutableMemoryView(StackBuffer, StackBuffer + StackBufferSize), It); + // Keep first read inside first segment if possible + ReadSize = Min(ReadSize, FirstSegmentSize); } - Header->ByteSwap(); - if (Header->Magic != BufferHeader::ExpectedMagic) + + MutableMemoryView HeaderMemory(HeaderBuffer, &HeaderBuffer[ReadSize]); + CompositeBuffer::Iterator It = CompressedData.GetIterator(0); + CompressedData.CopyTo(HeaderMemory, It); + + OutHeader = *reinterpret_cast<BufferHeader*>(HeaderMemory.GetData()); + OutHeader.ByteSwap(); + if (OutHeader.Magic != BufferHeader::ExpectedMagic) { return false; } - - const BaseDecoder* const Decoder = GetDecoder(Header->Method); + if (OutHeader.TotalCompressedSize > CompressedDataSize) + { + return false; + } + const BaseDecoder* const Decoder = GetDecoder(OutHeader.Method); if (!Decoder) { return false; } - - uint32_t Crc32 = Header->Crc32; - OutRawHash = IoHash::FromBLAKE3(Header->RawHash); - OutRawSize = Header->TotalRawSize; - uint64_t HeaderSize = Decoder->GetHeaderSize(*Header); - - if (Header->TotalCompressedSize > CompressedDataSize) + uint64_t FullHeaderSize = Decoder->GetHeaderSize(OutHeader); + if (FullHeaderSize > CompressedDataSize) { return false; } - - Header->ByteSwap(); - - if (HeaderSize > ReadSize) + if (OutHeaderData) + { + *OutHeaderData = UniqueBuffer::Alloc(FullHeaderSize); + MutableMemoryView RemainingHeaderView = OutHeaderData->GetMutableView().CopyFrom(HeaderMemory.Mid(0, FullHeaderSize)); + if (!RemainingHeaderView.IsEmpty()) + { + CompressedData.CopyTo(RemainingHeaderView, It); + } + if (OutHeader.Crc32 != BufferHeader::CalculateCrc32(OutHeaderData->GetView())) + { + return false; + } + } + else if (FullHeaderSize < ReadSize) { - UniqueBuffer HeaderCopy = UniqueBuffer::Alloc(HeaderSize); - CompositeBuffer::Iterator It; - CompressedData.CopyTo(HeaderCopy.GetMutableView(), It); - const MemoryView HeaderView = HeaderCopy.GetView(); - if (Crc32 != BufferHeader::CalculateCrc32(HeaderView)) + if (OutHeader.Crc32 != BufferHeader::CalculateCrc32(HeaderMemory.Mid(0, FullHeaderSize))) { return false; } } else { - MemoryView FullHeaderView(StackBuffer, StackBuffer + HeaderSize); - if (Crc32 != BufferHeader::CalculateCrc32(FullHeaderView)) + UniqueBuffer HeaderData = UniqueBuffer::Alloc(FullHeaderSize); + MutableMemoryView RemainingHeaderView = HeaderData.GetMutableView().CopyFrom(HeaderMemory.Mid(0, FullHeaderSize)); + if (!RemainingHeaderView.IsEmpty()) + { + CompressedData.CopyTo(RemainingHeaderView, It); + } + if (OutHeader.Crc32 != BufferHeader::CalculateCrc32(HeaderData.GetView())) { return false; } } - return true; } +bool +BufferHeader::IsValid(const CompositeBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + detail::BufferHeader Header; + if (ReadHeader(CompressedData, Header, nullptr)) + { + OutRawHash = IoHash::FromBLAKE3(Header.RawHash); + OutRawSize = Header.TotalRawSize; + return true; + } + return false; +} + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// static bool @@ -1105,7 +1646,95 @@ ValidBufferOrEmpty(BufferType&& CompressedData, IoHash& OutRawHash, uint64_t& Ou } CompositeBuffer -CopyCompressedRange(const BufferHeader& Header, const CompositeBuffer& CompressedData, uint64_t RawOffset, uint64_t RawSize) +GetCompressedRange(const BufferHeader& Header, + MemoryView HeaderRawData, + const CompositeBuffer& CompressedData, + uint64_t RawOffset, + uint64_t RawSize) +{ + if (Header.TotalRawSize < RawOffset + RawSize) + { + return CompositeBuffer(); + } + if (Header.Method == CompressionMethod::None) + { + BufferHeader NewHeader = Header; + NewHeader.Crc32 = 0; + NewHeader.TotalRawSize = RawSize; + NewHeader.TotalCompressedSize = NewHeader.TotalRawSize + sizeof(BufferHeader); + NewHeader.RawHash = BLAKE3(); + + UniqueBuffer HeaderData = UniqueBuffer::Alloc(sizeof(BufferHeader)); + NewHeader.Write(HeaderData); + + return CompositeBuffer(HeaderData.MoveToShared(), CompressedData.Mid(sizeof(BufferHeader) + RawOffset, RawSize).MakeOwned()); + } + else + { + MemoryView BlockSizeView = HeaderRawData.Mid(sizeof(Header), Header.BlockCount * sizeof(uint32_t)); + std::span<uint32_t const> CompressedBlockSizes(reinterpret_cast<const uint32_t*>(BlockSizeView.GetData()), Header.BlockCount); + + const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent; + const uint64_t LastBlockSize = BlockSize - ((Header.BlockCount * BlockSize) - Header.TotalRawSize); + const size_t FirstBlock = uint64_t(RawOffset / BlockSize); + const size_t LastBlock = uint64_t((RawOffset + RawSize - 1) / BlockSize); + uint64_t CompressedOffset = sizeof(BufferHeader) + uint64_t(Header.BlockCount) * sizeof(uint32_t); + + const uint64_t NewBlockCount = LastBlock - FirstBlock + 1; + const uint64_t NewMetaSize = NewBlockCount * sizeof(uint32_t); + uint64_t NewCompressedSize = 0; + uint64_t NewTotalRawSize = 0; + std::vector<uint32_t> NewCompressedBlockSizes; + + NewCompressedBlockSizes.reserve(NewBlockCount); + for (size_t BlockIndex = FirstBlock; BlockIndex <= LastBlock; ++BlockIndex) + { + const uint64_t UncompressedBlockSize = (BlockIndex == Header.BlockCount - 1) ? LastBlockSize : BlockSize; + NewTotalRawSize += UncompressedBlockSize; + + const uint32_t CompressedBlockSize = CompressedBlockSizes[BlockIndex]; + NewCompressedBlockSizes.push_back(CompressedBlockSize); + NewCompressedSize += ByteSwap(CompressedBlockSize); + } + + const uint64_t NewTotalCompressedSize = sizeof(BufferHeader) + NewBlockCount * sizeof(uint32_t) + NewCompressedSize; + const uint64_t NewCompressedHeaderSize = sizeof(BufferHeader) + NewBlockCount * sizeof(uint32_t); + UniqueBuffer NewCompressedHeaderData = UniqueBuffer::Alloc(NewCompressedHeaderSize); + + // Seek to first compressed block + for (size_t BlockIndex = 0; BlockIndex < FirstBlock; ++BlockIndex) + { + const uint64_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + CompressedOffset += CompressedBlockSize; + } + + CompositeBuffer NewCompressedData = CompressedData.Mid(CompressedOffset, NewCompressedSize).MakeOwned(); + + // Copy block sizes + NewCompressedHeaderData.GetMutableView().Mid(sizeof(BufferHeader), NewMetaSize).CopyFrom(MakeMemoryView(NewCompressedBlockSizes)); + + BufferHeader NewHeader; + NewHeader.Crc32 = 0; + NewHeader.Method = Header.Method; + NewHeader.Compressor = Header.Compressor; + NewHeader.CompressionLevel = Header.CompressionLevel; + NewHeader.BlockSizeExponent = Header.BlockSizeExponent; + NewHeader.BlockCount = static_cast<uint32_t>(NewBlockCount); + NewHeader.TotalRawSize = NewTotalRawSize; + NewHeader.TotalCompressedSize = NewTotalCompressedSize; + NewHeader.RawHash = BLAKE3(); + NewHeader.Write(NewCompressedHeaderData.GetMutableView().Left(sizeof(BufferHeader) + NewMetaSize)); + + return CompositeBuffer(NewCompressedHeaderData.MoveToShared(), NewCompressedData); + } +} + +CompositeBuffer +CopyCompressedRange(const BufferHeader& Header, + MemoryView HeaderRawData, + const CompositeBuffer& CompressedData, + uint64_t RawOffset, + uint64_t RawSize) { if (Header.TotalRawSize < RawOffset + RawSize) { @@ -1130,9 +1759,7 @@ CopyCompressedRange(const BufferHeader& Header, const CompositeBuffer& Compresse } else { - UniqueBuffer BlockSizeBuffer; - MemoryView BlockSizeView = - CompressedData.ViewOrCopyRange(sizeof(BufferHeader), Header.BlockCount * sizeof(uint32_t), BlockSizeBuffer); + MemoryView BlockSizeView = HeaderRawData.Mid(sizeof(Header), Header.BlockCount * sizeof(uint32_t)); std::span<uint32_t const> CompressedBlockSizes(reinterpret_cast<const uint32_t*>(BlockSizeView.GetData()), Header.BlockCount); const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent; @@ -1233,6 +1860,31 @@ CompressedBuffer::Compress(const SharedBuffer& RawData, return Compress(CompositeBuffer(RawData), Compressor, CompressionLevel, BlockSize); } +bool +CompressedBuffer::CompressToStream( + const CompositeBuffer& RawData, + std::function<void(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback, + OodleCompressor Compressor, + OodleCompressionLevel CompressionLevel, + uint64_t BlockSize) +{ + using namespace detail; + + if (BlockSize == 0) + { + BlockSize = DefaultBlockSize; + } + + if (CompressionLevel == OodleCompressionLevel::None) + { + return NoneEncoder().CompressToStream(RawData, std::move(Callback), BlockSize); + } + else + { + return OodleEncoder(Compressor, CompressionLevel).CompressToStream(RawData, std::move(Callback), BlockSize); + } +} + CompressedBuffer CompressedBuffer::FromCompressed(const CompositeBuffer& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) { @@ -1301,6 +1953,26 @@ CompressedBuffer::ValidateCompressedHeader(const IoBuffer& CompressedData, IoHas return detail::BufferHeader::IsValid(SharedBuffer(CompressedData), OutRawHash, OutRawSize); } +size_t +CompressedBuffer::GetHeaderSizeForNoneEncoder() +{ + return sizeof(detail::BufferHeader); +} + +UniqueBuffer +CompressedBuffer::CreateHeaderForNoneEncoder(uint64_t RawSize, const BLAKE3& RawHash) +{ + detail::BufferHeader Header; + Header.Method = detail::CompressionMethod::None; + Header.BlockCount = 1; + Header.TotalRawSize = RawSize; + Header.TotalCompressedSize = Header.TotalRawSize + sizeof(detail::BufferHeader); + Header.RawHash = RawHash; + UniqueBuffer HeaderData = UniqueBuffer::Alloc(sizeof(detail::BufferHeader)); + Header.Write(HeaderData); + return HeaderData; +} + uint64_t CompressedBuffer::DecodeRawSize() const { @@ -1316,13 +1988,34 @@ CompressedBuffer::DecodeRawHash() const CompressedBuffer CompressedBuffer::CopyRange(uint64_t RawOffset, uint64_t RawSize) const { - using namespace detail; - const BufferHeader Header = BufferHeader::Read(CompressedData); - const uint64_t TotalRawSize = RawSize < ~uint64_t(0) ? RawSize : Header.TotalRawSize - RawOffset; - CompressedBuffer Range; - Range.CompressedData = CopyCompressedRange(Header, CompressedData, RawOffset, TotalRawSize); + if (RawSize > 0) + { + detail::BufferHeader Header; + UniqueBuffer RawHeaderData; + if (ReadHeader(CompressedData, Header, &RawHeaderData)) + { + const uint64_t TotalRawSize = RawSize < ~uint64_t(0) ? RawSize : Header.TotalRawSize - RawOffset; + Range.CompressedData = CopyCompressedRange(Header, RawHeaderData.GetView(), CompressedData, RawOffset, TotalRawSize); + } + } + return Range; +} +CompressedBuffer +CompressedBuffer::GetRange(uint64_t RawOffset, uint64_t RawSize) const +{ + CompressedBuffer Range; + if (RawSize > 0) + { + detail::BufferHeader Header; + UniqueBuffer RawHeaderData; + if (ReadHeader(CompressedData, Header, &RawHeaderData)) + { + const uint64_t TotalRawSize = RawSize < ~uint64_t(0) ? RawSize : Header.TotalRawSize - RawOffset; + Range.CompressedData = GetCompressedRange(Header, RawHeaderData.GetView(), CompressedData, RawOffset, TotalRawSize); + } + } return Range; } @@ -1330,7 +2023,7 @@ bool CompressedBuffer::TryDecompressTo(MutableMemoryView RawView, uint64_t RawOffset) const { using namespace detail; - if (CompressedData) + if (CompressedData && RawView.GetSize() > 0) { const BufferHeader Header = BufferHeader::Read(CompressedData); if (Header.Magic == BufferHeader::ExpectedMagic) @@ -1386,6 +2079,28 @@ CompressedBuffer::DecompressToComposite() const } bool +CompressedBuffer::DecompressToStream( + uint64_t RawOffset, + uint64_t RawSize, + std::function<bool(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback) const +{ + using namespace detail; + if (CompressedData) + { + const BufferHeader Header = BufferHeader::Read(CompressedData); + if (Header.Magic == BufferHeader::ExpectedMagic) + { + if (const BaseDecoder* const Decoder = GetDecoder(Header.Method)) + { + const uint64_t TotalRawSize = RawSize < ~uint64_t(0) ? RawSize : Header.TotalRawSize - RawOffset; + return Decoder->DecompressToStream(Header, CompressedData, RawOffset, TotalRawSize, std::move(Callback)); + } + } + } + return false; +} + +bool CompressedBuffer::TryGetCompressParameters(OodleCompressor& OutCompressor, OodleCompressionLevel& OutCompressionLevel, uint64_t& OutBlockSize) const @@ -1554,28 +2269,31 @@ CompressedBufferReader::TryDecompressTo(const MutableMemoryView RawView, const u SharedBuffer CompressedBufferReader::Decompress(const uint64_t RawOffset, const uint64_t RawSize) { - using namespace detail; - BufferHeader Header; - MemoryView HeaderView; - if (TryReadHeader(Header, HeaderView)) + if (RawSize > 0) { - const uint64_t TotalRawSize = Header.TotalRawSize; - const uint64_t RawSizeToCopy = RawSize == MAX_uint64 ? TotalRawSize - RawOffset : RawSize; - if (RawOffset <= TotalRawSize && RawSizeToCopy <= TotalRawSize - RawOffset) + using namespace detail; + BufferHeader Header; + MemoryView HeaderView; + if (TryReadHeader(Header, HeaderView)) { - if (const BaseDecoder* const Decoder = GetDecoder(Header.Method)) + const uint64_t TotalRawSize = Header.TotalRawSize; + const uint64_t RawSizeToCopy = RawSize == MAX_uint64 ? TotalRawSize - RawOffset : RawSize; + if (RawOffset <= TotalRawSize && RawSizeToCopy <= TotalRawSize - RawOffset) { - UniqueBuffer RawData = UniqueBuffer::Alloc(RawSizeToCopy); - if (Decoder->TryDecompressTo( - Context, - SourceArchive ? static_cast<const DecoderSource&>(ArchiveDecoderSource(*SourceArchive, Context.HeaderOffset)) - : static_cast<const DecoderSource&>(BufferDecoderSource(SourceBuffer->GetCompressed())), - Header, - HeaderView, - RawOffset, - RawData)) + if (const BaseDecoder* const Decoder = GetDecoder(Header.Method)) { - return RawData.MoveToShared(); + UniqueBuffer RawData = UniqueBuffer::Alloc(RawSizeToCopy); + if (Decoder->TryDecompressTo( + Context, + SourceArchive ? static_cast<const DecoderSource&>(ArchiveDecoderSource(*SourceArchive, Context.HeaderOffset)) + : static_cast<const DecoderSource&>(BufferDecoderSource(SourceBuffer->GetCompressed())), + Header, + HeaderView, + RawOffset, + RawData)) + { + return RawData.MoveToShared(); + } } } } @@ -1908,6 +2626,66 @@ TEST_CASE("CompressedBuffer") } } + SUBCASE("get range") + { + const uint64_t BlockSize = 64 * sizeof(uint64_t); + const uint64_t N = 1000; + std::vector<uint64_t> ExpectedValues = GenerateData(N); + + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)), + OodleCompressor::Mermaid, + OodleCompressionLevel::Optimal4, + BlockSize); + + { + const uint64_t OffsetCount = 0; + const uint64_t Count = N; + SharedBuffer Uncompressed = Compressed.GetRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 64; + const uint64_t Count = N - 64; + SharedBuffer Uncompressed = Compressed.GetRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 64 * 2 + 32; + const uint64_t Count = N - OffsetCount; + const uint64_t RawOffset = OffsetCount * sizeof(uint64_t); + const uint64_t RawSize = Count * sizeof(uint64_t); + uint64_t FirstBlockOffset = RawOffset % BlockSize; + + SharedBuffer Uncompressed = Compressed.GetRange(RawOffset, RawSize).Decompress(); + std::span<uint64_t const> AllValues((const uint64_t*)Uncompressed.GetData(), RawSize / sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)(((const uint8_t*)(Uncompressed.GetData()) + FirstBlockOffset)), + RawSize / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 64 * 2 + 63; + const uint64_t Count = N - OffsetCount - 5; + const uint64_t RawOffset = OffsetCount * sizeof(uint64_t); + const uint64_t RawSize = Count * sizeof(uint64_t); + uint64_t FirstBlockOffset = RawOffset % BlockSize; + + SharedBuffer Uncompressed = Compressed.GetRange(RawOffset, RawSize).Decompress(); + std::span<uint64_t const> AllValues((const uint64_t*)Uncompressed.GetData(), RawSize / sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)(((const uint8_t*)(Uncompressed.GetData()) + FirstBlockOffset)), + RawSize / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + } + SUBCASE("copy uncompressed range") { const uint64_t N = 1000; @@ -1944,6 +2722,43 @@ TEST_CASE("CompressedBuffer") ValidateData(Values, ExpectedValues, OffsetCount); } } + + SUBCASE("get uncompressed range") + { + const uint64_t N = 1000; + std::vector<uint64_t> ExpectedValues = GenerateData(N); + + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)), + OodleCompressor::NotSet, + OodleCompressionLevel::None); + + { + const uint64_t OffsetCount = 0; + const uint64_t Count = N; + SharedBuffer Uncompressed = Compressed.GetRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 1; + const uint64_t Count = N - OffsetCount; + SharedBuffer Uncompressed = Compressed.GetRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 42; + const uint64_t Count = 100; + SharedBuffer Uncompressed = Compressed.GetRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + } } TEST_CASE("CompressedBufferReader") diff --git a/src/zencore/crypto.cpp b/src/zencore/crypto.cpp index 8403a35f4..78bea0c17 100644 --- a/src/zencore/crypto.cpp +++ b/src/zencore/crypto.cpp @@ -2,6 +2,7 @@ #include <zencore/crypto.h> #include <zencore/intmath.h> +#include <zencore/memory/memory.h> #include <zencore/scopeguard.h> #include <zencore/testing.h> diff --git a/src/zencore/except.cpp b/src/zencore/except.cpp index d5eabea9d..610b0ced5 100644 --- a/src/zencore/except.cpp +++ b/src/zencore/except.cpp @@ -47,7 +47,7 @@ ThrowSystemException([[maybe_unused]] HRESULT hRes, [[maybe_unused]] std::string { if (HRESULT_FACILITY(hRes) == FACILITY_WIN32) { - throw std::system_error(std::error_code(hRes & 0xffff, std::system_category()), std::string(Message)); + throw std::system_error(std::error_code(HRESULT_CODE(hRes), std::system_category()), std::string(Message)); } else { diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 29ec14e0c..d18f21dbe 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -7,13 +7,19 @@ #include <zencore/fmtutils.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> +#include <zencore/memory/memory.h> #include <zencore/process.h> +#include <zencore/scopeguard.h> #include <zencore/stream.h> #include <zencore/string.h> #include <zencore/testing.h> +#include <zencore/workthreadpool.h> #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> +# include <ShlObj.h> +# pragma comment(lib, "shell32.lib") +# pragma comment(lib, "ole32.lib") #endif #if ZEN_PLATFORM_WINDOWS @@ -27,7 +33,9 @@ ZEN_THIRD_PARTY_INCLUDES_END # include <dirent.h> # include <fcntl.h> # include <sys/resource.h> +# include <sys/mman.h> # include <sys/stat.h> +# include <pwd.h> # include <unistd.h> #endif @@ -36,8 +44,10 @@ ZEN_THIRD_PARTY_INCLUDES_END # include <fcntl.h> # include <libproc.h> # include <sys/resource.h> +# include <sys/mman.h> # include <sys/stat.h> # include <sys/syslimits.h> +# include <pwd.h> # include <unistd.h> #endif @@ -78,21 +88,16 @@ DeleteReparsePoint(const wchar_t* Path, DWORD dwReparseTag) } bool -CreateDirectories(const wchar_t* Dir) +CreateDirectories(const wchar_t* Path) { - // This may be suboptimal, in that it appears to try and create directories - // from the root on up instead of from some directory which is known to - // be present - // - // We should implement a smarter version at some point since this can be - // pretty expensive in aggregate - - return std::filesystem::create_directories(Dir); + return CreateDirectories(std::filesystem::path(Path)); } // Erase all files and directories in a given directory, leaving an empty directory // behind +bool DeleteDirectoriesInternal(const wchar_t* DirPath); + static bool WipeDirectory(const wchar_t* DirPath, bool KeepDotFiles) { @@ -159,7 +164,7 @@ WipeDirectory(const wchar_t* DirPath, bool KeepDotFiles) } } - bool Succeeded = DeleteDirectories(Path.c_str()); + bool Succeeded = DeleteDirectoriesInternal(Path.c_str()); if (!Succeeded) { @@ -193,104 +198,386 @@ WipeDirectory(const wchar_t* DirPath, bool KeepDotFiles) } bool -DeleteDirectories(const wchar_t* DirPath) +DeleteDirectoriesInternal(const wchar_t* DirPath) { const bool KeepDotFiles = false; return WipeDirectory(DirPath, KeepDotFiles) && RemoveDirectoryW(DirPath) == TRUE; } bool -CleanDirectory(const wchar_t* DirPath) +CleanDirectory(const wchar_t* DirPath, bool KeepDotFiles) { - if (std::filesystem::exists(DirPath)) + if (IsDir(DirPath)) { - const bool KeepDotFiles = false; - return WipeDirectory(DirPath, KeepDotFiles); } return CreateDirectories(DirPath); } +#endif // ZEN_PLATFORM_WINDOWS + +#if ZEN_PLATFORM_WINDOWS +const uint32_t FileAttributesSystemReadOnlyFlag = FILE_ATTRIBUTE_READONLY; +#else +const uint32_t FileAttributesSystemReadOnlyFlag = 0x00000001; +#endif // ZEN_PLATFORM_WINDOWS + +const uint32_t FileModeWriteEnableFlags = 0222; + bool -CleanDirectory(const wchar_t* DirPath, bool KeepDotFiles) +IsFileAttributeReadOnly(uint32_t FileAttributes) +{ +#if ZEN_PLATFORM_WINDOWS + return (FileAttributes & FileAttributesSystemReadOnlyFlag) != 0; +#else + return (FileAttributes & 0x00000001) != 0; +#endif // ZEN_PLATFORM_WINDOWS +} + +bool +IsFileModeReadOnly(uint32_t FileMode) +{ + return (FileMode & FileModeWriteEnableFlags) == 0; +} + +uint32_t +MakeFileAttributeReadOnly(uint32_t FileAttributes, bool ReadOnly) +{ + return ReadOnly ? (FileAttributes | FileAttributesSystemReadOnlyFlag) : (FileAttributes & ~FileAttributesSystemReadOnlyFlag); +} + +uint32_t +MakeFileModeReadOnly(uint32_t FileMode, bool ReadOnly) { - if (std::filesystem::exists(DirPath)) + return ReadOnly ? (FileMode & ~FileModeWriteEnableFlags) : (FileMode | FileModeWriteEnableFlags); +} + +#if ZEN_PLATFORM_WINDOWS + +static DWORD +WinGetFileAttributes(const std::filesystem::path& Path, std::error_code& Ec) +{ + DWORD Attributes = ::GetFileAttributes(Path.native().c_str()); + if (Attributes == INVALID_FILE_ATTRIBUTES) { - return WipeDirectory(DirPath, KeepDotFiles); + DWORD LastError = GetLastError(); + switch (LastError) + { + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + case ERROR_BAD_NETPATH: + case ERROR_INVALID_DRIVE: + break; + case ERROR_ACCESS_DENIED: + { + WIN32_FIND_DATA FindData; + HANDLE FindHandle = ::FindFirstFile(Path.native().c_str(), &FindData); + if (FindHandle == INVALID_HANDLE_VALUE) + { + DWORD LastFindError = GetLastError(); + if (LastFindError != ERROR_FILE_NOT_FOUND) + { + Ec = MakeErrorCode(LastError); + } + } + else + { + FindClose(FindHandle); + Attributes = FindData.dwFileAttributes; + } + } + break; + default: + Ec = MakeErrorCode(LastError); + break; + } } - - return CreateDirectories(DirPath); + return Attributes; } #endif // ZEN_PLATFORM_WINDOWS bool -CreateDirectories(const std::filesystem::path& Dir) +RemoveDirNative(const std::filesystem::path& Path, std::error_code& Ec) { - if (Dir.string().ends_with(":")) +#if ZEN_PLATFORM_WINDOWS + BOOL Success = ::RemoveDirectory(Path.native().c_str()); + if (!Success) { + DWORD LastError = GetLastError(); + switch (LastError) + { + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + break; + default: + Ec = MakeErrorCode(LastError); + break; + } return false; } - while (!std::filesystem::is_directory(Dir)) + return true; +#else + return std::filesystem::remove(Path, Ec); +#endif // ZEN_PLATFORM_WINDOWS +} + +bool +RemoveFileNative(const std::filesystem::path& Path, bool ForceRemoveReadOnlyFiles, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + const std::filesystem::path::value_type* NativePath = Path.native().c_str(); + BOOL Success = ::DeleteFile(NativePath); + if (!Success) + { + if (ForceRemoveReadOnlyFiles) + { + DWORD FileAttributes = WinGetFileAttributes(NativePath, Ec); + if (Ec) + { + return false; + } + + if ((FileAttributes != INVALID_FILE_ATTRIBUTES) && IsFileAttributeReadOnly(FileAttributes) != 0) + { + ::SetFileAttributes(NativePath, MakeFileAttributeReadOnly(FileAttributes, false)); + Success = ::DeleteFile(NativePath); + } + } + if (!Success) + { + DWORD LastError = GetLastError(); + switch (LastError) + { + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + break; + default: + Ec = MakeErrorCode(LastError); + break; + } + return false; + } + } + return true; +#else + if (!ForceRemoveReadOnlyFiles) { - if (Dir.has_parent_path()) + struct stat Stat; + int err = stat(Path.native().c_str(), &Stat); + if (err != 0) { - CreateDirectories(Dir.parent_path()); + int32_t err = errno; + if (err == ENOENT) + { + Ec.clear(); + return false; + } } - std::error_code ErrorCode; - std::filesystem::create_directory(Dir, ErrorCode); - if (ErrorCode) + const uint32_t Mode = (uint32_t)Stat.st_mode; + if (IsFileModeReadOnly(Mode)) { - throw std::system_error(ErrorCode, fmt::format("Failed to create directories for '{}'", Dir.string())); + Ec = MakeErrorCode(EACCES); + return false; + } + } + return std::filesystem::remove(Path, Ec); +#endif // ZEN_PLATFORM_WINDOWS +} + +static void +WipeDirectoryContentInternal(const std::filesystem::path& Path, bool ForceRemoveReadOnlyFiles, std::error_code& Ec) +{ + DirectoryContent LocalDirectoryContent; + GetDirectoryContent(Path, DirectoryContentFlags::IncludeDirs | DirectoryContentFlags::IncludeFiles, LocalDirectoryContent); + for (const std::filesystem::path& LocalFilePath : LocalDirectoryContent.Files) + { + RemoveFileNative(LocalFilePath, ForceRemoveReadOnlyFiles, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + Sleep(100 + int(Retries * 50)); + Ec.clear(); + if (IsFile(LocalFilePath)) + { + RemoveFileNative(LocalFilePath, ForceRemoveReadOnlyFiles, Ec); + } + } + if (Ec) + { + return; + } + } + + for (std::filesystem::path& LocalDirPath : LocalDirectoryContent.Directories) + { + WipeDirectoryContentInternal(LocalDirPath, ForceRemoveReadOnlyFiles, Ec); + if (Ec) + { + return; + } + + RemoveDirNative(LocalDirPath, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + Sleep(100 + int(Retries * 50)); + Ec.clear(); + if (IsDir(LocalDirPath)) + { + RemoveDirNative(LocalDirPath, Ec); + } + } + if (Ec) + { + return; } - return true; } - return false; } bool -DeleteDirectories(const std::filesystem::path& Dir) +CreateDirectory(const std::filesystem::path& Path, std::error_code& Ec) { #if ZEN_PLATFORM_WINDOWS - return DeleteDirectories(Dir.c_str()); + BOOL Success = ::CreateDirectory(Path.native().c_str(), nullptr); + if (!Success) + { + DWORD LastError = GetLastError(); + switch (LastError) + { + case ERROR_FILE_EXISTS: + case ERROR_ALREADY_EXISTS: + break; + default: + Ec = MakeErrorCode(LastError); + break; + } + return false; + } + return Success; #else - std::error_code ErrorCode; - return std::filesystem::remove_all(Dir, ErrorCode); -#endif + return std::filesystem::create_directory(Path, Ec); +#endif // ZEN_PLATFORM_WINDOWS } bool -CleanDirectory(const std::filesystem::path& Dir) +CreateDirectories(const std::filesystem::path& Path) { -#if ZEN_PLATFORM_WINDOWS - return CleanDirectory(Dir.c_str()); -#else - if (std::filesystem::exists(Dir)) + std::error_code Ec; + bool Success = CreateDirectories(Path, Ec); + if (Ec) { - bool Success = true; + throw std::system_error(Ec, fmt::format("Failed to create directories for '{}'", Path.string())); + } + return Success; +} - std::error_code ErrorCode; - for (const auto& Item : std::filesystem::directory_iterator(Dir)) +bool +CreateDirectories(const std::filesystem::path& Path, std::error_code& Ec) +{ + if (Path.string().ends_with(":")) + { + return false; + } + bool Exists = IsDir(Path, Ec); + if (Ec) + { + return false; + } + if (Exists) + { + return false; + } + + if (Path.has_parent_path()) + { + bool Result = CreateDirectories(Path.parent_path(), Ec); + if (Ec) { - Success &= std::filesystem::remove_all(Item, ErrorCode); + return Result; } + } + return CreateDirectory(Path, Ec); +} - return Success; +bool +CleanDirectory(const std::filesystem::path& Path, bool ForceRemoveReadOnlyFiles) +{ + std::error_code Ec; + bool Result = CleanDirectory(Path, ForceRemoveReadOnlyFiles, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to clean directory for '{}'", Path.string())); } + return Result; +} - return CreateDirectories(Dir); -#endif +bool +CleanDirectory(const std::filesystem::path& Path, bool ForceRemoveReadOnlyFiles, std::error_code& Ec) +{ + bool Exists = IsDir(Path, Ec); + if (Ec) + { + return Exists; + } + if (Exists) + { + WipeDirectoryContentInternal(Path, ForceRemoveReadOnlyFiles, Ec); + return false; + } + return CreateDirectory(Path, Ec); +} + +bool +DeleteDirectories(const std::filesystem::path& Path) +{ + std::error_code Ec; + bool Result = DeleteDirectories(Path, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to delete directories for '{}'", Path.string())); + } + return Result; +} + +bool +DeleteDirectories(const std::filesystem::path& Path, std::error_code& Ec) +{ + bool Exists = IsDir(Path, Ec); + if (Ec) + { + return Exists; + } + + if (Exists) + { + WipeDirectoryContentInternal(Path, false, Ec); + if (Ec) + { + return false; + } + bool Result = RemoveDirNative(Path, Ec); + for (size_t Retries = 0; Ec && Retries < 3; Retries++) + { + Sleep(100 + int(Retries * 50)); + Ec.clear(); + if (IsDir(Path)) + { + Result = RemoveDirNative(Path, Ec); + } + } + return Result; + } + return false; } bool -CleanDirectoryExceptDotFiles(const std::filesystem::path& Dir) +CleanDirectoryExceptDotFiles(const std::filesystem::path& Path) { #if ZEN_PLATFORM_WINDOWS const bool KeepDotFiles = true; - return CleanDirectory(Dir.c_str(), KeepDotFiles); + return CleanDirectory(Path.c_str(), KeepDotFiles); #else - ZEN_UNUSED(Dir); + ZEN_UNUSED(Path); ZEN_NOT_IMPLEMENTED(); #endif @@ -541,7 +828,10 @@ CloneFile(std::filesystem::path FromPath, std::filesystem::path ToPath) } void -CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const CopyFileOptions& Options, std::error_code& OutErrorCode) +CopyFile(const std::filesystem::path& FromPath, + const std::filesystem::path& ToPath, + const CopyFileOptions& Options, + std::error_code& OutErrorCode) { OutErrorCode.clear(); @@ -554,7 +844,7 @@ CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop } bool -CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const CopyFileOptions& Options) +CopyFile(const std::filesystem::path& FromPath, const std::filesystem::path& ToPath, const CopyFileOptions& Options) { bool Success = false; @@ -597,7 +887,7 @@ CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop ScopedFd $From = {FromFd}; // To file - int ToFd = open(ToPath.c_str(), O_WRONLY | O_CREAT | O_EXCL | O_CLOEXEC, 0666); + int ToFd = open(ToPath.c_str(), O_WRONLY | O_CREAT | O_CLOEXEC, 0666); if (ToFd < 0) { ThrowLastError(fmt::format("failed to create file {}", ToPath)); @@ -605,9 +895,16 @@ CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop fchmod(ToFd, 0666); ScopedFd $To = {ToFd}; + struct stat Stat; + fstat(FromFd, &Stat); + + size_t FileSizeBytes = Stat.st_size; + + fchown(ToFd, Stat.st_uid, Stat.st_gid); + // Copy impl - static const size_t BufferSize = 64 << 10; - void* Buffer = malloc(BufferSize); + const size_t BufferSize = Min(FileSizeBytes, 64u << 10); + void* Buffer = malloc(BufferSize); while (true) { int BytesRead = read(FromFd, Buffer, BufferSize); @@ -617,7 +914,7 @@ CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop break; } - if (write(ToFd, Buffer, BytesRead) != BufferSize) + if (write(ToFd, Buffer, BytesRead) != BytesRead) { Success = false; break; @@ -628,7 +925,7 @@ CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop if (!Success) { - ThrowLastError("file copy failed"sv); + ThrowLastError(fmt::format("file copy from {} to {} failed", FromPath, ToPath)); } return true; @@ -639,7 +936,7 @@ CopyTree(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop { // Validate arguments - if (FromPath.empty() || !std::filesystem::is_directory(FromPath)) + if (FromPath.empty() || !IsDir(FromPath)) throw std::runtime_error("invalid CopyTree source directory specified"); if (ToPath.empty()) @@ -648,16 +945,13 @@ CopyTree(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop if (Options.MustClone && !SupportsBlockRefCounting(FromPath)) throw std::runtime_error(fmt::format("cloning not possible from '{}'", FromPath)); - if (std::filesystem::exists(ToPath)) + if (IsFile(ToPath)) { - if (!std::filesystem::is_directory(ToPath)) - { - throw std::runtime_error(fmt::format("specified CopyTree target '{}' is not a directory", ToPath)); - } + throw std::runtime_error(fmt::format("specified CopyTree target '{}' is not a directory", ToPath)); } - else + if (!IsDir(ToPath)) { - std::filesystem::create_directories(ToPath); + CreateDirectories(ToPath); } if (Options.MustClone && !SupportsBlockRefCounting(ToPath)) @@ -693,7 +987,7 @@ CopyTree(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop { } - virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t FileSize) override + virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t FileSize, uint32_t, uint64_t) override { std::error_code Ec; const std::filesystem::path Relative = std::filesystem::relative(Parent, BasePath, Ec); @@ -730,7 +1024,7 @@ CopyTree(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop throw std::runtime_error("CopyFile failed in an unexpected way"); } } - catch (std::exception& Ex) + catch (const std::exception& Ex) { ++FailedFileCount; @@ -739,7 +1033,7 @@ CopyTree(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop } } - virtual bool VisitDirectory(const std::filesystem::path&, const path_view&) override { return true; } + virtual bool VisitDirectory(const std::filesystem::path&, const path_view&, uint32_t) override { return true; } std::filesystem::path BasePath; std::filesystem::path TargetPath; @@ -762,6 +1056,100 @@ CopyTree(std::filesystem::path FromPath, std::filesystem::path ToPath, const Cop } void +WriteFile(void* NativeHandle, const void* Data, uint64_t Size, uint64_t FileOffset, uint64_t ChunkSize, std::error_code& Ec) +{ + ZEN_ASSERT(NativeHandle != nullptr); + + Ec.clear(); + + while (Size) + { + const uint64_t NumberOfBytesToWrite = Min(Size, ChunkSize); + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(FileOffset >> 32); + + DWORD dwNumberOfBytesWritten = 0; + + BOOL Success = ::WriteFile(NativeHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl); +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(NativeHandle)); + int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, FileOffset); + bool Success = (BytesWritten > 0); +#endif + + if (!Success) + { + Ec = MakeErrorCodeFromLastError(); + return; + } + + Size -= NumberOfBytesToWrite; + FileOffset += NumberOfBytesToWrite; + Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesToWrite; + } +} + +void +ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uint64_t ChunkSize, std::error_code& Ec) +{ + while (Size) + { + const uint64_t NumberOfBytesToRead = Min(Size, ChunkSize); + size_t BytesRead = 0; + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(FileOffset >> 32); + + DWORD dwNumberOfBytesRead = 0; + BOOL Success = ::ReadFile(NativeHandle, Data, DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl); + if (Success) + { + BytesRead = size_t(dwNumberOfBytesRead); + } + else if ((BytesRead != NumberOfBytesToRead)) + { + Ec = MakeErrorCode(ERROR_HANDLE_EOF); + return; + } + else + { + Ec = MakeErrorCodeFromLastError(); + return; + } +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(NativeHandle)); + ssize_t ReadResult = pread(Fd, Data, NumberOfBytesToRead, FileOffset); + if (ReadResult != -1) + { + BytesRead = size_t(ReadResult); + } + else if ((BytesRead != NumberOfBytesToRead)) + { + Ec = MakeErrorCode(EIO); + return; + } + else + { + Ec = MakeErrorCodeFromLastError(); + return; + } +#endif + Size -= NumberOfBytesToRead; + FileOffset += NumberOfBytesToRead; + Data = reinterpret_cast<uint8_t*>(Data) + NumberOfBytesToRead; + } +} + +void WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t BufferCount) { #if ZEN_PLATFORM_WINDOWS @@ -811,11 +1199,17 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer hRes = Outfile.Write(DataPtr, gsl::narrow_cast<uint32_t>(WriteSize)); if (FAILED(hRes)) { + Outfile.Close(); + std::error_code DummyEc; + RemoveFile(Path, DummyEc); ThrowSystemException(hRes, fmt::format("File write failed for '{}'", Path).c_str()); } #else if (write(Fd, DataPtr, WriteSize) != int64_t(WriteSize)) { + close(Fd); + std::error_code DummyEc; + RemoveFile(Path, DummyEc); ThrowLastError(fmt::format("File write failed for '{}'", Path)); } #endif // ZEN_PLATFORM_WINDOWS @@ -825,7 +1219,9 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer } } -#if !ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_WINDOWS + Outfile.Close(); +#else close(Fd); #endif } @@ -894,6 +1290,27 @@ MoveToFile(std::filesystem::path Path, IoBuffer Data) zen::CreateDirectories(Path.parent_path()); Success = SetFileInformationByHandle(ChunkFileHandle, FileRenameInfo, RenameInfo, BufferSize); } + if (!Success && (LastError == ERROR_ACCESS_DENIED)) + { + // Fallback to regular rename + std::error_code Ec; + std::filesystem::path SourcePath = PathFromHandle(FileRef.FileHandle, Ec); + if (!Ec) + { + auto NativeSourcePath = SourcePath.native().c_str(); + auto NativeTargetPath = Path.native().c_str(); + Success = ::MoveFile(NativeSourcePath, NativeTargetPath); + if (!Success) + { + LastError = GetLastError(); + if (LastError == ERROR_PATH_NOT_FOUND) + { + zen::CreateDirectories(Path.parent_path()); + Success = ::MoveFile(NativeSourcePath, NativeTargetPath); + } + } + } + } } Memory::Free(RenameInfo); if (!Success) @@ -901,15 +1318,20 @@ MoveToFile(std::filesystem::path Path, IoBuffer Data) return false; } #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - std::filesystem::path SourcePath = PathFromHandle(FileRef.FileHandle); - int Ret = link(SourcePath.c_str(), Path.c_str()); + std::error_code Ec; + std::filesystem::path SourcePath = PathFromHandle(FileRef.FileHandle, Ec); + if (Ec) + { + return false; + } + int Ret = rename(SourcePath.c_str(), Path.c_str()); if (Ret < 0) { int32_t err = errno; if (err == ENOENT) { zen::CreateDirectories(Path.parent_path()); - Ret = link(SourcePath.c_str(), Path.c_str()); + Ret = rename(SourcePath.c_str(), Path.c_str()); } } if (Ret < 0) @@ -1007,6 +1429,46 @@ ReadFile(std::filesystem::path Path) return Contents; } +ZENCORE_API void +ScanFile(void* NativeHandle, + uint64_t Offset, + uint64_t Size, + uint64_t ChunkSize, + std::function<void(const void* Data, size_t Size)>&& ProcessFunc) +{ + ZEN_ASSERT(NativeHandle != nullptr); + uint64_t BufferSize = Min(ChunkSize, Size); + std::vector<uint8_t> ReadBuffer(BufferSize); + uint64_t ReadOffset = 0; + while (ReadOffset < Size) + { + const uint64_t NumberOfBytesToRead = Min(Size - ReadOffset, BufferSize); + uint64_t FileOffset = Offset + ReadOffset; + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(FileOffset >> 32); + + DWORD BytesRead = 0; + BOOL Success = ::ReadFile(NativeHandle, ReadBuffer.data(), DWORD(NumberOfBytesToRead), &BytesRead, &Ovl); + if (!Success) + { + throw std::system_error(std::error_code(::GetLastError(), std::system_category()), "file scan failed"); + } +#else + int BytesRead = pread(int(intptr_t(NativeHandle)), ReadBuffer.data(), size_t(NumberOfBytesToRead), off_t(FileOffset)); + if (BytesRead < 0) + { + throw std::system_error(std::error_code(errno, std::system_category()), "file scan failed"); + } +#endif + ProcessFunc(ReadBuffer.data(), (size_t)BytesRead); + ReadOffset += (uint64_t)BytesRead; + } +} + bool ScanFile(std::filesystem::path Path, const uint64_t ChunkSize, std::function<void(const void* Data, size_t Size)>&& ProcessFunc) { @@ -1042,7 +1504,7 @@ ScanFile(std::filesystem::path Path, const uint64_t ChunkSize, std::function<voi ProcessFunc(ReadBuffer.data(), dwBytesRead); } #else - int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC); + int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC); if (Fd < 0) { return false; @@ -1121,7 +1583,7 @@ void FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, TreeVisitor& Visitor) { #if ZEN_PLATFORM_WINDOWS - uint64_t FileInfoBuffer[8 * 1024]; + std::vector<uint64_t> FileInfoBuffer(8 * 1024); FILE_INFO_BY_HANDLE_CLASS FibClass = FileIdBothDirectoryRestartInfo; bool Continue = true; @@ -1132,7 +1594,7 @@ FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, Tr if (FAILED(hRes)) { - if (hRes == ERROR_FILE_NOT_FOUND) + if (HRESULT_CODE(hRes) == ERROR_FILE_NOT_FOUND || HRESULT_CODE(hRes) == ERROR_PATH_NOT_FOUND) { // Directory no longer exist, treat it as empty return; @@ -1142,8 +1604,9 @@ FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, Tr while (Continue) { - BOOL Success = GetFileInformationByHandleEx(RootDirHandle, FibClass, FileInfoBuffer, sizeof FileInfoBuffer); - FibClass = FileIdBothDirectoryInfo; // Set up for next iteration + BOOL Success = + GetFileInformationByHandleEx(RootDirHandle, FibClass, FileInfoBuffer.data(), (DWORD)(FileInfoBuffer.size() * sizeof(uint64_t))); + FibClass = FileIdBothDirectoryInfo; // Set up for next iteration uint64_t EntryOffset = 0; @@ -1162,7 +1625,7 @@ FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, Tr for (;;) { const FILE_ID_BOTH_DIR_INFO* DirInfo = - reinterpret_cast<const FILE_ID_BOTH_DIR_INFO*>(reinterpret_cast<const uint8_t*>(FileInfoBuffer) + EntryOffset); + reinterpret_cast<const FILE_ID_BOTH_DIR_INFO*>(reinterpret_cast<const uint8_t*>(FileInfoBuffer.data()) + EntryOffset); std::wstring_view FileName(DirInfo->FileName, DirInfo->FileNameLength / sizeof(wchar_t)); @@ -1174,7 +1637,7 @@ FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, Tr } else { - const bool ShouldDescend = Visitor.VisitDirectory(RootDir, FileName); + const bool ShouldDescend = Visitor.VisitDirectory(RootDir, FileName, gsl::narrow<uint32_t>(DirInfo->FileAttributes)); if (ShouldDescend) { @@ -1193,7 +1656,11 @@ FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, Tr } else { - Visitor.VisitFile(RootDir, FileName, DirInfo->EndOfFile.QuadPart); + Visitor.VisitFile(RootDir, + FileName, + DirInfo->EndOfFile.QuadPart, + gsl::narrow<uint32_t>(DirInfo->FileAttributes), + (uint64_t)DirInfo->LastWriteTime.QuadPart); } const uint64_t NextOffset = DirInfo->NextEntryOffset; @@ -1235,14 +1702,14 @@ FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, Tr { /* nop */ } - else if (Visitor.VisitDirectory(RootDir, FileName)) + else if (Visitor.VisitDirectory(RootDir, FileName, gsl::narrow<uint32_t>(Stat.st_mode))) { TraverseFileSystem(FullPath, Visitor); } } else if (S_ISREG(Stat.st_mode)) { - Visitor.VisitFile(RootDir, FileName, Stat.st_size); + Visitor.VisitFile(RootDir, FileName, Stat.st_size, gsl::narrow<uint32_t>(Stat.st_mode), gsl::narrow<uint64_t>(Stat.st_mtime)); } else { @@ -1283,6 +1750,156 @@ CanonicalPath(std::filesystem::path InPath, std::error_code& Ec) #endif } +bool +IsFile(const std::filesystem::path& Path) +{ + std::error_code Ec; + bool Result = IsFile(Path, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to test if path '{}' is a file", Path.string())); + } + return Result; +} + +bool +IsFile(const std::filesystem::path& Path, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + DWORD Attributes = WinGetFileAttributes(Path, Ec); + if (Ec) + { + return false; + } + if (Attributes == INVALID_FILE_ATTRIBUTES) + { + return false; + } + return (Attributes & FILE_ATTRIBUTE_DIRECTORY) == 0; +#else + struct stat Stat; + int err = stat(Path.native().c_str(), &Stat); + if (err != 0) + { + int32_t err = errno; + if (err == ENOENT) + { + Ec.clear(); + return false; + } + } + if (S_ISREG(Stat.st_mode)) + { + return true; + } + return false; +#endif // ZEN_PLATFORM_WINDOWS +} + +bool +IsDir(const std::filesystem::path& Path) +{ + std::error_code Ec; + bool Result = IsDir(Path, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to test if path '{}' is a directory", Path.string())); + } + return Result; +} + +bool +IsDir(const std::filesystem::path& Path, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + DWORD Attributes = WinGetFileAttributes(Path, Ec); + if (Ec) + { + return false; + } + if (Attributes == INVALID_FILE_ATTRIBUTES) + { + return false; + } + return (Attributes & FILE_ATTRIBUTE_DIRECTORY) == FILE_ATTRIBUTE_DIRECTORY; +#else + struct stat Stat; + int err = stat(Path.native().c_str(), &Stat); + if (err != 0) + { + int32_t err = errno; + if (err == ENOENT) + { + Ec.clear(); + return false; + } + } + if (S_ISDIR(Stat.st_mode)) + { + return true; + } + return false; +#endif // ZEN_PLATFORM_WINDOWS +} + +bool +RemoveFile(const std::filesystem::path& Path) +{ + std::error_code Ec; + bool Success = RemoveFile(Path, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to remove file '{}'", Path.string())); + } + return Success; +} + +bool +RemoveFile(const std::filesystem::path& Path, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + return RemoveFileNative(Path, false, Ec); +#else + bool IsDirectory = std::filesystem::is_directory(Path, Ec); + if (IsDirectory) + { + Ec = MakeErrorCode(EPERM); + return false; + } + Ec.clear(); + return RemoveFileNative(Path, false, Ec); +#endif // ZEN_PLATFORM_WINDOWS +} + +bool +RemoveDir(const std::filesystem::path& Path) +{ + std::error_code Ec; + bool Success = RemoveDir(Path, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to remove directory '{}'", Path.string())); + } + return Success; +} + +bool +RemoveDir(const std::filesystem::path& Path, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + return RemoveDirNative(Path, Ec); +#else + bool IsFile = std::filesystem::is_regular_file(Path, Ec); + if (IsFile) + { + Ec = MakeErrorCode(EPERM); + return false; + } + Ec.clear(); + return RemoveDirNative(Path, Ec); +#endif // ZEN_PLATFORM_WINDOWS +} + std::filesystem::path PathFromHandle(void* NativeHandle, std::error_code& Ec) { @@ -1379,39 +1996,242 @@ PathFromHandle(void* NativeHandle, std::error_code& Ec) #endif // ZEN_PLATFORM_WINDOWS } -std::filesystem::path -PathFromHandle(void* NativeHandle) +uint64_t +FileSizeFromPath(const std::filesystem::path& Path) { - std::error_code Ec; - std::filesystem::path Result = PathFromHandle(NativeHandle, Ec); + std::error_code Ec; + uint64_t Size = FileSizeFromPath(Path, Ec); if (Ec) { - throw std::system_error(Ec, fmt::format("failed to get path from file handle '{}'", NativeHandle)); + throw std::system_error(Ec, fmt::format("Failed to get file size for path '{}'", Path.string())); } - return Result; + return Size; } uint64_t -FileSizeFromHandle(void* NativeHandle) +FileSizeFromPath(const std::filesystem::path& Path, std::error_code& Ec) { - uint64_t FileSize = ~0ull; +#if ZEN_PLATFORM_WINDOWS + void* Handle = ::CreateFile(Path.native().c_str(), + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, + OPEN_EXISTING, + 0, + nullptr); + if (Handle == INVALID_HANDLE_VALUE) + { + DWORD LastError = GetLastError(); + Ec = MakeErrorCode(LastError); + return 0; + } + auto _ = MakeGuard([Handle]() { CloseHandle(Handle); }); + LARGE_INTEGER FileSize; + BOOL Success = GetFileSizeEx(Handle, &FileSize); + if (!Success) + { + Ec = MakeErrorCodeFromLastError(); + return 0; + } + return FileSize.QuadPart; +#else + return std::filesystem::file_size(Path, Ec); +#endif // ZEN_PLATFORM_WINDOWS +} +uint64_t +FileSizeFromHandle(void* NativeHandle, std::error_code& Ec) +{ #if ZEN_PLATFORM_WINDOWS BY_HANDLE_FILE_INFORMATION Bhfh = {}; if (GetFileInformationByHandle(NativeHandle, &Bhfh)) { - FileSize = uint64_t(Bhfh.nFileSizeHigh) << 32 | Bhfh.nFileSizeLow; + return uint64_t(Bhfh.nFileSizeHigh) << 32 | Bhfh.nFileSizeLow; + } + else + { + Ec = MakeErrorCodeFromLastError(); + return 0; } #else - int Fd = int(intptr_t(NativeHandle)); + int Fd = int(intptr_t(NativeHandle)); + static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); struct stat Stat; - fstat(Fd, &Stat); - FileSize = size_t(Stat.st_size); + if (fstat(Fd, &Stat) == -1) + { + Ec = MakeErrorCodeFromLastError(); + return 0; + } + return uint64_t(Stat.st_size); #endif +} +uint64_t +FileSizeFromHandle(void* NativeHandle) +{ + std::error_code Ec; + uint64_t FileSize = FileSizeFromHandle(NativeHandle, Ec); + if (Ec) + { + return ~0ull; + } return FileSize; } +uint64_t +GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + FILETIME LastWriteTime; + BOOL OK = GetFileTime((HANDLE)NativeHandle, NULL, NULL, &LastWriteTime); + if (OK) + { + return ((uint64_t(LastWriteTime.dwHighDateTime) << 32) | LastWriteTime.dwLowDateTime); + } +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Fd = int(uintptr_t(NativeHandle)); + struct stat Stat; + if (0 == fstat(Fd, &Stat)) + { + return gsl::narrow<uint64_t>(Stat.st_mtime); + } +#endif + Ec = MakeErrorCodeFromLastError(); + return 0; +} + +uint64_t +GetModificationTickFromPath(const std::filesystem::path& Filename) +{ + // PathFromHandle + void* Handle; +#if ZEN_PLATFORM_WINDOWS + Handle = CreateFileW(Filename.c_str(), + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, + OPEN_EXISTING, + 0, + nullptr); + if (Handle == INVALID_HANDLE_VALUE) + { + ThrowLastError(fmt::format("Failed to open file {} to check modification tick.", Filename)); + } + auto _ = MakeGuard([Handle]() { CloseHandle(Handle); }); + std::error_code Ec; + uint64_t ModificatonTick = GetModificationTickFromHandle(Handle, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to get modification tick for path '{}'", Filename.string())); + } + return ModificatonTick; +#else + struct stat Stat; + int err = stat(Filename.native().c_str(), &Stat); + if (err) + { + ThrowLastError(fmt::format("Failed to get mode of file {}", Filename)); + } + return gsl::narrow<uint64_t>(Stat.st_mtime); +#endif +} + +bool +TryGetFileProperties(const std::filesystem::path& Path, + uint64_t& OutSize, + uint64_t& OutModificationTick, + uint32_t& OutNativeModeOrAttributes) +{ +#if ZEN_PLATFORM_WINDOWS + const std::filesystem::path::value_type* NativePath = Path.native().c_str(); + { + void* Handle = CreateFileW(NativePath, + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, + OPEN_EXISTING, + 0, + nullptr); + if (Handle == INVALID_HANDLE_VALUE) + { + return false; + } + auto _ = MakeGuard([Handle]() { CloseHandle(Handle); }); + + BY_HANDLE_FILE_INFORMATION Bhfh = {}; + if (!GetFileInformationByHandle(Handle, &Bhfh)) + { + return false; + } + OutSize = uint64_t(Bhfh.nFileSizeHigh) << 32 | Bhfh.nFileSizeLow; + OutModificationTick = ((uint64_t(Bhfh.ftLastWriteTime.dwHighDateTime) << 32) | Bhfh.ftLastWriteTime.dwLowDateTime); + OutNativeModeOrAttributes = Bhfh.dwFileAttributes; + return true; + } +#else + struct stat Stat; + int err = stat(Path.native().c_str(), &Stat); + if (err) + { + return false; + } + OutModificationTick = gsl::narrow<uint64_t>(Stat.st_mtime); + OutSize = size_t(Stat.st_size); + OutNativeModeOrAttributes = (uint32_t)Stat.st_mode; + return true; +#endif +} + +void +RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath) +{ + std::error_code Ec; + RenameFile(SourcePath, TargetPath, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to rename path from '{}' to '{}'", SourcePath.string(), TargetPath.string())); + } +} + +void +RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + BOOL Success = ::MoveFileEx(SourcePath.native().c_str(), TargetPath.native().c_str(), MOVEFILE_REPLACE_EXISTING); + if (!Success) + { + Ec = MakeErrorCodeFromLastError(); + } +#else + return std::filesystem::rename(SourcePath, TargetPath, Ec); +#endif // ZEN_PLATFORM_WINDOWS +} + +void +RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath) +{ + std::error_code Ec; + RenameDirectory(SourcePath, TargetPath, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to rename directory from '{}' to '{}'", SourcePath.string(), TargetPath.string())); + } +} + +void +RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + BOOL Success = ::MoveFile(SourcePath.native().c_str(), TargetPath.native().c_str()); + if (!Success) + { + Ec = MakeErrorCodeFromLastError(); + } +#else + return std::filesystem::rename(SourcePath, TargetPath, Ec); +#endif // ZEN_PLATFORM_WINDOWS +} + std::filesystem::path GetRunningExecutablePath() { @@ -1475,40 +2295,218 @@ MaximizeOpenFileCount() #endif } +bool +PrepareFileForScatteredWrite(void* FileHandle, uint64_t FinalSize) +{ + bool Result = true; +#if ZEN_PLATFORM_WINDOWS + + BY_HANDLE_FILE_INFORMATION Information; + if (GetFileInformationByHandle(FileHandle, &Information)) + { + if ((Information.dwFileAttributes & FILE_ATTRIBUTE_SPARSE_FILE) == 0) + { + DWORD _ = 0; + BOOL Ok = DeviceIoControl(FileHandle, FSCTL_SET_SPARSE, nullptr, 0, nullptr, 0, &_, nullptr); + if (!Ok) + { + std::error_code DummyEc; + ZEN_DEBUG("Unable to set sparse mode for file '{}'", PathFromHandle(FileHandle, DummyEc)); + Result = false; + } + } + } + + FILE_ALLOCATION_INFO AllocationInfo = {}; + AllocationInfo.AllocationSize.QuadPart = LONGLONG(FinalSize); + if (!SetFileInformationByHandle(FileHandle, FileAllocationInfo, &AllocationInfo, DWORD(sizeof(AllocationInfo)))) + { + std::error_code DummyEc; + ZEN_DEBUG("Unable to set file allocation size to {} for file '{}'", FinalSize, PathFromHandle(FileHandle, DummyEc)); + Result = false; + } + +#else // ZEN_PLATFORM_WINDOWS + ZEN_UNUSED(FileHandle, FinalSize); +#endif // ZEN_PLATFORM_WINDOWS + return Result; +} + void -GetDirectoryContent(const std::filesystem::path& RootDir, uint8_t Flags, DirectoryContent& OutContent) +GetDirectoryContent(const std::filesystem::path& RootDir, DirectoryContentFlags Flags, DirectoryContent& OutContent) { + ZEN_ASSERT(EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::IncludeDirs)); + ZEN_ASSERT(EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFiles) + ? true + : (!EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFileSizes))); + FileSystemTraversal Traversal; struct Visitor : public FileSystemTraversal::TreeVisitor { - Visitor(uint8_t Flags, DirectoryContent& OutContent) : Flags(Flags), Content(OutContent) {} + Visitor(DirectoryContentFlags Flags, DirectoryContent& OutContent) : Flags(Flags), Content(OutContent) {} - virtual void VisitFile([[maybe_unused]] const std::filesystem::path& Parent, - [[maybe_unused]] const path_view& File, - [[maybe_unused]] uint64_t FileSize) override + virtual void VisitFile(const std::filesystem::path& Parent, + const path_view& File, + uint64_t FileSize, + uint32_t NativeModeOrAttributes, + uint64_t NativeModificationTick) override { - if (Flags & DirectoryContent::IncludeFilesFlag) + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFiles)) { Content.Files.push_back(Parent / File); + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFileSizes)) + { + Content.FileSizes.push_back(FileSize); + } + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeAttributes)) + { + Content.FileAttributes.push_back(NativeModeOrAttributes); + } + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeModificationTick)) + { + Content.FileModificationTicks.push_back(NativeModificationTick); + } } } - virtual bool VisitDirectory([[maybe_unused]] const std::filesystem::path& Parent, const path_view& DirectoryName) override + virtual bool VisitDirectory([[maybe_unused]] const std::filesystem::path& Parent, + const path_view& DirectoryName, + uint32_t NativeModeOrAttributes) override { - if (Flags & DirectoryContent::IncludeDirsFlag) + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeDirs)) { Content.Directories.push_back(Parent / DirectoryName); + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeAttributes)) + { + Content.DirectoryAttributes.push_back(NativeModeOrAttributes); + } } - return (Flags & DirectoryContent::RecursiveFlag) != 0; + return EnumHasAnyFlags(Flags, DirectoryContentFlags::Recursive); } - const uint8_t Flags; - DirectoryContent& Content; + const DirectoryContentFlags Flags; + DirectoryContent& Content; } Visit(Flags, OutContent); Traversal.TraverseFileSystem(RootDir, Visit); } +void +GetDirectoryContent(const std::filesystem::path& RootDir, + DirectoryContentFlags Flags, + GetDirectoryContentVisitor& Visitor, + WorkerThreadPool& WorkerPool, + Latch& PendingWorkCount) +{ + ZEN_ASSERT(EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::IncludeDirs)); + ZEN_ASSERT(EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFiles) + ? true + : (!EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFileSizes))); + + struct MultithreadedVisitor : public FileSystemTraversal::TreeVisitor + { + MultithreadedVisitor(WorkerThreadPool& InWorkerPool, + Latch& InOutPendingWorkCount, + const std::filesystem::path& InRelativeRoot, + DirectoryContentFlags InFlags, + GetDirectoryContentVisitor* InVisitor) + : WorkerPool(InWorkerPool) + , PendingWorkCount(InOutPendingWorkCount) + , RelativeRoot(InRelativeRoot) + , Flags(InFlags) + , Visitor(InVisitor) + { + } + + virtual void VisitFile(const std::filesystem::path&, + const path_view& File, + uint64_t FileSize, + uint32_t NativeModeOrAttributes, + uint64_t NativeModificationTick) override + { + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFiles)) + { + Content.FileNames.push_back(File); + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeFileSizes)) + { + Content.FileSizes.push_back(FileSize); + } + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeAttributes)) + { + Content.FileAttributes.push_back(NativeModeOrAttributes); + } + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeModificationTick)) + { + Content.FileModificationTicks.push_back(NativeModificationTick); + } + } + } + + virtual bool VisitDirectory(const std::filesystem::path& Parent, + const path_view& DirectoryName, + uint32_t NativeModeOrAttributes) override + { + std::filesystem::path Path = Parent / DirectoryName; + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeDirs)) + { + Content.DirectoryNames.push_back(DirectoryName); + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::IncludeAttributes)) + { + Content.DirectoryAttributes.push_back(NativeModeOrAttributes); + } + } + if (EnumHasAnyFlags(Flags, DirectoryContentFlags::Recursive)) + { + PendingWorkCount.AddCount(1); + try + { + WorkerPool.ScheduleWork( + [WorkerPool = &WorkerPool, + PendingWorkCount = &PendingWorkCount, + Visitor = Visitor, + Flags = Flags, + Path = std::move(Path), + RelativeRoot = RelativeRoot / DirectoryName]() { + ZEN_ASSERT(Visitor); + auto _ = MakeGuard([&]() { PendingWorkCount->CountDown(); }); + try + { + MultithreadedVisitor SubVisitor(*WorkerPool, *PendingWorkCount, RelativeRoot, Flags, Visitor); + FileSystemTraversal Traversal; + Traversal.TraverseFileSystem(Path, SubVisitor); + Visitor->AsyncVisitDirectory(SubVisitor.RelativeRoot, std::move(SubVisitor.Content)); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed scheduling work to scan subfolder '{}'. Reason: '{}'", Path / RelativeRoot, Ex.what()); + } + }, + WorkerThreadPool::EMode::DisableBacklog); + } + catch (const std::exception Ex) + { + ZEN_ERROR("Failed scheduling work to scan folder '{}'. Reason: '{}'", Path, Ex.what()); + PendingWorkCount.CountDown(); + } + } + return false; + } + + WorkerThreadPool& WorkerPool; + Latch& PendingWorkCount; + + const std::filesystem::path RelativeRoot; + const DirectoryContentFlags Flags; + + GetDirectoryContentVisitor::DirectoryContent Content; + GetDirectoryContentVisitor* Visitor; + } WrapperVisitor(WorkerPool, PendingWorkCount, {}, Flags, &Visitor); + + FileSystemTraversal Traversal; + Traversal.TraverseFileSystem(RootDir, WrapperVisitor); + Visitor.AsyncVisitDirectory(WrapperVisitor.RelativeRoot, std::move(WrapperVisitor.Content)); +} + std::string GetEnvVariable(std::string_view VariableName) { @@ -1564,7 +2562,7 @@ RotateFiles(const std::filesystem::path& Filename, std::size_t MaxFiles) }; auto IsEmpty = [](const std::filesystem::path& Path, std::error_code& Ec) -> bool { - bool Exists = std::filesystem::exists(Path, Ec); + bool Exists = IsFile(Path, Ec); if (Ec) { return false; @@ -1573,7 +2571,7 @@ RotateFiles(const std::filesystem::path& Filename, std::size_t MaxFiles) { return true; } - uintmax_t Size = std::filesystem::file_size(Path, Ec); + uintmax_t Size = FileSizeFromPath(Path, Ec); if (Ec) { return false; @@ -1592,17 +2590,17 @@ RotateFiles(const std::filesystem::path& Filename, std::size_t MaxFiles) for (auto i = MaxFiles; i > 0; i--) { std::filesystem::path src = GetFileName(i - 1); - if (!std::filesystem::exists(src)) + if (!IsFile(src)) { continue; } std::error_code DummyEc; std::filesystem::path target = GetFileName(i); - if (std::filesystem::exists(target, DummyEc)) + if (IsFile(target, DummyEc)) { - std::filesystem::remove(target, DummyEc); + RemoveFile(target, DummyEc); } - std::filesystem::rename(src, target, DummyEc); + RenameFile(src, target, DummyEc); } } @@ -1639,16 +2637,16 @@ RotateDirectories(const std::filesystem::path& DirectoryName, std::size_t MaxDir { const std::filesystem::path SourcePath = GetPathForIndex(i - 1); - if (std::filesystem::exists(SourcePath)) + if (IsDir(SourcePath)) { std::filesystem::path TargetPath = GetPathForIndex(i); std::error_code DummyEc; - if (std::filesystem::exists(TargetPath, DummyEc)) + if (IsDir(TargetPath, DummyEc)) { - std::filesystem::remove_all(TargetPath, DummyEc); + DeleteDirectories(TargetPath, DummyEc); } - std::filesystem::rename(SourcePath, TargetPath, DummyEc); + RenameDirectory(SourcePath, TargetPath, DummyEc); } } @@ -1679,6 +2677,392 @@ SearchPathForExecutable(std::string_view ExecutableName) #endif } +std::filesystem::path +PickDefaultSystemRootDirectory() +{ +#if ZEN_PLATFORM_WINDOWS + // Pick sensible default + PWSTR ProgramDataDir = nullptr; + HRESULT hRes = SHGetKnownFolderPath(FOLDERID_ProgramData, 0, NULL, &ProgramDataDir); + + if (SUCCEEDED(hRes)) + { + std::filesystem::path FinalPath(ProgramDataDir); + FinalPath /= L"Epic\\Zen"; + ::CoTaskMemFree(ProgramDataDir); + + return FinalPath; + } + + return L""; +#else // ZEN_PLATFORM_WINDOWS + int UserId = getuid(); + const passwd* Passwd = getpwuid(UserId); + return std::filesystem::path(Passwd->pw_dir) / ".zen"; +#endif // ZEN_PLATFORM_WINDOWS +} + +#if ZEN_PLATFORM_WINDOWS + +uint32_t +GetFileAttributes(const std::filesystem::path& Filename, std::error_code& Ec) +{ + return WinGetFileAttributes(Filename, Ec); +} + +uint32_t +GetFileAttributes(const std::filesystem::path& Filename) +{ + std::error_code Ec; + uint32_t Result = zen::GetFileAttributes(Filename, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to get attributes of file '{}'", Filename.string())); + } + return Result; +} + +void +SetFileAttributes(const std::filesystem::path& Filename, uint32_t Attributes, std::error_code& Ec) +{ + if (::SetFileAttributes(Filename.native().c_str(), Attributes) == 0) + { + Ec = MakeErrorCodeFromLastError(); + } +} + +void +SetFileAttributes(const std::filesystem::path& Filename, uint32_t Attributes) +{ + std::error_code Ec; + zen::SetFileAttributes(Filename, Attributes, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to set attributes of file {}", Filename.string())); + } +} + +#endif // ZEN_PLATFORM_WINDOWS + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +uint32_t +GetFileMode(const std::filesystem::path& Filename) +{ + std::error_code Ec; + uint32_t Result = GetFileMode(Filename, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to get mode of file {}", Filename)); + } + return Result; +} + +uint32_t +GetFileMode(const std::filesystem::path& Filename, std::error_code& Ec) +{ + struct stat Stat; + int err = stat(Filename.native().c_str(), &Stat); + if (err) + { + Ec = MakeErrorCodeFromLastError(); + return 0; + } + return (uint32_t)Stat.st_mode; +} + +void +SetFileMode(const std::filesystem::path& Filename, uint32_t Mode) +{ + std::error_code Ec; + SetFileMode(Filename, Mode, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to set mode of file {}", Filename)); + } +} + +void +SetFileMode(const std::filesystem::path& Filename, uint32_t Mode, std::error_code& Ec) +{ + int err = chmod(Filename.native().c_str(), (mode_t)Mode); + if (err) + { + Ec = MakeErrorCodeFromLastError(); + } +} + +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +bool +SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + uint32_t CurrentAttributes = GetFileAttributes(Filename, Ec); + if (Ec) + { + return false; + } + if (CurrentAttributes == INVALID_FILE_ATTRIBUTES) + { + Ec = MakeErrorCode(ERROR_FILE_NOT_FOUND); + return false; + } + uint32_t NewAttributes = MakeFileAttributeReadOnly(CurrentAttributes, ReadOnly); + if (CurrentAttributes != NewAttributes) + { + SetFileAttributes(Filename, NewAttributes, Ec); + if (Ec) + { + return false; + } + return true; + } +#endif // ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + uint32_t CurrentMode = GetFileMode(Filename, Ec); + if (Ec) + { + return false; + } + uint32_t NewMode = MakeFileModeReadOnly(CurrentMode, ReadOnly); + if (CurrentMode != NewMode) + { + SetFileMode(Filename, NewMode, Ec); + if (Ec) + { + return false; + } + return true; + } +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + return false; +} + +bool +SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly) +{ + std::error_code Ec; + bool Result = SetFileReadOnly(Filename, ReadOnly, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to set read only mode of file '{}'", Filename.string())); + } + return Result; +} + +class SharedMemoryImpl : public SharedMemory +{ +public: + struct Data + { + void* Handle = nullptr; + void* DataPtr = nullptr; + size_t Size = 0; + std::string Name; + }; + + static Data Open(std::string_view Name, size_t Size, bool SystemGlobal) + { +#if ZEN_PLATFORM_WINDOWS + std::wstring InstanceMapName = Utf8ToWide(fmt::format("{}\\{}", SystemGlobal ? "Global" : "Local", Name)); + + HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, InstanceMapName.c_str()); + if (hMap == NULL) + { + return {}; + } + void* pBuf = MapViewOfFile(hMap, // handle to map object + FILE_MAP_ALL_ACCESS, // read/write permission + 0, // offset high + 0, // offset low + DWORD(Size)); // size + + if (pBuf == NULL) + { + CloseHandle(hMap); + } + return Data{.Handle = hMap, .DataPtr = pBuf, .Size = Size, .Name = std::string(Name)}; +#endif // ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + ZEN_UNUSED(SystemGlobal); + std::string InstanceMapName = fmt::format("/{}", Name); + + int Fd = shm_open(InstanceMapName.c_str(), O_RDWR, 0666); + if (Fd < 0) + { + return {}; + } + void* hMap = (void*)intptr_t(Fd); + + struct stat Stat; + fstat(Fd, &Stat); + + if (size_t(Stat.st_size) < Size) + { + close(Fd); + return {}; + } + + void* pBuf = mmap(nullptr, Size, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0); + if (pBuf == MAP_FAILED) + { + close(Fd); + return {}; + } + return Data{.Handle = hMap, .DataPtr = pBuf, .Size = Size, .Name = std::string(Name)}; +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + } + + static Data Create(std::string_view Name, size_t Size, bool SystemGlobal) + { +#if ZEN_PLATFORM_WINDOWS + std::wstring InstanceMapName = Utf8ToWide(fmt::format("{}\\{}", SystemGlobal ? "Global" : "Local", Name)); + + SECURITY_ATTRIBUTES m_Attributes{}; + SECURITY_DESCRIPTOR m_Sd{}; + + m_Attributes.nLength = sizeof m_Attributes; + m_Attributes.bInheritHandle = false; // Disable inheritance + + const BOOL Success = InitializeSecurityDescriptor(&m_Sd, SECURITY_DESCRIPTOR_REVISION); + + if (Success) + { + if (!SetSecurityDescriptorDacl(&m_Sd, TRUE, (PACL)NULL, FALSE)) + { + ThrowLastError("SetSecurityDescriptorDacl failed"); + } + + m_Attributes.lpSecurityDescriptor = &m_Sd; + } + + HANDLE hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file + &m_Attributes, // allow anyone to access + PAGE_READWRITE, // read/write access + 0, // maximum object size (high-order DWORD) + DWORD(Size), // maximum object size (low-order DWORD) + InstanceMapName.c_str()); + if (hMap == NULL) + { + return {}; + } + void* pBuf = MapViewOfFile(hMap, // handle to map object + FILE_MAP_ALL_ACCESS, // read/write permission + 0, // offset high + 0, // offset low + DWORD(Size)); // size + + if (pBuf == NULL) + { + CloseHandle(hMap); + return {}; + } + return Data{.Handle = hMap, .DataPtr = pBuf, .Size = Size, .Name = std::string(Name)}; +#endif // ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + ZEN_UNUSED(SystemGlobal); + std::string InstanceMapName = fmt::format("/{}", Name); + + int Fd = shm_open(InstanceMapName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666); + if (Fd < 0) + { + return {}; + } + fchmod(Fd, 0666); + void* hMap = (void*)intptr_t(Fd); + + int Result = ftruncate(Fd, Size); + ZEN_UNUSED(Result); + + void* pBuf = mmap(nullptr, Size, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0); + if (pBuf == MAP_FAILED) + { + close(Fd); + return {}; + } + return Data{.Handle = hMap, .DataPtr = pBuf, .Size = Size, .Name = std::string(Name)}; +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + } + + static void Close(Data&& MemMap, bool Delete) + { +#if ZEN_PLATFORM_WINDOWS + ZEN_UNUSED(Delete); + if (MemMap.DataPtr != nullptr) + { + UnmapViewOfFile(MemMap.DataPtr); + MemMap.DataPtr = nullptr; + } + if (MemMap.Handle != nullptr) + { + CloseHandle(MemMap.Handle); + MemMap.Handle = nullptr; + } +#endif // ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + if (MemMap.DataPtr != nullptr) + { + munmap(MemMap.DataPtr, MemMap.Size); + MemMap.DataPtr = nullptr; + } + + if (MemMap.Handle != nullptr) + { + int Fd = int(intptr_t(MemMap.Handle)); + close(Fd); + MemMap.Handle = nullptr; + } + if (Delete) + { + std::string InstanceMapName = fmt::format("/{}", MemMap.Name); + shm_unlink(InstanceMapName.c_str()); + } +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + } + + SharedMemoryImpl(Data&& MemMap, bool IsOwned) : m_MemMap(std::move(MemMap)), m_IsOwned(IsOwned) {} + virtual ~SharedMemoryImpl() + { + try + { + Close(std::move(m_MemMap), /*Delete*/ m_IsOwned); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("SharedMemoryImpl::~SharedMemoryImpl threw exception: {}", Ex.what()); + } + } + + virtual void* GetData() override { return m_MemMap.DataPtr; } + +private: + Data m_MemMap; + const bool m_IsOwned = false; +}; + +std::unique_ptr<SharedMemory> +OpenSharedMemory(std::string_view Name, size_t Size, bool SystemGlobal) +{ + SharedMemoryImpl::Data MemMap = SharedMemoryImpl::Open(Name, Size, SystemGlobal); + if (MemMap.DataPtr) + { + return std::make_unique<SharedMemoryImpl>(std::move(MemMap), /*IsOwned*/ false); + } + return {}; +} + +std::unique_ptr<SharedMemory> +CreateSharedMemory(std::string_view Name, size_t Size, bool SystemGlobal) +{ + SharedMemoryImpl::Data MemMap = SharedMemoryImpl::Create(Name, Size, SystemGlobal); + if (MemMap.DataPtr) + { + return std::make_unique<SharedMemoryImpl>(std::move(MemMap), /*IsOwned*/ true); + } + return {}; +} + ////////////////////////////////////////////////////////////////////////// // // Testing related code follows... @@ -1699,7 +3083,7 @@ TEST_CASE("filesystem") path BinPath = GetRunningExecutablePath(); const bool ExpectedExe = PathToUtf8(BinPath.stem().native()).ends_with("-test"sv) || BinPath.stem() == "zenserver"; CHECK(ExpectedExe); - CHECK(is_regular_file(BinPath)); + CHECK(IsFile(BinPath)); // PathFromHandle void* Handle; @@ -1712,7 +3096,9 @@ TEST_CASE("filesystem") Handle = (void*)uintptr_t(Fd); # endif - auto FromHandle = PathFromHandle((void*)uintptr_t(Handle)); + std::error_code Ec; + auto FromHandle = PathFromHandle((void*)uintptr_t(Handle), Ec); + CHECK(!Ec); CHECK(equivalent(FromHandle, BinPath)); # if ZEN_PLATFORM_WINDOWS @@ -1724,17 +3110,24 @@ TEST_CASE("filesystem") // Traversal struct : public FileSystemTraversal::TreeVisitor { - virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t) override + virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t, uint32_t, uint64_t) override { - bFoundExpected |= std::filesystem::equivalent(Parent / File, Expected); + // std::filesystem::equivalent is *very* expensive on Windows, filter out unlikely candidates + if (ExpectedFilename == ToLower(std::filesystem::path(File).string())) + { + bFoundExpected |= std::filesystem::equivalent(Parent / File, Expected); + } } - virtual bool VisitDirectory(const std::filesystem::path&, const path_view&) override { return true; } + virtual bool VisitDirectory(const std::filesystem::path&, const path_view&, uint32_t) override { return true; } - bool bFoundExpected = false; + bool bFoundExpected = false; + + std::string ExpectedFilename; std::filesystem::path Expected; } Visitor; - Visitor.Expected = BinPath; + Visitor.ExpectedFilename = ToLower(BinPath.filename().string()); + Visitor.Expected = BinPath; FileSystemTraversal().TraverseFileSystem(BinPath.parent_path().parent_path(), Visitor); CHECK(Visitor.bFoundExpected); @@ -1750,6 +3143,80 @@ TEST_CASE("filesystem") CHECK_EQ(BinScan.size(), BinRead.Data[0].GetSize()); } +TEST_CASE("Filesystem.Basics") +{ + std::filesystem::path TestBaseDir = GetRunningExecutablePath().parent_path() / ".test"; + CleanDirectory(TestBaseDir, true); + DeleteDirectories(TestBaseDir); + CHECK(!IsDir(TestBaseDir)); + CHECK(CleanDirectory(TestBaseDir, false)); + CHECK(IsDir(TestBaseDir)); + CHECK(!CleanDirectory(TestBaseDir, false)); + CHECK(!IsDir(TestBaseDir / "no_such_thing")); + CHECK(!IsDir("hgjda/cev_/q12")); + CHECK(!IsFile(TestBaseDir)); + CHECK(!IsFile(TestBaseDir / "no_such_thing")); + CHECK(!IsFile("hgjda/cev_/q12")); + CHECK_THROWS(FileSizeFromPath(TestBaseDir) == 0); + CHECK_THROWS(FileSizeFromPath(TestBaseDir / "no_such_file")); + CHECK(!CreateDirectories(TestBaseDir)); + CHECK(CreateDirectories(TestBaseDir / "nested" / "a" / "bit" / "deep")); + CHECK(!CreateDirectories(TestBaseDir / "nested" / "a" / "bit" / "deep")); + CHECK(IsDir(TestBaseDir / "nested" / "a" / "bit" / "deep")); + CHECK(IsDir(TestBaseDir / "nested" / "a" / "bit")); + CHECK(!IsDir(TestBaseDir / "nested" / "a" / "bit" / "deep" / "no")); + CHECK_THROWS(WriteFile(TestBaseDir / "nested" / "a", IoBuffer(20))); + CHECK_NOTHROW(WriteFile(TestBaseDir / "nested" / "a" / "yo", IoBuffer(20))); + CHECK(IsFile(TestBaseDir / "nested" / "a" / "yo")); + CHECK(FileSizeFromPath(TestBaseDir / "nested" / "a" / "yo") == 20); + CHECK(!IsFile(TestBaseDir / "nested" / "a")); + CHECK(DeleteDirectories(TestBaseDir / "nested" / "a" / "bit")); + CHECK(IsFile(TestBaseDir / "nested" / "a" / "yo")); + CHECK(!IsDir(TestBaseDir / "nested" / "a" / "bit")); + CHECK(!DeleteDirectories(TestBaseDir / "nested" / "a" / "bit")); + CHECK(IsDir(TestBaseDir / "nested" / "a")); + CHECK(DeleteDirectories(TestBaseDir / "nested")); + CHECK(!IsFile(TestBaseDir / "nested" / "a" / "yo")); + CHECK(CreateDirectories(TestBaseDir / "nested" / "deeper")); + CHECK_NOTHROW(WriteFile(TestBaseDir / "nested" / "deeper" / "yo", IoBuffer(20))); + CHECK_NOTHROW(RenameDirectory(TestBaseDir / "nested" / "deeper", TestBaseDir / "new_place")); + CHECK(IsFile(TestBaseDir / "new_place" / "yo")); + CHECK(FileSizeFromPath(TestBaseDir / "new_place" / "yo") == 20); + CHECK(IsDir(TestBaseDir / "new_place")); + CHECK(!IsFile(TestBaseDir / "new_place")); + CHECK_THROWS(RenameDirectory(TestBaseDir / "nested" / "deeper", TestBaseDir / "new_place")); + CHECK(!RemoveDir(TestBaseDir / "nested" / "deeper")); + CHECK(RemoveFile(TestBaseDir / "new_place" / "yo")); + CHECK(!IsFile(TestBaseDir / "new_place" / "yo")); + CHECK_THROWS(FileSizeFromPath(TestBaseDir / "new_place" / "yo")); + CHECK(!RemoveFile(TestBaseDir / "new_place" / "yo")); + CHECK_THROWS(RemoveFile(TestBaseDir / "nested")); + CHECK_THROWS(RemoveDir(TestBaseDir)); + CHECK_NOTHROW(WriteFile(TestBaseDir / "yo", IoBuffer(20))); + CHECK_NOTHROW(RenameFile(TestBaseDir / "yo", TestBaseDir / "new_place" / "yo")); + CHECK(!IsFile(TestBaseDir / "yo")); + CHECK(IsFile(TestBaseDir / "new_place" / "yo")); + CHECK(FileSizeFromPath(TestBaseDir / "new_place" / "yo") == 20); + CHECK_THROWS(RemoveDir(TestBaseDir / "new_place" / "yo")); + CHECK(DeleteDirectories(TestBaseDir)); + CHECK(!IsFile(TestBaseDir / "new_place" / "yo")); + CHECK(!IsDir(TestBaseDir)); + CHECK(!IsDir(TestBaseDir / "nested")); + CHECK(CreateDirectories(TestBaseDir / "nested")); + CHECK_NOTHROW(WriteFile(TestBaseDir / "nested" / "readonly", IoBuffer(20))); + CHECK(SetFileReadOnly(TestBaseDir / "nested" / "readonly", true)); + CHECK_THROWS(RemoveFile(TestBaseDir / "nested" / "readonly")); + CHECK_THROWS(CleanDirectory(TestBaseDir, false)); + CHECK(SetFileReadOnly(TestBaseDir / "nested" / "readonly", false)); + CHECK(RemoveFile(TestBaseDir / "nested" / "readonly")); + CHECK(!CleanDirectory(TestBaseDir, false)); + CHECK_NOTHROW(WriteFile(TestBaseDir / "nested" / "readonly", IoBuffer(20))); + CHECK(SetFileReadOnly(TestBaseDir / "nested" / "readonly", true)); + CHECK(!CleanDirectory(TestBaseDir / "nested", true)); + CHECK(!CleanDirectory(TestBaseDir, false)); + CHECK(RemoveDir(TestBaseDir)); +} + TEST_CASE("WriteFile") { std::filesystem::path TempFile = GetRunningExecutablePath().parent_path(); @@ -1784,7 +3251,7 @@ TEST_CASE("WriteFile") CHECK_EQ(memcmp(MagicTest.Data, MagicsReadback.Data[0].Data(), MagicTest.Size), 0); } - std::filesystem::remove(TempFile); + RemoveFile(TempFile); } TEST_CASE("DiskSpaceInfo") @@ -1841,7 +3308,7 @@ TEST_CASE("PathBuilder") TEST_CASE("RotateDirectories") { std::filesystem::path TestBaseDir = GetRunningExecutablePath().parent_path() / ".test"; - CleanDirectory(TestBaseDir); + CleanDirectory(TestBaseDir, false); std::filesystem::path RotateDir = TestBaseDir / "rotate_dir" / "dir_to_rotate"; IoBuffer DummyFileData = IoBufferBuilder::MakeCloneFromMemory("blubb", 5); @@ -1855,16 +3322,16 @@ TEST_CASE("RotateDirectories") const int RotateMax = 10; NewDir(); - CHECK(std::filesystem::exists(RotateDir)); + CHECK(IsDir(RotateDir)); RotateDirectories(RotateDir, RotateMax); - CHECK(!std::filesystem::exists(RotateDir)); - CHECK(std::filesystem::exists(DirWithSuffix(1))); + CHECK(!IsDir(RotateDir)); + CHECK(IsDir(DirWithSuffix(1))); NewDir(); - CHECK(std::filesystem::exists(RotateDir)); + CHECK(IsDir(RotateDir)); RotateDirectories(RotateDir, RotateMax); - CHECK(!std::filesystem::exists(RotateDir)); - CHECK(std::filesystem::exists(DirWithSuffix(1))); - CHECK(std::filesystem::exists(DirWithSuffix(2))); + CHECK(!IsDir(RotateDir)); + CHECK(IsDir(DirWithSuffix(1))); + CHECK(IsDir(DirWithSuffix(2))); for (int i = 0; i < RotateMax; ++i) { @@ -1874,17 +3341,33 @@ TEST_CASE("RotateDirectories") CHECK_EQ(IsError, false); } - CHECK(!std::filesystem::exists(RotateDir)); + CHECK(!IsDir(RotateDir)); for (int i = 0; i < RotateMax; ++i) { - CHECK(std::filesystem::exists(DirWithSuffix(i + 1))); + CHECK(IsDir(DirWithSuffix(i + 1))); } for (int i = RotateMax; i < RotateMax + 5; ++i) { - CHECK(!std::filesystem::exists(DirWithSuffix(RotateMax + i + 1))); + CHECK(!IsDir(DirWithSuffix(RotateMax + i + 1))); + } +} + +TEST_CASE("SharedMemory") +{ + CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false)); + CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, true)); + + { + auto Mem0 = CreateSharedMemory("SharedMemoryTest0", 482, false); + CHECK(Mem0); + strcpy((char*)Mem0->GetData(), "this is the string we are looking for"); + auto Mem1 = OpenSharedMemory("SharedMemoryTest0", 482, false); + CHECK_EQ(std::string((char*)Mem0->GetData()), std::string((char*)Mem1->GetData())); } + + CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false)); } #endif diff --git a/src/zencore/include/zencore/basicfile.h b/src/zencore/include/zencore/basicfile.h new file mode 100644 index 000000000..465499d2b --- /dev/null +++ b/src/zencore/include/zencore/basicfile.h @@ -0,0 +1,195 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/compositebuffer.h> +#include <zencore/enumflags.h> +#include <zencore/iobuffer.h> + +#include <filesystem> +#include <functional> + +namespace zen { + +class CbObject; + +/** + * Probably the most basic file abstraction in the universe + * + * One thing of note is that there is no notion of a "current file position" + * in this API -- all reads and writes are done from explicit offsets in + * the file. This avoids concurrency issues which can occur otherwise. + * + */ + +class BasicFile +{ +public: + BasicFile() = default; + ~BasicFile(); + + BasicFile(const BasicFile&) = delete; + BasicFile& operator=(const BasicFile&) = delete; + + enum class Mode : uint32_t + { + kRead = 0, // Opens a existing file for read only + kWrite = 1, // Opens (or creates) a file for read and write + kTruncate = 2, // Opens (or creates) a file for read and write and sets the size to zero + kDelete = 3, // Opens (or creates) a file for read and write allowing .DeleteFile file disposition to be set + kTruncateDelete = + 4, // Opens (or creates) a file for read and write and sets the size to zero allowing .DeleteFile file disposition to be set + kModeMask = 0x0007, + kPreventDelete = 0x1000'0000, // Do not open with delete sharing mode (prevent other processes from deleting file while open) + kPreventWrite = 0x2000'0000, // Do not open with write sharing mode (prevent other processes from writing to file while open) + }; + + BasicFile(const std::filesystem::path& FileName, Mode Mode); + BasicFile(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec); + BasicFile(const std::filesystem::path& FileName, Mode Mode, std::function<bool(std::error_code& Ec)>&& RetryCallback); + + void Open(const std::filesystem::path& FileName, Mode Mode); + void Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec); + void Open(const std::filesystem::path& FileName, Mode Mode, std::function<bool(std::error_code& Ec)>&& RetryCallback); + void Close(); + void Read(void* Data, uint64_t Size, uint64_t FileOffset); + IoBuffer ReadRange(uint64_t FileOffset, uint64_t ByteCount); + void StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun); + void StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun); + void Write(MemoryView Data, uint64_t FileOffset); + void Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec); + uint64_t Write(const CompositeBuffer& Data, uint64_t FileOffset); + uint64_t Write(const CompositeBuffer& Data, uint64_t FileOffset, std::error_code& Ec); + void Write(const void* Data, uint64_t Size, uint64_t FileOffset); + void Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec); + void Flush(); + [[nodiscard]] uint64_t FileSize() const; + [[nodiscard]] uint64_t FileSize(std::error_code& Ec) const; + void SetFileSize(uint64_t FileSize); + IoBuffer ReadAll(); + void WriteAll(IoBuffer Data, std::error_code& Ec); + void Attach(void* Handle); + void* Detach(); + + inline void* Handle() { return m_FileHandle; } + bool IsOpen() const { return m_FileHandle != nullptr; } + +protected: + void* m_FileHandle = nullptr; // This is either null or valid +private: +}; + +ENUM_CLASS_FLAGS(BasicFile::Mode); + +/** + * Simple abstraction for a temporary file + * + * Works like a regular BasicFile but implements a simple mechanism to allow creating + * a temporary file for writing in a directory which may later be moved atomically + * into the intended location after it has been fully written to. + * + */ + +class TemporaryFile : public BasicFile +{ +public: + TemporaryFile() = default; + ~TemporaryFile(); + + TemporaryFile(const TemporaryFile&) = delete; + TemporaryFile& operator=(const TemporaryFile&) = delete; + + void CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec); + void MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec); + const std::filesystem::path& GetPath() const { return m_TempPath; } + + static void SafeWriteFile(const std::filesystem::path& Path, MemoryView Data); + static void SafeWriteFile(const std::filesystem::path& Path, MemoryView Data, std::error_code& OutEc); + +private: + void Close(); + std::filesystem::path m_TempPath; + + using BasicFile::Open; +}; + +/** Lock file abstraction + + */ + +class LockFile : protected BasicFile +{ +public: + LockFile(); + ~LockFile(); + + void Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec); + void Update(CbObject Payload, std::error_code& Ec); + +private: +}; + +/** Adds a layer of buffered reading to a BasicFile + +This class is not intended for concurrent access, it is not thread safe. + +*/ + +class BasicFileBuffer +{ +public: + BasicFileBuffer(BasicFile& Base, uint64_t BufferSize); + ~BasicFileBuffer(); + + void Read(void* Data, uint64_t Size, uint64_t FileOffset); + MemoryView MakeView(uint64_t Size, uint64_t FileOffset); + + template<typename T> + const T* MakeView(uint64_t FileOffset) + { + MemoryView View = MakeView(sizeof(T), FileOffset); + return reinterpret_cast<const T*>(View.GetData()); + } + +private: + BasicFile& m_Base; + uint8_t* m_Buffer; + const uint64_t m_BufferSize; + uint64_t m_Size; + uint64_t m_BufferStart; + uint64_t m_BufferEnd; +}; + +/** Adds a layer of buffered writing to a BasicFile + +This class is not intended for concurrent access, it is not thread safe. + +*/ + +class BasicFileWriter +{ +public: + BasicFileWriter(BasicFile& Base, uint64_t BufferSize); + ~BasicFileWriter(); + + void Write(const void* Data, uint64_t Size, uint64_t FileOffset); + void Write(const CompositeBuffer& Data, uint64_t FileOffset); + void AddPadding(uint64_t Padding); + uint64_t AlignTo(uint64_t Alignment); + void Flush(); + +private: + BasicFile& m_Base; + uint8_t* m_Buffer; + const uint64_t m_BufferSize; + uint64_t m_BufferStart; + uint64_t m_BufferEnd; +}; + +IoBuffer WriteToTempFile(CompositeBuffer&& Buffer, const std::filesystem::path& Path); + +ZENCORE_API void basicfile_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/blake3.h b/src/zencore/include/zencore/blake3.h index b31b710a7..f01e45266 100644 --- a/src/zencore/include/zencore/blake3.h +++ b/src/zencore/include/zencore/blake3.h @@ -6,11 +6,12 @@ #include <compare> #include <cstring> -#include <zencore/memory.h> +#include <zencore/memoryview.h> namespace zen { class CompositeBuffer; +class IoBuffer; class StringBuilderBase; /** @@ -23,6 +24,7 @@ struct BLAKE3 inline auto operator<=>(const BLAKE3& Rhs) const = default; static BLAKE3 HashBuffer(const CompositeBuffer& Buffer); + static BLAKE3 HashBuffer(const IoBuffer& Buffer); static BLAKE3 HashMemory(const void* Data, size_t ByteCount); static BLAKE3 FromHexString(const char* String); const char* ToHexString(char* OutString /* 40 characters + NUL terminator */) const; @@ -51,6 +53,7 @@ struct BLAKE3Stream void Reset(); // Begin streaming hash compute (not needed on freshly constructed instance) BLAKE3Stream& Append(const void* data, size_t byteCount); // Append another chunk BLAKE3Stream& Append(MemoryView DataView) { return Append(DataView.GetData(), DataView.GetSize()); } // Append another chunk + BLAKE3Stream& Append(const IoBuffer& Buffer); // Append another chunk BLAKE3 GetHash(); // Obtain final hash. If you wish to reuse the instance call reset() private: diff --git a/src/zencore/include/zencore/callstack.h b/src/zencore/include/zencore/callstack.h new file mode 100644 index 000000000..ca8171435 --- /dev/null +++ b/src/zencore/include/zencore/callstack.h @@ -0,0 +1,49 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/string.h> + +#include <string> +#include <vector> + +namespace zen { + +struct CallstackFrames +{ + uint32_t FrameCount; + void** Frames; +}; + +CallstackFrames* CreateCallstack(uint32_t FrameCount, void** Frames) noexcept; +CallstackFrames* CloneCallstack(const CallstackFrames* Callstack) noexcept; +void FreeCallstack(CallstackFrames* Callstack) noexcept; + +uint32_t GetCallstack(int FramesToSkip, int FramesToCapture, void* OutAddresses[]); +std::vector<std::string> GetFrameSymbols(uint32_t FrameCount, void** Frames); +inline std::vector<std::string> +GetFrameSymbols(const CallstackFrames* Callstack) +{ + return GetFrameSymbols(Callstack ? Callstack->FrameCount : 0, Callstack ? Callstack->Frames : nullptr); +} + +void FormatCallstack(const CallstackFrames* Callstack, StringBuilderBase& SB, std::string_view Prefix); +std::string CallstackToString(const CallstackFrames* Callstack, std::string_view Prefix = {}); + +typedef void (*CallstackRawCallback)(void* UserData, uint32_t FrameIndex, const char* FrameText); + +constexpr size_t +CallstackRawMemorySize(int FramesToSkip, int FramesToCapture) +{ + return sizeof(CallstackFrames) + sizeof(void*) * (FramesToSkip + FramesToCapture); +} + +void CallstackToStringRaw(const CallstackFrames* Callstack, void* CallbackUserData, CallstackRawCallback Callback); + +CallstackFrames* GetCallstackRaw(void* CaptureBuffer, int FramesToSkip, int FramesToCapture); + +void callstack_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/commandline.h b/src/zencore/include/zencore/commandline.h new file mode 100644 index 000000000..a4ce6b27d --- /dev/null +++ b/src/zencore/include/zencore/commandline.h @@ -0,0 +1,39 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +#include <functional> +#include <string_view> + +namespace zen { + +void IterateCommandlineArgs(std::function<void(const std::string_view& Arg)>& ProcessArg); + +template<typename Func> +void +IterateCommaSeparatedValue(std::string_view OptionArgs, Func&& ProcessArg) +{ + while (OptionArgs.size()) + { + const auto CommaPos = OptionArgs.find_first_of(','); + std::string_view OptionArg; + + if (CommaPos == std::string_view::npos) + { + // No comma or final argument + OptionArg = OptionArgs; + OptionArgs = {}; + } + else + { + OptionArg = OptionArgs.substr(0, CommaPos); + OptionArgs = OptionArgs.substr(CommaPos + 1); + } + + ProcessArg(OptionArg); + } +} + +} // namespace zen diff --git a/src/zencore/include/zencore/compactbinary.h b/src/zencore/include/zencore/compactbinary.h index 675e2a8d4..82ca055ab 100644 --- a/src/zencore/include/zencore/compactbinary.h +++ b/src/zencore/include/zencore/compactbinary.h @@ -9,7 +9,7 @@ #include <zencore/intmath.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <zencore/meta.h> #include <zencore/sharedbuffer.h> #include <zencore/uid.h> @@ -514,7 +514,7 @@ public: ZENCORE_API std::string_view AsString(std::string_view Default = std::string_view()); ZENCORE_API std::u8string_view AsU8String(std::u8string_view Default = std::u8string_view()); - ZENCORE_API void IterateAttachments(std::function<void(CbFieldView)> Visitor) const; + ZENCORE_API void IterateAttachments(const std::function<void(CbFieldView)>& Visitor) const; /** Access the field as an int8. Returns the provided default on error. */ inline int8_t AsInt8(int8_t Default = 0) { return AsInteger<int8_t>(Default); } @@ -796,7 +796,7 @@ public: ZENCORE_API void CopyRangeTo(MutableMemoryView Buffer) const; /** Invoke the visitor for every attachment in the field range. */ - ZENCORE_API void IterateRangeAttachments(std::function<void(CbFieldView)> Visitor) const; + ZENCORE_API void IterateRangeAttachments(const std::function<void(CbFieldView)>& Visitor) const; /** Create a view of every field in the range. */ inline MemoryView GetRangeView() const { return MemoryView(FieldType::GetView().GetData(), FieldsEnd); } @@ -895,6 +895,11 @@ private: ZENCORE_API void CompactBinaryToJson(const CbArrayView& Object, StringBuilderBase& Builder); /** + * Serialize a compact binary array to YAML. + */ +ZENCORE_API void CompactBinaryToYaml(const CbArrayView& Object, StringBuilderBase& Builder); + +/** * Array of CbField that have no names. * * Accessing a field of the array requires iteration. Access by index is not provided because the @@ -960,7 +965,7 @@ public: ZENCORE_API void CopyTo(BinaryWriter& Ar) const; ///** Invoke the visitor for every attachment in the array. */ - inline void IterateAttachments(std::function<void(CbFieldView)> Visitor) const + inline void IterateAttachments(const std::function<void(CbFieldView)>& Visitor) const { CreateViewIterator().IterateRangeAttachments(Visitor); } @@ -974,6 +979,12 @@ public: return Builder; } + StringBuilderBase& ToYaml(StringBuilderBase& Builder) const + { + CompactBinaryToYaml(*this, Builder); + return Builder; + } + private: friend inline CbFieldViewIterator begin(const CbArrayView& Array) { return Array.CreateViewIterator(); } friend inline CbFieldViewIterator end(const CbArrayView&) { return CbFieldViewIterator(); } @@ -985,7 +996,11 @@ private: /** * Serialize a compact binary object to JSON. */ -ZENCORE_API void CompactBinaryToJson(const CbObjectView& Object, StringBuilderBase& Builder); +ZENCORE_API void CompactBinaryToJson(const CbObjectView& Object, StringBuilderBase& Builder, bool AddTypeComment = false); +/** + * Serialize a compact binary object to YAML. + */ +ZENCORE_API void CompactBinaryToYaml(const CbObjectView& Object, StringBuilderBase& Builder); class CbObjectView : protected CbFieldView { @@ -1058,7 +1073,7 @@ public: ZENCORE_API void CopyTo(BinaryWriter& Ar) const; ///** Invoke the visitor for every attachment in the object. */ - inline void IterateAttachments(std::function<void(CbFieldView)> Visitor) const + inline void IterateAttachments(const std::function<void(CbFieldView)>& Visitor) const { CreateViewIterator().IterateRangeAttachments(Visitor); } @@ -1075,6 +1090,12 @@ public: return Builder; } + StringBuilderBase& ToYaml(StringBuilderBase& Builder) const + { + CompactBinaryToYaml(*this, Builder); + return Builder; + } + private: friend inline CbFieldViewIterator begin(const CbObjectView& Object) { return Object.CreateViewIterator(); } friend inline CbFieldViewIterator end(const CbObjectView&) { return CbFieldViewIterator(); } @@ -1497,12 +1518,19 @@ end(CbFieldView&) } /** - * Serialize serialized compact binary blob to jaons. It must be 0 to n fields with including type for each field + * Serialize serialized compact binary blob to JSON. It must be 0 to n fields with including type for each field + */ +ZENCORE_API void CompactBinaryToJson(MemoryView Data, StringBuilderBase& InBuilder, bool AddTypeComment = false); + +/** + * Serialize serialized compact binary blob to YAML. It must be 0 to n fields with including type for each field */ -ZENCORE_API void CompactBinaryToJson(MemoryView Data, StringBuilderBase& InBuilder); +ZENCORE_API void CompactBinaryToYaml(MemoryView Data, StringBuilderBase& InBuilder); ZENCORE_API std::vector<CbFieldView> ReadCompactBinaryStream(MemoryView Data); -void uson_forcelink(); // internal +void uson_forcelink(); // internal +void cbjson_forcelink(); // internal +void cbyaml_forcelink(); // internal } // namespace zen diff --git a/src/zencore/include/zencore/compactbinarybuilder.h b/src/zencore/include/zencore/compactbinarybuilder.h index 9c81cf490..f11717453 100644 --- a/src/zencore/include/zencore/compactbinarybuilder.h +++ b/src/zencore/include/zencore/compactbinarybuilder.h @@ -18,6 +18,8 @@ #include <type_traits> #include <vector> +#include <EASTL/fixed_vector.h> + #include <gsl/gsl-lite.hpp> namespace zen { @@ -367,6 +369,8 @@ public: /** Private flags that are public to work with ENUM_CLASS_FLAGS. */ enum class StateFlags : uint8_t; + typedef eastl::fixed_vector<uint8_t, 2048> CbWriterData_t; + protected: /** Reserve the specified size up front until the format is optimized. */ ZENCORE_API explicit CbWriter(int64_t InitialSize); @@ -409,8 +413,8 @@ private: // provided externally, such as on the stack. That format will store the offsets that require // object or array sizes to be inserted and field types to be removed, and will perform those // operations only when saving to a buffer. - std::vector<uint8_t> Data; - std::vector<WriterState> States; + eastl::fixed_vector<WriterState, 4> States; + CbWriterData_t Data; }; /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -669,7 +673,7 @@ ToDateTime(std::chrono::system_clock::time_point TimePoint) { time_t Time = std::chrono::system_clock::to_time_t(TimePoint); tm UTCTime = *gmtime(&Time); - return DateTime(1900 + UTCTime.tm_year, UTCTime.tm_mon, UTCTime.tm_mday, UTCTime.tm_hour, UTCTime.tm_min, UTCTime.tm_sec); + return DateTime(1900 + UTCTime.tm_year, UTCTime.tm_mon + 1, UTCTime.tm_mday, UTCTime.tm_hour, UTCTime.tm_min, UTCTime.tm_sec); } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/zencore/include/zencore/compactbinaryfmt.h b/src/zencore/include/zencore/compactbinaryfmt.h new file mode 100644 index 000000000..b03683db4 --- /dev/null +++ b/src/zencore/include/zencore/compactbinaryfmt.h @@ -0,0 +1,24 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <string_view> + +template<typename T> +requires DerivedFrom<T, zen::CbObjectView> +struct fmt::formatter<T> : fmt::formatter<std::string_view> +{ + template<typename FormatContext> + auto format(const zen::CbObject& a, FormatContext& ctx) const + { + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToJson(a, ObjStr); + return fmt::formatter<std::string_view>::format(ObjStr.ToView(), ctx); + } +}; diff --git a/src/zencore/include/zencore/compactbinarypackage.h b/src/zencore/include/zencore/compactbinarypackage.h index 16f723edc..9ec12cb0f 100644 --- a/src/zencore/include/zencore/compactbinarypackage.h +++ b/src/zencore/include/zencore/compactbinarypackage.h @@ -12,6 +12,8 @@ #include <span> #include <variant> +#include <EASTL/fixed_vector.h> + #ifdef GetObject # error "windows.h pollution" # undef GetObject @@ -46,26 +48,26 @@ public: inline explicit CbAttachment(const CbObject& InValue) : CbAttachment(InValue, nullptr) {} /** Construct a compact binary attachment. Value is cloned if not owned. Hash must match Value. */ - inline explicit CbAttachment(const CbObject& InValue, const IoHash& Hash) : CbAttachment(InValue, &Hash) {} + inline CbAttachment(const CbObject& InValue, const IoHash& Hash) : CbAttachment(InValue, &Hash) {} /** Construct a raw binary attachment. Value is cloned if not owned. */ - ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue); + ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue) : CbAttachment(CompositeBuffer(InValue)) {} /** Construct a raw binary attachment. Value is cloned if not owned. Hash must match Value. */ - ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue, const IoHash& Hash); - - /** Construct a raw binary attachment. Value is cloned if not owned. */ - ZENCORE_API explicit CbAttachment(const CompositeBuffer& InValue); + ZENCORE_API CbAttachment(const SharedBuffer& InValue, const IoHash& Hash) : CbAttachment(CompositeBuffer(InValue), Hash) {} /** Construct a raw binary attachment. Value is cloned if not owned. */ - ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue); + ZENCORE_API explicit CbAttachment(SharedBuffer&& InValue) : CbAttachment(CompositeBuffer(std::move(InValue))) {} - /** Construct a raw binary attachment. Value is cloned if not owned. */ - ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue, const IoHash& Hash); + /** Construct a raw binary attachment. Value is cloned if not owned. Hash must match Value. */ + ZENCORE_API CbAttachment(SharedBuffer&& InValue, const IoHash& Hash) : CbAttachment(CompositeBuffer(std::move(InValue)), Hash) {} /** Construct a compressed binary attachment. Value is cloned if not owned. */ - ZENCORE_API explicit CbAttachment(const CompressedBuffer& InValue, const IoHash& Hash); - ZENCORE_API explicit CbAttachment(CompressedBuffer&& InValue, const IoHash& Hash); + ZENCORE_API CbAttachment(const CompressedBuffer& InValue, const IoHash& Hash); + ZENCORE_API CbAttachment(CompressedBuffer&& InValue, const IoHash& Hash); + + /** Construct a binary attachment. Value is cloned if not owned. */ + ZENCORE_API CbAttachment(CompositeBuffer&& InValue, const IoHash& Hash); /** Reset this to a null attachment. */ inline void Reset() { *this = CbAttachment(); } @@ -80,10 +82,10 @@ public: ZENCORE_API [[nodiscard]] SharedBuffer AsBinary() const; /** Access the attachment as raw binary. Defaults to a null buffer on error. */ - ZENCORE_API [[nodiscard]] CompositeBuffer AsCompositeBinary() const; + ZENCORE_API [[nodiscard]] const CompositeBuffer& AsCompositeBinary() const; /** Access the attachment as compressed binary. Defaults to a null buffer if the attachment is null. */ - ZENCORE_API [[nodiscard]] CompressedBuffer AsCompressedBinary() const; + ZENCORE_API [[nodiscard]] const CompressedBuffer& AsCompressedBinary() const; /** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */ ZENCORE_API [[nodiscard]] CbObject AsObject() const; @@ -132,6 +134,7 @@ public: private: ZENCORE_API CbAttachment(const CbObject& Value, const IoHash* Hash); + ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue); IoHash Hash; std::variant<std::nullptr_t, CbObject, CompositeBuffer, CompressedBuffer> Value; @@ -173,14 +176,14 @@ public: using AttachmentResolver = std::function<SharedBuffer(const IoHash& Hash)>; /** Construct a null package. */ - CbPackage() = default; + CbPackage(); /** * Construct a package from a root object without gathering attachments. * * @param InObject The root object, which will be cloned unless it is owned. */ - inline explicit CbPackage(CbObject InObject) { SetObject(std::move(InObject)); } + inline explicit CbPackage(CbObject InObject) : CbPackage() { SetObject(std::move(InObject)); } /** * Construct a package from a root object and gather attachments using the resolver. @@ -188,7 +191,10 @@ public: * @param InObject The root object, which will be cloned unless it is owned. * @param InResolver A function that is invoked for every reference and binary reference field. */ - inline explicit CbPackage(CbObject InObject, AttachmentResolver InResolver) { SetObject(std::move(InObject), InResolver); } + inline explicit CbPackage(CbObject InObject, AttachmentResolver InResolver) : CbPackage() + { + SetObject(std::move(InObject), InResolver); + } /** * Construct a package from a root object without gathering attachments. @@ -196,7 +202,7 @@ public: * @param InObject The root object, which will be cloned unless it is owned. * @param InObjectHash The hash of the object, which must match to avoid validation errors. */ - inline explicit CbPackage(CbObject InObject, const IoHash& InObjectHash) { SetObject(std::move(InObject), InObjectHash); } + inline explicit CbPackage(CbObject InObject, const IoHash& InObjectHash) : CbPackage() { SetObject(std::move(InObject), InObjectHash); } /** * Construct a package from a root object and gather attachments using the resolver. @@ -205,7 +211,7 @@ public: * @param InObjectHash The hash of the object, which must match to avoid validation errors. * @param InResolver A function that is invoked for every reference and binary reference field. */ - inline explicit CbPackage(CbObject InObject, const IoHash& InObjectHash, AttachmentResolver InResolver) + inline explicit CbPackage(CbObject InObject, const IoHash& InObjectHash, AttachmentResolver InResolver) : CbPackage() { SetObject(std::move(InObject), InObjectHash, InResolver); } @@ -261,7 +267,10 @@ public: } /** Returns the attachments in this package. */ - inline std::span<const CbAttachment> GetAttachments() const { return Attachments; } + inline std::span<const CbAttachment> GetAttachments() const + { + return std::span<const CbAttachment>(begin(Attachments), end(Attachments)); + } /** * Find an attachment by its hash. @@ -282,6 +291,8 @@ public: void AddAttachments(std::span<const CbAttachment> Attachments); + void ReserveAttachments(size_t Count); + /** * Remove an attachment by hash. * @@ -320,9 +331,9 @@ private: void GatherAttachments(const CbObject& Object, AttachmentResolver Resolver); /** Attachments ordered by their hash. */ - std::vector<CbAttachment> Attachments; - CbObject Object; - IoHash ObjectHash; + eastl::fixed_vector<CbAttachment, 32> Attachments; + CbObject Object; + IoHash ObjectHash; }; namespace legacy { diff --git a/src/zencore/include/zencore/compactbinaryutil.h b/src/zencore/include/zencore/compactbinaryutil.h index 9524d1fc4..d750c6492 100644 --- a/src/zencore/include/zencore/compactbinaryutil.h +++ b/src/zencore/include/zencore/compactbinaryutil.h @@ -6,6 +6,7 @@ #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> namespace zen { @@ -43,4 +44,12 @@ RewriteCbObject(CbObjectView InObj, Invocable<CbObjectWriter&, CbFieldView&> aut return Writer.Save(); } +CbObject ValidateAndReadCompactBinaryObject(const SharedBuffer&& Payload, CbValidateError& OutError); +inline CbObject +ValidateAndReadCompactBinaryObject(const IoBuffer&& Payload, CbValidateError& OutError) +{ + return ValidateAndReadCompactBinaryObject(SharedBuffer(std::move(Payload)), OutError); +} +CbObject ValidateAndReadCompactBinaryObject(const CompressedBuffer&& Payload, CbValidateError& OutError); + } // namespace zen diff --git a/src/zencore/include/zencore/compactbinaryvalue.h b/src/zencore/include/zencore/compactbinaryvalue.h index 0124a8983..aa2d2821d 100644 --- a/src/zencore/include/zencore/compactbinaryvalue.h +++ b/src/zencore/include/zencore/compactbinaryvalue.h @@ -6,7 +6,7 @@ #include <zencore/endian.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> -#include <zencore/memory.h> +#include <zencore/memoryview.h> namespace zen { diff --git a/src/zencore/include/zencore/compositebuffer.h b/src/zencore/include/zencore/compositebuffer.h index cc03dd156..1e1611de9 100644 --- a/src/zencore/include/zencore/compositebuffer.h +++ b/src/zencore/include/zencore/compositebuffer.h @@ -2,6 +2,7 @@ #pragma once +#include <zencore/eastlutil.h> #include <zencore/sharedbuffer.h> #include <zencore/zencore.h> @@ -9,6 +10,8 @@ #include <span> #include <vector> +#include <EASTL/fixed_vector.h> + namespace zen { /** @@ -35,7 +38,7 @@ public: { m_Segments.reserve((GetBufferCount(std::forward<BufferTypes>(Buffers)) + ...)); (AppendBuffers(std::forward<BufferTypes>(Buffers)), ...); - std::erase_if(m_Segments, [](const SharedBuffer& It) { return It.IsNull(); }); + erase_if(m_Segments, [](const SharedBuffer& It) { return It.IsNull(); }); } } @@ -46,7 +49,10 @@ public: [[nodiscard]] ZENCORE_API uint64_t GetSize() const; /** Returns the segments that the buffer is composed from. */ - [[nodiscard]] inline std::span<const SharedBuffer> GetSegments() const { return std::span<const SharedBuffer>{m_Segments}; } + [[nodiscard]] inline std::span<const SharedBuffer> GetSegments() const + { + return std::span<const SharedBuffer>{begin(m_Segments), end(m_Segments)}; + } /** Returns true if the composite buffer is not null. */ [[nodiscard]] inline explicit operator bool() const { return !IsNull(); } @@ -120,29 +126,62 @@ public: static const CompositeBuffer Null; private: + typedef eastl::fixed_vector<SharedBuffer, 4> SharedBufferVector_t; + static inline size_t GetBufferCount(const CompositeBuffer& Buffer) { return Buffer.m_Segments.size(); } inline void AppendBuffers(const CompositeBuffer& Buffer) { + m_Segments.reserve(m_Segments.size() + Buffer.m_Segments.size()); m_Segments.insert(m_Segments.end(), begin(Buffer.m_Segments), end(Buffer.m_Segments)); } - inline void AppendBuffers(CompositeBuffer&& Buffer) - { - // TODO: this operates just like the by-reference version above - m_Segments.insert(m_Segments.end(), begin(Buffer.m_Segments), end(Buffer.m_Segments)); - } + inline void AppendBuffers(CompositeBuffer&& Buffer) { AppendBuffers(std::move(Buffer.m_Segments)); } static inline size_t GetBufferCount(const SharedBuffer&) { return 1; } + static inline size_t GetBufferCount(const IoBuffer&) { return 1; } inline void AppendBuffers(const SharedBuffer& Buffer) { m_Segments.push_back(Buffer); } inline void AppendBuffers(SharedBuffer&& Buffer) { m_Segments.push_back(std::move(Buffer)); } + inline void AppendBuffers(IoBuffer&& Buffer) { m_Segments.push_back(SharedBuffer(std::move(Buffer))); } + + static inline size_t GetBufferCount(std::span<IoBuffer>&& Container) { return Container.size(); } + inline void AppendBuffers(std::span<IoBuffer>&& Container) + { + m_Segments.reserve(m_Segments.size() + Container.size()); + for (IoBuffer& Buffer : Container) + { + m_Segments.emplace_back(SharedBuffer(std::move(Buffer))); + } + } static inline size_t GetBufferCount(std::vector<SharedBuffer>&& Container) { return Container.size(); } + static inline size_t GetBufferCount(std::vector<IoBuffer>&& Container) { return Container.size(); } inline void AppendBuffers(std::vector<SharedBuffer>&& Container) { - m_Segments.insert(m_Segments.end(), begin(Container), end(Container)); + m_Segments.reserve(m_Segments.size() + Container.size()); + for (SharedBuffer& Buffer : Container) + { + m_Segments.emplace_back(std::move(Buffer)); + } + } + inline void AppendBuffers(std::vector<IoBuffer>&& Container) + { + m_Segments.reserve(m_Segments.size() + Container.size()); + for (IoBuffer& Buffer : Container) + { + m_Segments.emplace_back(SharedBuffer(std::move(Buffer))); + } + } + + inline void AppendBuffers(SharedBufferVector_t&& Container) + { + m_Segments.reserve(m_Segments.size() + Container.size()); + for (SharedBuffer& Buffer : Container) + { + m_Segments.emplace_back(std::move(Buffer)); + } } private: - std::vector<SharedBuffer> m_Segments; + SharedBufferVector_t m_Segments; }; void compositebuffer_forcelink(); // internal diff --git a/src/zencore/include/zencore/compress.h b/src/zencore/include/zencore/compress.h index 44431f299..09fa6249d 100644 --- a/src/zencore/include/zencore/compress.h +++ b/src/zencore/include/zencore/compress.h @@ -74,6 +74,12 @@ public: OodleCompressor Compressor = OodleCompressor::Mermaid, OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast, uint64_t BlockSize = 0); + [[nodiscard]] ZENCORE_API static bool CompressToStream( + const CompositeBuffer& RawData, + std::function<void(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback, + OodleCompressor Compressor = OodleCompressor::Mermaid, + OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast, + uint64_t BlockSize = 0); /** * Construct from a compressed buffer previously created by Compress(). @@ -94,10 +100,12 @@ public: uint64_t& OutRawSize); [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressedNoValidate(IoBuffer&& CompressedData); [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressedNoValidate(CompositeBuffer&& CompressedData); - [[nodiscard]] ZENCORE_API static bool ValidateCompressedHeader(IoBuffer&& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize); - [[nodiscard]] ZENCORE_API static bool ValidateCompressedHeader(const IoBuffer& CompressedData, - IoHash& OutRawHash, - uint64_t& OutRawSize); + [[nodiscard]] ZENCORE_API static bool ValidateCompressedHeader(IoBuffer&& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize); + [[nodiscard]] ZENCORE_API static bool ValidateCompressedHeader(const IoBuffer& CompressedData, + IoHash& OutRawHash, + uint64_t& OutRawSize); + [[nodiscard]] ZENCORE_API static size_t GetHeaderSizeForNoneEncoder(); + [[nodiscard]] ZENCORE_API static UniqueBuffer CreateHeaderForNoneEncoder(uint64_t RawSize, const BLAKE3& RawHash); /** Reset this to null. */ inline void Reset() { CompressedData.Reset(); } @@ -128,9 +136,41 @@ public: /** Returns the hash of the raw data. Zero on error or if this is null. */ [[nodiscard]] ZENCORE_API IoHash DecodeRawHash() const; + /** + * Returns a block aligned range of a compressed buffer. + * + * This extracts a sub-range from the compressed buffer, if the buffer is block-compressed + * it will align start and end to end up on block boundaries. + * + * The resulting segments in the CompressedBuffer will are allocated and the data is copied + * from the source buffers. + * + * A new header will be allocated and generated. + * + * The RawHash field of the header will be zero as we do not calculate the raw hash for the sub-range + * + * @return A sub-range from the compressed buffer that encompasses RawOffset and RawSize + */ [[nodiscard]] ZENCORE_API CompressedBuffer CopyRange(uint64_t RawOffset, uint64_t RawSize = ~uint64_t(0)) const; /** + * Returns a block aligned range of a compressed buffer. + * + * This extracts a sub-range from the compressed buffer, if the buffer is block-compressed + * it will align start and end to end up on block boundaries. + * + * The resulting segments in the CompressedBuffer will reference the source buffers so it won't + * allocate memory and copy data for the compressed data blocks. + * + * A new header will be allocated and generated. + * + * The RawHash field of the header will be zero as we do not calculate the raw hash for the sub-range + * + * @return A sub-range from the compressed buffer that encompasses RawOffset and RawSize + */ + [[nodiscard]] ZENCORE_API CompressedBuffer GetRange(uint64_t RawOffset, uint64_t RawSize = ~uint64_t(0)) const; + + /** * Returns the compressor and compression level used by this buffer. * * The compressor and compression level may differ from those specified when creating the buffer @@ -162,6 +202,17 @@ public: */ [[nodiscard]] ZENCORE_API CompositeBuffer DecompressToComposite() const; + /** + * Decompress into and call callback for ranges of decompressed data. + * The buffer in the callback will be overwritten when the callback returns. + * + * @return True if the buffer is valid and can be decompressed. + */ + [[nodiscard]] ZENCORE_API bool DecompressToStream( + uint64_t RawOffset, + uint64_t RawSize, + std::function<bool(uint64_t SourceOffset, uint64_t SourceSize, uint64_t Offset, const CompositeBuffer& Range)>&& Callback) const; + /** A null compressed buffer. */ static const CompressedBuffer Null; diff --git a/src/zencore/include/zencore/crypto.h b/src/zencore/include/zencore/crypto.h index 83d416b0f..88d156879 100644 --- a/src/zencore/include/zencore/crypto.h +++ b/src/zencore/include/zencore/crypto.h @@ -3,7 +3,7 @@ #pragma once -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <zencore/zencore.h> #include <memory> diff --git a/src/zencore/include/zencore/eastlutil.h b/src/zencore/include/zencore/eastlutil.h new file mode 100644 index 000000000..642321dae --- /dev/null +++ b/src/zencore/include/zencore/eastlutil.h @@ -0,0 +1,20 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <algorithm> + +namespace zen { + +size_t +erase_if(auto& _Cont, auto Predicate) +{ + auto _First = _Cont.begin(); + const auto _Last = _Cont.end(); + const auto _Old_size = _Cont.size(); + _First = std::remove_if(_First, _Last, Predicate); + _Cont.erase(_First, _Last); + return _Old_size - _Cont.size(); +} + +} // namespace zen diff --git a/src/zencore/include/zencore/except.h b/src/zencore/include/zencore/except.h index 6810e6ea9..c933adfd8 100644 --- a/src/zencore/include/zencore/except.h +++ b/src/zencore/include/zencore/except.h @@ -63,7 +63,9 @@ MakeErrorCodeFromLastError() noexcept class OptionParseException : public std::runtime_error { public: - inline explicit OptionParseException(const std::string& Message) : std::runtime_error(Message) {} + // inline explicit OptionParseException(const std::string& Message) : std::runtime_error(Message) {} + inline OptionParseException(const std::string& Message, const std::string& Help) : std::runtime_error(Message), m_Help(Help) {} + const std::string m_Help; }; bool IsOOM(const std::system_error& SystemError); diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h index 233941479..36d4d1b68 100644 --- a/src/zencore/include/zencore/filesystem.h +++ b/src/zencore/include/zencore/filesystem.h @@ -4,6 +4,7 @@ #include "zencore.h" +#include <zencore/enumflags.h> #include <zencore/iobuffer.h> #include <zencore/string.h> @@ -12,30 +13,42 @@ namespace zen { -class IoBuffer; class CompositeBuffer; +class IoBuffer; +class Latch; +class WorkerThreadPool; + +/** Delete directory (after deleting any contents) + */ +ZENCORE_API bool DeleteDirectories(const std::filesystem::path& Path); /** Delete directory (after deleting any contents) */ -ZENCORE_API bool DeleteDirectories(const std::filesystem::path& dir); +ZENCORE_API bool DeleteDirectories(const std::filesystem::path& Path, std::error_code& Ec); + +/** Ensure directory exists. + + Will also create any required parent direCleanDirectoryctories + */ +ZENCORE_API bool CreateDirectories(const std::filesystem::path& Path); /** Ensure directory exists. Will also create any required parent directories */ -ZENCORE_API bool CreateDirectories(const std::filesystem::path& dir); +ZENCORE_API bool CreateDirectories(const std::filesystem::path& Path, std::error_code& Ec); /** Ensure directory exists and delete contents (if any) before returning */ -ZENCORE_API bool CleanDirectory(const std::filesystem::path& dir); +ZENCORE_API bool CleanDirectory(const std::filesystem::path& Path, bool ForceRemoveReadOnlyFiles); /** Ensure directory exists and delete contents (if any) before returning */ -ZENCORE_API bool CleanDirectoryExceptDotFiles(const std::filesystem::path& dir); +ZENCORE_API bool CleanDirectory(const std::filesystem::path& Path, bool ForceRemoveReadOnlyFiles, std::error_code& Ec); -/** Map native file handle to a path +/** Ensure directory exists and delete contents (if any) before returning */ -ZENCORE_API std::filesystem::path PathFromHandle(void* NativeHandle); +ZENCORE_API bool CleanDirectoryExceptDotFiles(const std::filesystem::path& Path); /** Map native file handle to a path */ @@ -45,16 +58,91 @@ ZENCORE_API std::filesystem::path PathFromHandle(void* NativeHandle, std::error_ */ ZENCORE_API std::filesystem::path CanonicalPath(std::filesystem::path InPath, std::error_code& Ec); +/** Query file size + */ +ZENCORE_API bool IsFile(const std::filesystem::path& Path); + +/** Query file size + */ +ZENCORE_API bool IsFile(const std::filesystem::path& Path, std::error_code& Ec); + +/** Query file size + */ +ZENCORE_API bool IsDir(const std::filesystem::path& Path); + +/** Query file size + */ +ZENCORE_API bool IsDir(const std::filesystem::path& Path, std::error_code& Ec); + +/** Query file size + */ +ZENCORE_API bool RemoveFile(const std::filesystem::path& Path); + +/** Query file size + */ +ZENCORE_API bool RemoveFile(const std::filesystem::path& Path, std::error_code& Ec); + +/** Query file size + */ +ZENCORE_API bool RemoveDir(const std::filesystem::path& Path); + +/** Query file size + */ +ZENCORE_API bool RemoveDir(const std::filesystem::path& Path, std::error_code& Ec); + +/** Query file size + */ +ZENCORE_API uint64_t FileSizeFromPath(const std::filesystem::path& Path); + +/** Query file size + */ +ZENCORE_API uint64_t FileSizeFromPath(const std::filesystem::path& Path, std::error_code& Ec); + /** Query file size from native file handle */ ZENCORE_API uint64_t FileSizeFromHandle(void* NativeHandle); +/** Query file size from native file handle + */ +ZENCORE_API uint64_t FileSizeFromHandle(void* NativeHandle, std::error_code& Ec); + +/** Get a native time tick of last modification time + */ +ZENCORE_API uint64_t GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec); + +/** Get a native time tick of last modification time + */ +ZENCORE_API uint64_t GetModificationTickFromPath(const std::filesystem::path& Filename); + +ZENCORE_API bool TryGetFileProperties(const std::filesystem::path& Path, + uint64_t& OutSize, + uint64_t& OutModificationTick, + uint32_t& OutNativeModeOrAttributes); + +/** Move a file, if the files are not on the same drive the function will fail + */ +ZENCORE_API void RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); + +/** Move a file, if the files are not on the same drive the function will fail + */ +ZENCORE_API void RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec); + +/** Move a directory, if the files are not on the same drive the function will fail + */ +ZENCORE_API void RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); + +/** Move a directory, if the files are not on the same drive the function will fail + */ +ZENCORE_API void RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec); + ZENCORE_API std::filesystem::path GetRunningExecutablePath(); /** Set the max open file handle count to max allowed for the current process on Linux and MacOS */ ZENCORE_API void MaximizeOpenFileCount(); +ZENCORE_API bool PrepareFileForScatteredWrite(void* FileHandle, uint64_t FinalSize); + struct FileContents { std::vector<IoBuffer> Data; @@ -78,6 +166,18 @@ ZENCORE_API void WriteFile(std::filesystem::path Path, const IoBuffer* const* Da ZENCORE_API void WriteFile(std::filesystem::path Path, IoBuffer Data); ZENCORE_API void WriteFile(std::filesystem::path Path, CompositeBuffer Data); ZENCORE_API bool MoveToFile(std::filesystem::path Path, IoBuffer Data); +ZENCORE_API void ScanFile(void* NativeHandle, + uint64_t Offset, + uint64_t Size, + uint64_t ChunkSize, + std::function<void(const void* Data, size_t Size)>&& ProcessFunc); +ZENCORE_API void WriteFile(void* NativeHandle, + const void* Data, + uint64_t Size, + uint64_t FileOffset, + uint64_t ChunkSize, + std::error_code& Ec); +ZENCORE_API void ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uint64_t ChunkSize, std::error_code& Ec); struct CopyFileOptions { @@ -85,11 +185,11 @@ struct CopyFileOptions bool MustClone = false; }; -ZENCORE_API bool CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const CopyFileOptions& Options); -ZENCORE_API void CopyFile(std::filesystem::path FromPath, - std::filesystem::path ToPath, - const CopyFileOptions& Options, - std::error_code& OutError); +ZENCORE_API bool CopyFile(const std::filesystem::path& FromPath, const std::filesystem::path& ToPath, const CopyFileOptions& Options); +ZENCORE_API void CopyFile(const std::filesystem::path& FromPath, + const std::filesystem::path& ToPath, + const CopyFileOptions& Options, + std::error_code& OutError); ZENCORE_API void CopyTree(std::filesystem::path FromPath, std::filesystem::path ToPath, const CopyFileOptions& Options); ZENCORE_API bool SupportsBlockRefCounting(std::filesystem::path Path); @@ -190,28 +290,69 @@ class FileSystemTraversal public: struct TreeVisitor { - using path_view = std::basic_string_view<std::filesystem::path::value_type>; - using path_string = std::filesystem::path::string_type; + using path_view = std::basic_string_view<std::filesystem::path::value_type>; - virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t FileSize) = 0; + virtual void VisitFile(const std::filesystem::path& Parent, + const path_view& File, + uint64_t FileSize, + uint32_t NativeModeOrAttributes, + uint64_t NativeModificationTick) = 0; // This should return true if we should recurse into the directory - virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName) = 0; + virtual bool VisitDirectory(const std::filesystem::path& Parent, + const path_view& DirectoryName, + uint32_t NativeModeOrAttributes) = 0; }; void TraverseFileSystem(const std::filesystem::path& RootDir, TreeVisitor& Visitor); }; +enum class DirectoryContentFlags : uint8_t +{ + None = 0, + IncludeDirs = 1u << 0, + IncludeFiles = 1u << 1, + Recursive = 1u << 2, + IncludeFileSizes = 1u << 3, + IncludeAttributes = 1u << 4, + IncludeModificationTick = 1u << 5, + IncludeAllEntries = IncludeDirs | IncludeFiles | Recursive +}; + +ENUM_CLASS_FLAGS(DirectoryContentFlags) + struct DirectoryContent { - static const uint8_t IncludeDirsFlag = 1u << 0; - static const uint8_t IncludeFilesFlag = 1u << 1; - static const uint8_t RecursiveFlag = 1u << 2; std::vector<std::filesystem::path> Files; + std::vector<uint64_t> FileSizes; + std::vector<uint32_t> FileAttributes; + std::vector<uint64_t> FileModificationTicks; std::vector<std::filesystem::path> Directories; + std::vector<uint32_t> DirectoryAttributes; +}; + +void GetDirectoryContent(const std::filesystem::path& RootDir, DirectoryContentFlags Flags, DirectoryContent& OutContent); + +struct GetDirectoryContentVisitor +{ +public: + struct DirectoryContent + { + std::vector<std::filesystem::path> FileNames; + std::vector<uint64_t> FileSizes; + std::vector<uint32_t> FileAttributes; + std::vector<uint64_t> FileModificationTicks; + std::vector<std::filesystem::path> DirectoryNames; + std::vector<uint32_t> DirectoryAttributes; + }; + virtual void AsyncVisitDirectory(const std::filesystem::path& RelativeRoot, DirectoryContent&& Content) = 0; }; -void GetDirectoryContent(const std::filesystem::path& RootDir, uint8_t Flags, DirectoryContent& OutContent); +void GetDirectoryContent(const std::filesystem::path& RootDir, + DirectoryContentFlags Flags, + GetDirectoryContentVisitor& Visitor, + WorkerThreadPool& WorkerPool, + Latch& PendingWorkCount); std::string GetEnvVariable(std::string_view VariableName); @@ -220,6 +361,40 @@ std::filesystem::path SearchPathForExecutable(std::string_view ExecutableName); std::error_code RotateFiles(const std::filesystem::path& Filename, std::size_t MaxFiles); std::error_code RotateDirectories(const std::filesystem::path& DirectoryName, std::size_t MaxDirectories); +std::filesystem::path PickDefaultSystemRootDirectory(); + +#if ZEN_PLATFORM_WINDOWS +uint32_t GetFileAttributes(const std::filesystem::path& Filename); +uint32_t GetFileAttributes(const std::filesystem::path& Filename, std::error_code& Ec); +void SetFileAttributes(const std::filesystem::path& Filename, uint32_t Attributes); +void SetFileAttributes(const std::filesystem::path& Filename, uint32_t Attributes, std::error_code& Ec); +#endif // ZEN_PLATFORM_WINDOWS + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +uint32_t GetFileMode(const std::filesystem::path& Filename); +uint32_t GetFileMode(const std::filesystem::path& Filename, std::error_code& Ec); +void SetFileMode(const std::filesystem::path& Filename, uint32_t Mode); +void SetFileMode(const std::filesystem::path& Filename, uint32_t Mode, std::error_code& Ec); +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +bool IsFileAttributeReadOnly(uint32_t FileAttributes); +bool IsFileModeReadOnly(uint32_t FileMode); +uint32_t MakeFileAttributeReadOnly(uint32_t FileAttributes, bool ReadOnly); +uint32_t MakeFileModeReadOnly(uint32_t FileMode, bool ReadOnly); + +bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly, std::error_code& Ec); +bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly); + +class SharedMemory +{ +public: + virtual ~SharedMemory() {} + virtual void* GetData() = 0; +}; + +std::unique_ptr<SharedMemory> OpenSharedMemory(std::string_view Name, size_t Size, bool SystemGlobal); +std::unique_ptr<SharedMemory> CreateSharedMemory(std::string_view Name, size_t Size, bool SystemGlobal); + ////////////////////////////////////////////////////////////////////////// void filesystem_forcelink(); // internal diff --git a/src/zencore/include/zencore/fmtutils.h b/src/zencore/include/zencore/fmtutils.h index 9c5bdae66..404e570fd 100644 --- a/src/zencore/include/zencore/fmtutils.h +++ b/src/zencore/include/zencore/fmtutils.h @@ -12,6 +12,7 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <fmt/format.h> ZEN_THIRD_PARTY_INCLUDES_END +#include <chrono> #include <string_view> // Custom formatting for some zencore types @@ -20,7 +21,8 @@ template<typename T> requires DerivedFrom<T, zen::StringBuilderBase> struct fmt::formatter<T> : fmt::formatter<std::string_view> { - auto format(const zen::StringBuilderBase& a, format_context& ctx) const + template<typename FormatContext> + auto format(const zen::StringBuilderBase& a, FormatContext& ctx) const { return fmt::formatter<std::string_view>::format(a.ToView(), ctx); } @@ -30,7 +32,8 @@ template<typename T> requires DerivedFrom<T, zen::NiceBase> struct fmt::formatter<T> : fmt::formatter<std::string_view> { - auto format(const zen::NiceBase& a, format_context& ctx) const + template<typename FormatContext> + auto format(const zen::NiceBase& a, FormatContext& ctx) const { return fmt::formatter<std::string_view>::format(std::string_view(a), ctx); } @@ -40,7 +43,7 @@ template<> struct fmt::formatter<zen::IoHash> : formatter<string_view> { template<typename FormatContext> - auto format(const zen::IoHash& Hash, FormatContext& ctx) + auto format(const zen::IoHash& Hash, FormatContext& ctx) const { zen::IoHash::String_t String; Hash.ToHexString(String); @@ -52,7 +55,7 @@ template<> struct fmt::formatter<zen::Oid> : formatter<string_view> { template<typename FormatContext> - auto format(const zen::Oid& Id, FormatContext& ctx) + auto format(const zen::Oid& Id, FormatContext& ctx) const { zen::StringBuilder<32> String; Id.ToString(String); @@ -64,7 +67,7 @@ template<> struct fmt::formatter<zen::Guid> : formatter<string_view> { template<typename FormatContext> - auto format(const zen::Guid& Id, FormatContext& ctx) + auto format(const zen::Guid& Id, FormatContext& ctx) const { zen::StringBuilder<48> String; Id.ToString(String); @@ -76,7 +79,7 @@ template<> struct fmt::formatter<std::filesystem::path> : formatter<string_view> { template<typename FormatContext> - auto format(const std::filesystem::path& Path, FormatContext& ctx) + auto format(const std::filesystem::path& Path, FormatContext& ctx) const { using namespace std::literals; @@ -97,8 +100,22 @@ template<typename T> requires DerivedFrom<T, zen::PathBuilderBase> struct fmt::formatter<T> : fmt::formatter<std::string_view> { - auto format(const zen::PathBuilderBase& a, format_context& ctx) const + template<typename FormatContext> + auto format(const zen::PathBuilderBase& a, FormatContext& ctx) const { return fmt::formatter<std::string_view>::format(a.ToView(), ctx); } }; + +template<> +struct fmt::formatter<std::chrono::system_clock::time_point> : formatter<string_view> +{ + template<typename FormatContext> + auto format(const std::chrono::system_clock::time_point& TimePoint, FormatContext& ctx) const + { + std::time_t Time = std::chrono::system_clock::to_time_t(TimePoint); + char TimeString[std::size("yyyy-mm-ddThh:mm:ss")]; + std::strftime(std::data(TimeString), std::size(TimeString), "%FT%T", std::localtime(&Time)); + return fmt::format_to(ctx.out(), "{}", TimeString); + } +}; diff --git a/src/zencore/include/zencore/guardvalue.h b/src/zencore/include/zencore/guardvalue.h new file mode 100644 index 000000000..5419e882a --- /dev/null +++ b/src/zencore/include/zencore/guardvalue.h @@ -0,0 +1,40 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +namespace zen { + +/** + * exception-safe guard around saving/restoring a value. + * Commonly used to make sure a value is restored + * even if the code early outs in the future. + * Usage: + * TGuardValue<bool> GuardSomeBool(bSomeBool, false); // Sets bSomeBool to false, and restores it in dtor. + */ +template<typename RefType, typename AssignedType = RefType> +struct TGuardValue +{ + [[nodiscard]] TGuardValue(RefType& ReferenceValue, const AssignedType& NewValue) + : RefValue(ReferenceValue) + , OriginalValue(ReferenceValue) + { + RefValue = NewValue; + } + ~TGuardValue() { RefValue = OriginalValue; } + + /** + * Provides read-only access to the original value of the data being tracked by this struct + * + * @return a const reference to the original data value + */ + const AssignedType& GetOriginalValue() const { return OriginalValue; } + + TGuardValue& operator=(const TGuardValue&) = delete; + TGuardValue(const TGuardValue&) = delete; + +private: + RefType& RefValue; + AssignedType OriginalValue; +}; + +} // namespace zen diff --git a/src/zencore/include/zencore/iobuffer.h b/src/zencore/include/zencore/iobuffer.h index b9e503354..63779407e 100644 --- a/src/zencore/include/zencore/iobuffer.h +++ b/src/zencore/include/zencore/iobuffer.h @@ -4,7 +4,7 @@ #include <memory.h> #include <zenbase/refcount.h> -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <atomic> #include "zencore.h" @@ -99,6 +99,11 @@ public: ZENCORE_API IoBufferCore(size_t SizeBytes, size_t Alignment); ZENCORE_API ~IoBufferCore(); + void* operator new(size_t Size); + void operator delete(void* Ptr); + void* operator new[](size_t Size) = delete; + void operator delete[](void* Ptr) = delete; + // Reference counting inline uint32_t AddRef() const { return AtomicIncrement(const_cast<IoBufferCore*>(this)->m_RefCount); } @@ -243,9 +248,7 @@ protected: kIsMutable = 1 << 1, kIsExtended = 1 << 2, // Is actually a SharedBufferExtendedCore kIsMaterialized = 1 << 3, // Data pointers are valid - kLowLevelAlloc = 1 << 4, // Using direct memory allocation kIsWholeFile = 1 << 5, // References an entire file - kIoBufferAlloc = 1 << 6, // Using IoBuffer allocator kIsOwnedByThis = 1 << 7, // Note that we have some extended flags defined below @@ -338,11 +341,7 @@ public: }; inline IoBuffer() = default; - inline IoBuffer(IoBuffer&& Rhs) noexcept - { - m_Core.Swap(Rhs.m_Core); - Rhs.m_Core = NullBufferCore; - } + inline IoBuffer(IoBuffer&& Rhs) noexcept : m_Core(std::move(Rhs.m_Core)) { Rhs.m_Core = NullBufferCore; } inline IoBuffer(const IoBuffer& Rhs) = default; inline IoBuffer& operator=(const IoBuffer& Rhs) = default; inline IoBuffer& operator =(IoBuffer&& Rhs) noexcept @@ -379,7 +378,7 @@ public: inline explicit operator bool() const { return !m_Core->IsNull(); } inline operator MemoryView() const& { return MemoryView(m_Core->DataPointer(), m_Core->DataBytes()); } - inline void MakeOwned() { return m_Core->MakeOwned(); } + inline void MakeOwned() const { return m_Core->MakeOwned(); } [[nodiscard]] inline bool IsOwned() const { return m_Core->IsOwned(); } [[nodiscard]] inline bool IsWholeFile() const { return m_Core->IsWholeFile(); } [[nodiscard]] void* MutableData() const { return m_Core->MutableDataPointer(); } @@ -442,8 +441,6 @@ public: inline static IoBuffer MakeCloneFromMemory(MemoryView Memory) { return MakeCloneFromMemory(Memory.GetData(), Memory.GetSize()); } }; -IoHash HashBuffer(IoBuffer& Buffer); - void iobuffer_forcelink(); } // namespace zen diff --git a/src/zencore/include/zencore/iohash.h b/src/zencore/include/zencore/iohash.h index 79ed8ea1c..a619b0053 100644 --- a/src/zencore/include/zencore/iohash.h +++ b/src/zencore/include/zencore/iohash.h @@ -1,12 +1,12 @@ // Copyright Epic Games, Inc. All Rights Reserved. -// Copyright Epic Games, Inc. All Rights Reserved. #pragma once #include "zencore.h" #include <zencore/blake3.h> -#include <zencore/memory.h> +#include <zencore/memcmp.h> +#include <zencore/memoryview.h> #include <compare> #include <string_view> @@ -47,21 +47,37 @@ struct IoHash static IoHash HashBuffer(const void* data, size_t byteCount); static IoHash HashBuffer(MemoryView Data) { return HashBuffer(Data.GetData(), Data.GetSize()); } - static IoHash HashBuffer(const CompositeBuffer& Buffer); + static IoHash HashBuffer(const CompositeBuffer& Buffer, std::atomic<uint64_t>* ProcessedBytes = nullptr); + static IoHash HashBuffer(const IoBuffer& Buffer, std::atomic<uint64_t>* ProcessedBytes = nullptr); static IoHash FromHexString(const char* string); static IoHash FromHexString(const std::string_view string); + static bool TryParse(std::string_view Str, IoHash& Hash); const char* ToHexString(char* outString /* 40 characters + NUL terminator */) const; StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const; std::string ToHexString() const; - static const int StringLength = 40; - typedef char String_t[StringLength + 1]; + static constexpr int StringLength = 40; + typedef char String_t[StringLength + 1]; static const IoHash Zero; // Initialized to all zeros + static const IoHash Max; // Initialized to all ones inline auto operator<=>(const IoHash& rhs) const = default; - inline bool operator==(const IoHash& rhs) const { return memcmp(Hash, rhs.Hash, sizeof Hash) == 0; } - inline bool operator<(const IoHash& rhs) const { return memcmp(Hash, rhs.Hash, sizeof Hash) < 0; } + inline bool operator==(const IoHash& rhs) const + { + const uint32_t* LhsHash = reinterpret_cast<const uint32_t*>(Hash); + const uint32_t* RhsHash = reinterpret_cast<const uint32_t*>(rhs.Hash); + return LhsHash[0] == RhsHash[0] && LhsHash[1] == RhsHash[1] && LhsHash[2] == RhsHash[2] && LhsHash[3] == RhsHash[3] && + LhsHash[4] == RhsHash[4]; + } + inline bool operator!=(const IoHash& rhs) const + { + const uint32_t* LhsHash = reinterpret_cast<const uint32_t*>(Hash); + const uint32_t* RhsHash = reinterpret_cast<const uint32_t*>(rhs.Hash); + return LhsHash[0] != RhsHash[0] || LhsHash[1] != RhsHash[1] || LhsHash[2] != RhsHash[2] || LhsHash[3] != RhsHash[3] || + LhsHash[4] != RhsHash[4]; + } + inline bool operator<(const IoHash& rhs) const { return MemCmpFixed<sizeof Hash, std::uint32_t>(Hash, rhs.Hash) < 0; } struct Hasher { @@ -86,6 +102,12 @@ struct IoHashStream return *this; } + IoHashStream& Append(const IoBuffer& Buffer) + { + m_Blake3Stream.Append(Buffer); + return *this; + } + /// Append another chunk IoHashStream& Append(MemoryView Data) { diff --git a/src/zencore/include/zencore/jobqueue.h b/src/zencore/include/zencore/jobqueue.h index 91ca24b34..470ed3fc6 100644 --- a/src/zencore/include/zencore/jobqueue.h +++ b/src/zencore/include/zencore/jobqueue.h @@ -22,9 +22,20 @@ class JobQueue; class JobContext { public: - virtual bool IsCancelled() const = 0; - virtual void ReportMessage(std::string_view Message) = 0; - virtual void ReportProgress(std::string_view CurrentOp, uint32_t CurrentOpPercentComplete) = 0; + virtual bool IsCancelled() const = 0; + virtual void ReportMessage(std::string_view Message) = 0; + // virtual void ReportProgress(std::string_view CurrentOp, uint32_t CurrentOpPercentComplete) = 0; + virtual void ReportProgress(std::string_view CurrentOp, std::string_view Details, ptrdiff_t TotalCount, ptrdiff_t RemainingCount) = 0; +}; + +class JobError : public std::runtime_error +{ +public: + using _Mybase = runtime_error; + + JobError(const std::string& Message, int ReturnCode) : _Mybase(Message), m_ReturnCode(ReturnCode) {} + + const int m_ReturnCode = 0; }; class JobQueue @@ -48,8 +59,11 @@ public: struct State { std::string CurrentOp; - uint32_t CurrentOpPercentComplete = 0; + std::string CurrentOpDetails; + ptrdiff_t TotalCount; + ptrdiff_t RemainingCount; std::vector<std::string> Messages; + std::string AbortReason; }; struct JobInfo @@ -69,6 +83,7 @@ public: std::chrono::system_clock::time_point StartTime; std::chrono::system_clock::time_point EndTime; int WorkerThreadId; + int ReturnCode; }; // Will only respond once when status is Complete or Aborted diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h index 6d44e31df..afbbbd3ee 100644 --- a/src/zencore/include/zencore/logging.h +++ b/src/zencore/include/zencore/logging.h @@ -26,6 +26,7 @@ namespace zen::logging { void InitializeLogging(); void ShutdownLogging(); bool EnableVTMode(); +void FlushLogging(); LoggerRef Default(); void SetDefault(std::string_view NewDefaultLoggerId); @@ -58,7 +59,6 @@ struct LogCategory zen::LoggerRef LoggerRef; }; -void EmitConsoleLogMessage(int LogLevel, std::string_view Message); void EmitConsoleLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args); void EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Message); void EmitLogMessage(LoggerRef& Logger, const SourceLocation& Location, int LogLevel, std::string_view Message); diff --git a/src/zencore/include/zencore/memcmp.h b/src/zencore/include/zencore/memcmp.h new file mode 100644 index 000000000..5608f10f0 --- /dev/null +++ b/src/zencore/include/zencore/memcmp.h @@ -0,0 +1,47 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> +#include <zencore/zencore.h> + +#include <cstddef> + +namespace zen { + +template<std::size_t SIZE> +inline int +MemCmpFixed(const void* a1, const void* a2) +{ + auto const s1 = reinterpret_cast<const unsigned char*>(a1); + auto const s2 = reinterpret_cast<const unsigned char*>(a2); + auto const diff = *s1 - *s2; + return diff ? diff : MemCmpFixed<SIZE - 1>(s1 + 1, s2 + 1); +} + +template<> +inline int +MemCmpFixed<0>(const void*, const void*) +{ + return 0; +} + +template<std::size_t SIZE, typename EQTYPE> +inline int +MemCmpFixed(const void* a1, const void* a2) +{ + ZEN_ASSERT_SLOW((uintptr_t(a1) & (sizeof(EQTYPE) - 1)) == 0); + ZEN_ASSERT_SLOW((uintptr_t(a2) & (sizeof(EQTYPE) - 1)) == 0); + auto const s1 = reinterpret_cast<const EQTYPE*>(a1); + auto const s2 = reinterpret_cast<const EQTYPE*>(a2); + return (*s1 != *s2) ? MemCmpFixed<sizeof(EQTYPE)>(s1, s2) : MemCmpFixed<SIZE - sizeof(EQTYPE), EQTYPE>(s1 + 1, s2 + 1); +} + +template<> +inline int +MemCmpFixed<0, uint32_t>(const void*, const void*) +{ + return 0; +} + +} // namespace zen diff --git a/src/zencore/include/zencore/memory/align.h b/src/zencore/include/zencore/memory/align.h new file mode 100644 index 000000000..9d4101fab --- /dev/null +++ b/src/zencore/include/zencore/memory/align.h @@ -0,0 +1,69 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +namespace zen { + +/** + * Aligns a value to the nearest higher multiple of 'Alignment', which must be a power of two. + * + * @param Val The value to align. + * @param Alignment The alignment value, must be a power of two. + * + * @return The value aligned up to the specified alignment. + */ +template<typename T> +constexpr T +Align(T Val, uint64_t Alignment) +{ + return (T)(((uint64_t)Val + Alignment - 1) & ~(Alignment - 1)); +} + +/** + * Aligns a value to the nearest lower multiple of 'Alignment', which must be a power of two. + * + * @param Val The value to align. + * @param Alignment The alignment value, must be a power of two. + * + * @return The value aligned down to the specified alignment. + */ +template<typename T> +constexpr T +AlignDown(T Val, uint64_t Alignment) +{ + return (T)(((uint64_t)Val) & ~(Alignment - 1)); +} + +/** + * Checks if a pointer is aligned to the specified alignment. + * + * @param Val The value to align. + * @param Alignment The alignment value, must be a power of two. + * + * @return true if the pointer is aligned to the specified alignment, false otherwise. + */ +template<typename T> +constexpr bool +IsAligned(T Val, uint64_t Alignment) +{ + return !((uint64_t)Val & (Alignment - 1)); +} + +/** + * Aligns a value to the nearest higher multiple of 'Alignment'. + * + * @param Val The value to align. + * @param Alignment The alignment value, can be any arbitrary value. + * + * @return The value aligned up to the specified alignment. + */ +template<typename T> +constexpr T +AlignArbitrary(T Val, uint64_t Alignment) +{ + return (T)((((uint64_t)Val + Alignment - 1) / Alignment) * Alignment); +} + +} // namespace zen diff --git a/src/zencore/include/zencore/memory/fmalloc.h b/src/zencore/include/zencore/memory/fmalloc.h new file mode 100644 index 000000000..5b476429e --- /dev/null +++ b/src/zencore/include/zencore/memory/fmalloc.h @@ -0,0 +1,105 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cstddef> + +#include <zenbase/zenbase.h> + +namespace zen { + +enum +{ + DEFAULT_ALIGNMENT = 0 +}; + +/** + * Inherit from FUseSystemMallocForNew if you want your objects to be placed in memory + * alloced by the system malloc routines, bypassing GMalloc. This is e.g. used by FMalloc + * itself. + */ +class FUseSystemMallocForNew +{ +public: + void* operator new(size_t Size); + void operator delete(void* Ptr); + void* operator new[](size_t Size); + void operator delete[](void* Ptr); +}; + +/** Memory allocator abstraction + */ + +class FMalloc : public FUseSystemMallocForNew +{ +public: + /** + * Malloc + */ + virtual void* Malloc(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT) = 0; + + /** + * TryMalloc - like Malloc(), but may return a nullptr result if the allocation + * request cannot be satisfied. + */ + virtual void* TryMalloc(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT); + + /** + * Realloc + */ + virtual void* Realloc(void* Original, size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT) = 0; + + /** + * TryRealloc - like Realloc(), but may return a nullptr if the allocation + * request cannot be satisfied. Note that in this case the memory + * pointed to by Original will still be valid + */ + virtual void* TryRealloc(void* Original, size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT); + + /** + * Free + */ + virtual void Free(void* Original) = 0; + + /** + * Malloc zeroed memory + */ + virtual void* MallocZeroed(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT); + + /** + * TryMallocZeroed - like MallocZeroed(), but may return a nullptr result if the allocation + * request cannot be satisfied. + */ + virtual void* TryMallocZeroed(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT); + + /** + * For some allocators this will return the actual size that should be requested to eliminate + * internal fragmentation. The return value will always be >= Count. This can be used to grow + * and shrink containers to optimal sizes. + * This call is always fast and threadsafe with no locking. + */ + virtual size_t QuantizeSize(size_t Count, uint32_t Alignment); + + /** + * If possible determine the size of the memory allocated at the given address + * + * @param Original - Pointer to memory we are checking the size of + * @param SizeOut - If possible, this value is set to the size of the passed in pointer + * @return true if succeeded + */ + virtual bool GetAllocationSize(void* Original, size_t& SizeOut); + + /** + * Notifies the malloc implementation that initialization of all allocators in GMalloc is complete, so it's safe to initialize any extra + * features that require "regular" allocations + */ + virtual void OnMallocInitialized(); + + virtual void Trim(bool bTrimThreadCaches); + + virtual void OutOfMemory(size_t Size, uint32_t Alignment); +}; + +extern FMalloc* GMalloc; /* Memory allocator */ + +} // namespace zen diff --git a/src/zencore/include/zencore/memory/llm.h b/src/zencore/include/zencore/memory/llm.h new file mode 100644 index 000000000..ea7f68cc6 --- /dev/null +++ b/src/zencore/include/zencore/memory/llm.h @@ -0,0 +1,47 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> +#include <zencore/memory/tagtrace.h> + +namespace zen { + +// clang-format off +#define LLM_ENUM_GENERIC_TAGS(macro) \ + macro(Untagged, "Untagged", -1) \ + macro(ProgramSize, "ProgramSize", -1) \ + macro(Metrics, "Metrics", -1) \ + macro(Logging, "Logging", -1) \ + macro(IoBuffer, "IoBuffer", -1) \ + macro(IoBufferMemory, "IoMemory", ELLMTag::IoBuffer) \ + macro(IoBufferCore, "IoCore", ELLMTag::IoBuffer) + +// clang-format on + +enum class ELLMTag : uint8_t +{ +#define LLM_ENUM(Enum, Str, Parent) Enum, + LLM_ENUM_GENERIC_TAGS(LLM_ENUM) +#undef LLM_ENUM + + GenericTagCount +}; + +struct FLLMTag +{ +public: + FLLMTag(const char* TagName); + FLLMTag(const char* TagName, const FLLMTag& ParentTag); + + inline int32_t GetTag() const { return m_Tag; } + inline int32_t GetParentTag() const { return m_ParentTag; } + +private: + int32_t m_Tag = -1; + int32_t m_ParentTag = -1; + + void AssignAndAnnounceNewTag(const char* TagName); +}; + +} // namespace zen diff --git a/src/zencore/include/zencore/memory/mallocansi.h b/src/zencore/include/zencore/memory/mallocansi.h new file mode 100644 index 000000000..510695c8c --- /dev/null +++ b/src/zencore/include/zencore/memory/mallocansi.h @@ -0,0 +1,31 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "fmalloc.h" +#include "memory.h" + +namespace zen { + +void* AnsiMalloc(size_t Size, uint32_t Alignment); +void* AnsiRealloc(void* Ptr, size_t NewSize, uint32_t Alignment); +void AnsiFree(void* Ptr); + +// +// ANSI C memory allocator. +// + +class FMallocAnsi final : public FMalloc +{ +public: + FMallocAnsi(); + + virtual void* Malloc(size_t Size, uint32_t Alignment) override; + virtual void* TryMalloc(size_t Size, uint32_t Alignment) override; + virtual void* Realloc(void* Ptr, size_t NewSize, uint32_t Alignment) override; + virtual void* TryRealloc(void* Ptr, size_t NewSize, uint32_t Alignment) override; + virtual void Free(void* Ptr) override; + virtual bool GetAllocationSize(void* Original, size_t& SizeOut) override; +}; + +} // namespace zen diff --git a/src/zencore/include/zencore/memory/mallocmimalloc.h b/src/zencore/include/zencore/memory/mallocmimalloc.h new file mode 100644 index 000000000..759eeb4a6 --- /dev/null +++ b/src/zencore/include/zencore/memory/mallocmimalloc.h @@ -0,0 +1,36 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/memory/fmalloc.h> + +#if ZEN_USE_MIMALLOC +# define ZEN_MIMALLOC_ENABLED 1 +#endif + +#if !defined(ZEN_MIMALLOC_ENABLED) +# define ZEN_MIMALLOC_ENABLED 0 +#endif + +#if ZEN_MIMALLOC_ENABLED + +namespace zen { + +class FMallocMimalloc final : public FMalloc +{ +public: + FMallocMimalloc(); + virtual void* Malloc(size_t Size, uint32_t Alignment) override; + virtual void* TryMalloc(size_t Size, uint32_t Alignment) override; + virtual void* Realloc(void* Ptr, size_t NewSize, uint32_t Alignment) override; + virtual void* TryRealloc(void* Ptr, size_t NewSize, uint32_t Alignment) override; + virtual void Free(void* Ptr) override; + virtual void* MallocZeroed(size_t Count, uint32_t Alignment) override; + virtual void* TryMallocZeroed(size_t Count, uint32_t Alignment) override; + virtual bool GetAllocationSize(void* Original, size_t& SizeOut) override; + virtual void Trim(bool bTrimThreadCaches) override; +}; + +} // namespace zen + +#endif diff --git a/src/zencore/include/zencore/memory/mallocrpmalloc.h b/src/zencore/include/zencore/memory/mallocrpmalloc.h new file mode 100644 index 000000000..be2627b2d --- /dev/null +++ b/src/zencore/include/zencore/memory/mallocrpmalloc.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/memory/fmalloc.h> + +#if ZEN_USE_RPMALLOC +# define ZEN_RPMALLOC_ENABLED 1 +#endif + +#if !defined(ZEN_RPMALLOC_ENABLED) +# define ZEN_RPMALLOC_ENABLED 0 +#endif + +#if ZEN_RPMALLOC_ENABLED + +namespace zen { + +class FMallocRpmalloc final : public FMalloc +{ +public: + FMallocRpmalloc(); + ~FMallocRpmalloc(); + virtual void* Malloc(size_t Size, uint32_t Alignment) override; + virtual void* TryMalloc(size_t Size, uint32_t Alignment) override; + virtual void* Realloc(void* Ptr, size_t NewSize, uint32_t Alignment) override; + virtual void* TryRealloc(void* Ptr, size_t NewSize, uint32_t Alignment) override; + virtual void Free(void* Ptr) override; + virtual void* MallocZeroed(size_t Count, uint32_t Alignment) override; + virtual void* TryMallocZeroed(size_t Count, uint32_t Alignment) override; + virtual bool GetAllocationSize(void* Original, size_t& SizeOut) override; + virtual void Trim(bool bTrimThreadCaches) override; +}; + +} // namespace zen + +#endif diff --git a/src/zencore/include/zencore/memory/mallocstomp.h b/src/zencore/include/zencore/memory/mallocstomp.h new file mode 100644 index 000000000..5d83868bb --- /dev/null +++ b/src/zencore/include/zencore/memory/mallocstomp.h @@ -0,0 +1,100 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +#if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_MALLOC_STOMP 1 +#endif + +#ifndef ZEN_WITH_MALLOC_STOMP +# define ZEN_WITH_MALLOC_STOMP 0 +#endif + +/** + * Stomp memory allocator support should be enabled in Core.Build.cs. + * Run-time validation should be enabled using '-stompmalloc' command line argument. + */ + +#if ZEN_WITH_MALLOC_STOMP + +# include <zencore/memory/fmalloc.h> +# include <zencore/thread.h> + +namespace zen { + +/** + * Stomp memory allocator. It helps find the following errors: + * - Read or writes off the end of an allocation. + * - Read or writes off the beginning of an allocation. + * - Read or writes after freeing an allocation. + */ +class FMallocStomp final : public FMalloc +{ + struct FAllocationData; + + const size_t PageSize; + + /** If it is set to true, instead of focusing on overruns the allocator will focus on underruns. */ + const bool bUseUnderrunMode; + RwLock Lock; + + uintptr_t VirtualAddressCursor = 0; + size_t VirtualAddressMax = 0; + static constexpr size_t VirtualAddressBlockSize = 1 * 1024 * 1024 * 1024; // 1 GB blocks + +public: + // FMalloc interface. + explicit FMallocStomp(const bool InUseUnderrunMode = false); + + /** + * Allocates a block of a given number of bytes of memory with the required alignment. + * In the process it allocates as many pages as necessary plus one that will be protected + * making it unaccessible and causing an exception. The actual allocation will be pushed + * to the end of the last valid unprotected page. To deal with underrun errors a sentinel + * is added right before the allocation in page which is checked on free. + * + * @param Size Size in bytes of the memory block to allocate. + * @param Alignment Alignment in bytes of the memory block to allocate. + * @return A pointer to the beginning of the memory block. + */ + virtual void* Malloc(size_t Size, uint32_t Alignment) override; + + virtual void* TryMalloc(size_t Size, uint32_t Alignment) override; + + /** + * Changes the size of the memory block pointed to by OldPtr. + * The function may move the memory block to a new location. + * + * @param OldPtr Pointer to a memory block previously allocated with Malloc. + * @param NewSize New size in bytes for the memory block. + * @param Alignment Alignment in bytes for the reallocation. + * @return A pointer to the reallocated memory block, which may be either the same as ptr or a new location. + */ + virtual void* Realloc(void* InPtr, size_t NewSize, uint32_t Alignment) override; + + virtual void* TryRealloc(void* InPtr, size_t NewSize, uint32_t Alignment) override; + + /** + * Frees a memory allocation and verifies the sentinel in the process. + * + * @param InPtr Pointer of the data to free. + */ + virtual void Free(void* InPtr) override; + + /** + * If possible determine the size of the memory allocated at the given address. + * This will included all the pages that were allocated so it will be far more + * than what's set on the FAllocationData. + * + * @param Original - Pointer to memory we are checking the size of + * @param SizeOut - If possible, this value is set to the size of the passed in pointer + * @return true if succeeded + */ + virtual bool GetAllocationSize(void* Original, size_t& SizeOut) override; +}; + +} // namespace zen + +#endif // WITH_MALLOC_STOMP diff --git a/src/zencore/include/zencore/memory/memory.h b/src/zencore/include/zencore/memory/memory.h new file mode 100644 index 000000000..2fc20def6 --- /dev/null +++ b/src/zencore/include/zencore/memory/memory.h @@ -0,0 +1,78 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <stdlib.h> +#include <zencore/memory/fmalloc.h> + +#define UE_ALLOCATION_FUNCTION(...) + +namespace zen { + +/** + * Corresponds to UE-side FMemory implementation + */ + +class Memory +{ +public: + static void Initialize(); + + // + // C style memory allocation stubs that fall back to C runtime + // + UE_ALLOCATION_FUNCTION(1) static void* SystemMalloc(size_t Size); + static void SystemFree(void* Ptr); + + // + // C style memory allocation stubs. + // + + static inline void* Alloc(size_t Size, size_t Alignment = sizeof(void*)) { return Malloc(Size, uint32_t(Alignment)); } + + UE_ALLOCATION_FUNCTION(1, 2) static inline void* Malloc(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT); + UE_ALLOCATION_FUNCTION(2, 3) static inline void* Realloc(void* Original, size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT); + static inline void Free(void* Original); + static inline size_t GetAllocSize(void* Original); + + UE_ALLOCATION_FUNCTION(1, 2) static inline void* MallocZeroed(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT); + +private: + static void GCreateMalloc(); +}; + +inline void* +Memory::Malloc(size_t Count, uint32_t Alignment) +{ + return GMalloc->TryMalloc(Count, Alignment); +} + +inline void* +Memory::Realloc(void* Original, size_t Count, uint32_t Alignment) +{ + return GMalloc->TryRealloc(Original, Count, Alignment); +} + +inline void +Memory::Free(void* Original) +{ + if (Original) + { + GMalloc->Free(Original); + } +} + +inline size_t +Memory::GetAllocSize(void* Original) +{ + size_t Size = 0; + return GMalloc->GetAllocationSize(Original, Size) ? Size : 0; +} + +inline void* +Memory::MallocZeroed(size_t Count, uint32_t Alignment) +{ + return GMalloc->TryMallocZeroed(Count, Alignment); +} + +} // namespace zen diff --git a/src/zencore/include/zencore/memory/memorytrace.h b/src/zencore/include/zencore/memory/memorytrace.h new file mode 100644 index 000000000..6be7adb89 --- /dev/null +++ b/src/zencore/include/zencore/memory/memorytrace.h @@ -0,0 +1,251 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +#pragma once + +#include <zencore/enumflags.h> +#include <zencore/trace.h> + +#if !defined(UE_MEMORY_TRACE_AVAILABLE) +# define UE_MEMORY_TRACE_AVAILABLE 0 +#endif + +#if !defined(UE_MEMORY_TRACE_LATE_INIT) +# define UE_MEMORY_TRACE_LATE_INIT 0 +#endif + +#if !defined(PLATFORM_USES_FIXED_GMalloc_CLASS) +# define PLATFORM_USES_FIXED_GMalloc_CLASS 0 +#endif + +#if !defined(UE_MEMORY_TRACE_ENABLED) && UE_TRACE_ENABLED +# if UE_MEMORY_TRACE_AVAILABLE +# define UE_MEMORY_TRACE_ENABLED ZEN_WITH_MEMTRACK +# endif +#endif + +#if !defined(UE_MEMORY_TRACE_ENABLED) +# define UE_MEMORY_TRACE_ENABLED 0 +#endif + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +typedef uint32_t HeapId; + +//////////////////////////////////////////////////////////////////////////////// +enum EMemoryTraceRootHeap : uint8_t +{ + SystemMemory, // RAM + VideoMemory, // VRAM + EndHardcoded = VideoMemory, + EndReserved = 15 +}; + +//////////////////////////////////////////////////////////////////////////////// +// These values are traced. Do not modify existing values in order to maintain +// compatibility. +enum class EMemoryTraceHeapFlags : uint16_t +{ + None = 0, + Root = 1 << 0, + NeverFrees = 1 << 1, // The heap doesn't free (e.g. linear allocator) +}; +ENUM_CLASS_FLAGS(EMemoryTraceHeapFlags); + +//////////////////////////////////////////////////////////////////////////////// +// These values are traced. Do not modify existing values in order to maintain +// compatibility. +enum class EMemoryTraceHeapAllocationFlags : uint8_t +{ + None = 0, + Heap = 1 << 0, // Is a heap, can be used to unmark alloc as heap. + Swap = 2 << 0, // Is a swap page +}; +ENUM_CLASS_FLAGS(EMemoryTraceHeapAllocationFlags); + +//////////////////////////////////////////////////////////////////////////////// +enum class EMemoryTraceSwapOperation : uint8 +{ + PageOut = 0, // Paged out to swap + PageIn = 1, // Read from swap via page fault + FreeInSwap = 2, // Freed while being paged out in swap +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Internal options for early initialization of memory tracing systems. Exposed +// here due to visibility in platform implementations. +enum class EMemoryTraceInit : uint8 +{ + Disabled = 0, + AllocEvents = 1 << 0, + Callstacks = 1 << 1, + Tags = 1 << 2, + Full = AllocEvents | Callstacks | Tags, + Light = AllocEvents | Tags, +}; + +ENUM_CLASS_FLAGS(EMemoryTraceInit); + +//////////////////////////////////////////////////////////////////////////////// +#if UE_MEMORY_TRACE_ENABLED + +# define UE_MEMORY_TRACE(x) x + +UE_TRACE_CHANNEL_EXTERN(MemAllocChannel); + +//////////////////////////////////////////////////////////////////////////////// +class FMalloc* MemoryTrace_Create(class FMalloc* InMalloc, const TraceOptions& Options); +void MemoryTrace_Initialize(); +void MemoryTrace_Shutdown(); + +/** + * Register a new heap specification (name). Use the returned value when marking heaps. + * @param ParentId Heap id of parent heap. + * @param Name Descriptive name of the heap. + * @param Flags Properties of this heap. See \ref EMemoryTraceHeapFlags + * @return Heap id to use when allocating memory + */ +HeapId MemoryTrace_HeapSpec(HeapId ParentId, const char16_t* Name, EMemoryTraceHeapFlags Flags = EMemoryTraceHeapFlags::None); + +/** + * Register a new root heap specification (name). Use the returned value as parent to other heaps. + * @param Name Descriptive name of the root heap. + * @param Flags Properties of the this root heap. See \ref EMemoryTraceHeapFlags + * @return Heap id to use when allocating memory + */ +HeapId MemoryTrace_RootHeapSpec(const char16_t* Name, EMemoryTraceHeapFlags Flags = EMemoryTraceHeapFlags::None); + +/** + * Mark a traced allocation as being a heap. + * @param Address Address of the allocation + * @param Heap Heap id, see /ref MemoryTrace_HeapSpec. If no specific heap spec has been created the correct root heap needs to be given. + * @param Flags Additional properties of the heap allocation. Note that \ref EMemoryTraceHeapAllocationFlags::Heap is implicit. + * @param ExternalCallstackId CallstackId to use, if 0 will use current callstack id. + */ +void MemoryTrace_MarkAllocAsHeap(uint64 Address, + HeapId Heap, + EMemoryTraceHeapAllocationFlags Flags = EMemoryTraceHeapAllocationFlags::None, + uint32 ExternalCallstackId = 0); + +/** + * Unmark an allocation as a heap. When an allocation that has previously been used as a heap is reused as a regular + * allocation. + * @param Address Address of the allocation + * @param Heap Heap id + * @param ExternalCallstackId CallstackId to use, if 0 will use current callstack id. + */ +void MemoryTrace_UnmarkAllocAsHeap(uint64 Address, HeapId Heap, uint32 ExternalCallstackId = 0); + +/** + * Trace an allocation event. + * @param Address Address of allocation + * @param Size Size of allocation + * @param Alignment Alignment of the allocation + * @param RootHeap Which root heap this belongs to (system memory, video memory etc) + * @param ExternalCallstackId CallstackId to use, if 0 will use current callstack id. + */ +void MemoryTrace_Alloc(uint64 Address, + uint64 Size, + uint32 Alignment, + HeapId RootHeap = EMemoryTraceRootHeap::SystemMemory, + uint32 ExternalCallstackId = 0); + +/** + * Trace a free event. + * @param Address Address of the allocation being freed + * @param RootHeap Which root heap this belongs to (system memory, video memory etc) + * @param ExternalCallstackId CallstackId to use, if 0 will use current callstack id. + */ +void MemoryTrace_Free(uint64 Address, HeapId RootHeap = EMemoryTraceRootHeap::SystemMemory, uint32 ExternalCallstackId = 0); + +/** + * Trace a free related to a reallocation event. + * @param Address Address of the allocation being freed + * @param RootHeap Which root heap this belongs to (system memory, video memory etc) + * @param ExternalCallstackId CallstackId to use, if 0 will use current callstack id. + */ +void MemoryTrace_ReallocFree(uint64 Address, HeapId RootHeap = EMemoryTraceRootHeap::SystemMemory, uint32 ExternalCallstackId = 0); + +/** Trace an allocation related to a reallocation event. + * @param Address Address of allocation + * @param NewSize Size of allocation + * @param Alignment Alignment of the allocation + * @param RootHeap Which root heap this belongs to (system memory, video memory etc) + * @param ExternalCallstackId CallstackId to use, if 0 will use current callstack id. + */ +void MemoryTrace_ReallocAlloc(uint64 Address, + uint64 NewSize, + uint32 Alignment, + HeapId RootHeap = EMemoryTraceRootHeap::SystemMemory, + uint32 ExternalCallstackId = 0); + +/** Trace a swap operation. Only available for system memory root heap (EMemoryTraceRootHeap::SystemMemory). + * @param PageAddress Page address for operation, in case of PageIn can be address of the page fault (not aligned to page boundary). + * @param SwapOperation Which swap operation is happening to the address. + * @param CompressedSize Compressed size of the page for page out operation. + * @param CallstackId CallstackId to use, if 0 to ignore (will not use current callstack id). + */ +void MemoryTrace_SwapOp(uint64 PageAddress, EMemoryTraceSwapOperation SwapOperation, uint32 CompressedSize = 0, uint32 CallstackId = 0); + +//////////////////////////////////////////////////////////////////////////////// +#else // UE_MEMORY_TRACE_ENABLED + +# define UE_MEMORY_TRACE(x) +inline HeapId +MemoryTrace_RootHeapSpec(const char16_t* /*Name*/, EMemoryTraceHeapFlags /* Flags = EMemoryTraceHeapFlags::None */) +{ + return ~0u; +}; +inline HeapId +MemoryTrace_HeapSpec(HeapId /*ParentId*/, const char16_t* /*Name*/, EMemoryTraceHeapFlags /* Flags = EMemoryTraceHeapFlags::None */) +{ + return ~0u; +} +inline void +MemoryTrace_MarkAllocAsHeap(uint64_t /*Address*/, HeapId /*Heap*/) +{ +} +inline void +MemoryTrace_UnmarkAllocAsHeap(uint64_t /*Address*/, HeapId /*Heap*/) +{ +} +inline void +MemoryTrace_Alloc(uint64_t /*Address*/, + uint64_t /*Size*/, + uint32_t /*Alignment*/, + HeapId RootHeap = EMemoryTraceRootHeap::SystemMemory, + uint32_t ExternalCallstackId = 0) +{ + ZEN_UNUSED(RootHeap, ExternalCallstackId); +} +inline void +MemoryTrace_Free(uint64_t /*Address*/, HeapId RootHeap = EMemoryTraceRootHeap::SystemMemory, uint32_t ExternalCallstackId = 0) +{ + ZEN_UNUSED(RootHeap, ExternalCallstackId); +} +inline void +MemoryTrace_ReallocFree(uint64_t /*Address*/, HeapId RootHeap = EMemoryTraceRootHeap::SystemMemory, uint32_t ExternalCallstackId = 0) +{ + ZEN_UNUSED(RootHeap, ExternalCallstackId); +} +inline void +MemoryTrace_ReallocAlloc(uint64_t /*Address*/, + uint64_t /*NewSize*/, + uint32_t /*Alignment*/, + HeapId RootHeap = EMemoryTraceRootHeap::SystemMemory, + uint32_t ExternalCallstackId = 0) +{ + ZEN_UNUSED(RootHeap, ExternalCallstackId); +} +inline void +MemoryTrace_SwapOp(uint64_t /*PageAddress*/, + EMemoryTraceSwapOperation /*SwapOperation*/, + uint32_t CompressedSize = 0, + uint32_t CallstackId = 0) +{ + ZEN_UNUSED(CompressedSize, CallstackId); +} + +#endif // UE_MEMORY_TRACE_ENABLED + +} // namespace zen diff --git a/src/zencore/include/zencore/memory/newdelete.h b/src/zencore/include/zencore/memory/newdelete.h new file mode 100644 index 000000000..2ec92b91b --- /dev/null +++ b/src/zencore/include/zencore/memory/newdelete.h @@ -0,0 +1,168 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> +#include <new> + +#if defined(_MSC_VER) +# if (_MSC_VER >= 1900) && !defined(__EDG__) +# define ZEN_RESTRICT __declspec(allocator) __declspec(restrict) +# else +# define ZEN_RESTRICT __declspec(restrict) +# endif +#else +# define ZEN_RESTRICT +#endif + +////////////////////////////////////////////////////////////////////////// + +[[nodiscard]] ZEN_RESTRICT void* zen_new(size_t size); +[[nodiscard]] ZEN_RESTRICT void* zen_new_aligned(size_t size, size_t alignment); +[[nodiscard]] ZEN_RESTRICT void* zen_new_nothrow(size_t size) noexcept; +[[nodiscard]] ZEN_RESTRICT void* zen_new_aligned_nothrow(size_t size, size_t alignment) noexcept; + +void zen_free(void* p) noexcept; +void zen_free_size(void* p, size_t size) noexcept; +void zen_free_size_aligned(void* p, size_t size, size_t alignment) noexcept; +void zen_free_aligned(void* p, size_t alignment) noexcept; + +////////////////////////////////////////////////////////////////////////// + +#if defined(_MSC_VER) && defined(_Ret_notnull_) && defined(_Post_writable_byte_size_) +# define zen_decl_new(n) [[nodiscard]] _VCRT_ALLOCATOR _Ret_notnull_ _Post_writable_byte_size_(n) +# define zen_decl_new_nothrow(n) [[nodiscard]] _VCRT_ALLOCATOR _Ret_maybenull_ _Success_(return != NULL) _Post_writable_byte_size_(n) +#else +# define zen_decl_new(n) [[nodiscard]] +# define zen_decl_new_nothrow(n) [[nodiscard]] +#endif + +void +operator delete(void* p) noexcept +{ + zen_free(p); +} + +void +operator delete[](void* p) noexcept +{ + zen_free(p); +} + +void +operator delete(void* p, const std::nothrow_t&) noexcept +{ + zen_free(p); +} + +void +operator delete[](void* p, const std::nothrow_t&) noexcept +{ + zen_free(p); +} + +zen_decl_new(n) void* +operator new(std::size_t n) noexcept(false) +{ + return zen_new(n); +} + +zen_decl_new(n) void* +operator new[](std::size_t n) noexcept(false) +{ + return zen_new(n); +} + +zen_decl_new_nothrow(n) void* +operator new(std::size_t n, const std::nothrow_t& tag) noexcept +{ + (void)(tag); + return zen_new_nothrow(n); +} + +zen_decl_new_nothrow(n) void* +operator new[](std::size_t n, const std::nothrow_t& tag) noexcept +{ + (void)(tag); + return zen_new_nothrow(n); +} + +#if (__cplusplus >= 201402L || _MSC_VER >= 1916) +void +operator delete(void* p, std::size_t n) noexcept +{ + zen_free_size(p, n); +}; +void +operator delete[](void* p, std::size_t n) noexcept +{ + zen_free_size(p, n); +}; +#endif + +#if (__cplusplus > 201402L || defined(__cpp_aligned_new)) +void +operator delete(void* p, std::align_val_t al) noexcept +{ + zen_free_aligned(p, static_cast<size_t>(al)); +} +void +operator delete[](void* p, std::align_val_t al) noexcept +{ + zen_free_aligned(p, static_cast<size_t>(al)); +} +void +operator delete(void* p, std::size_t n, std::align_val_t al) noexcept +{ + zen_free_size_aligned(p, n, static_cast<size_t>(al)); +}; +void +operator delete[](void* p, std::size_t n, std::align_val_t al) noexcept +{ + zen_free_size_aligned(p, n, static_cast<size_t>(al)); +}; +void +operator delete(void* p, std::align_val_t al, const std::nothrow_t&) noexcept +{ + zen_free_aligned(p, static_cast<size_t>(al)); +} +void +operator delete[](void* p, std::align_val_t al, const std::nothrow_t&) noexcept +{ + zen_free_aligned(p, static_cast<size_t>(al)); +} + +void* +operator new(std::size_t n, std::align_val_t al) noexcept(false) +{ + return zen_new_aligned(n, static_cast<size_t>(al)); +} +void* +operator new[](std::size_t n, std::align_val_t al) noexcept(false) +{ + return zen_new_aligned(n, static_cast<size_t>(al)); +} +void* +operator new(std::size_t n, std::align_val_t al, const std::nothrow_t&) noexcept +{ + return zen_new_aligned_nothrow(n, static_cast<size_t>(al)); +} +void* +operator new[](std::size_t n, std::align_val_t al, const std::nothrow_t&) noexcept +{ + return zen_new_aligned_nothrow(n, static_cast<size_t>(al)); +} +#endif + +// EASTL operator new + +void* operator new[](size_t size, const char* pName, int flags, unsigned debugFlags, const char* file, int line); + +void* operator new[](size_t size, + size_t alignment, + size_t alignmentOffset, + const char* pName, + int flags, + unsigned debugFlags, + const char* file, + int line);
\ No newline at end of file diff --git a/src/zencore/include/zencore/memory/tagtrace.h b/src/zencore/include/zencore/memory/tagtrace.h new file mode 100644 index 000000000..8b5fc0e67 --- /dev/null +++ b/src/zencore/include/zencore/memory/tagtrace.h @@ -0,0 +1,94 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +#pragma once + +#include <zenbase/zenbase.h> +#include <zencore/trace.h> + +//////////////////////////////////////////////////////////////////////////////// + +namespace zen { + +enum class ELLMTag : uint8_t; +struct FLLMTag; + +int32_t MemoryTrace_AnnounceCustomTag(int32_t Tag, int32_t ParentTag, const char* Display); +int32_t MemoryTrace_GetActiveTag(); + +inline constexpr int32_t TRACE_TAG = 257; + +} // namespace zen + +//////////////////////////////////////////////////////////////////////////////// +#if !defined(UE_MEMORY_TAGS_TRACE_ENABLED) +# define UE_MEMORY_TAGS_TRACE_ENABLED 1 +#endif + +#if UE_MEMORY_TAGS_TRACE_ENABLED && UE_TRACE_ENABLED + +namespace zen { +//////////////////////////////////////////////////////////////////////////////// + +/** + * Used to associate any allocation within this scope to a given tag. + * + * We need to be able to convert the three types of inputs to LLM scopes: + * - ELLMTag, an uint8 with fixed categories. There are three sub ranges + Generic tags, platform and project tags. + * - FName, free form string, for example a specific asset. + * - TagData, an opaque pointer from LLM. + * + */ +class FMemScope +{ +public: + FMemScope(); // Used with SetTagAndActivate + FMemScope(int32_t InTag, bool bShouldActivate = true); + FMemScope(FLLMTag InTag, bool bShouldActivate = true); + FMemScope(ELLMTag InTag, bool bShouldActivate = true); + ~FMemScope(); + +private: + void ActivateScope(int32_t InTag); + UE::Trace::Private::FScopedLogScope Inner; + int32_t PrevTag; +}; + +/** + * A scope that activates in case no existing scope is active. + */ +template<typename TagType> +class FDefaultMemScope : public FMemScope +{ +public: + FDefaultMemScope(TagType InTag) : FMemScope(InTag, MemoryTrace_GetActiveTag() == 0) {} +}; + +/** + * Used order to keep the tag for memory that is being reallocated. + */ +class FMemScopePtr +{ +public: + FMemScopePtr(uint64_t InPtr); + ~FMemScopePtr(); + +private: + UE::Trace::Private::FScopedLogScope Inner; +}; + +//////////////////////////////////////////////////////////////////////////////// +# define ZEN_MEMSCOPE(InTag) FMemScope PREPROCESSOR_JOIN(MemScope, __LINE__)(InTag); +# define ZEN_MEMSCOPE_PTR(InPtr) FMemScopePtr PREPROCESSOR_JOIN(MemPtrScope, __LINE__)((uint64)InPtr); +# define ZEN_MEMSCOPE_DEFAULT(InTag) FDefaultMemScope PREPROCESSOR_JOIN(MemScope, __LINE__)(InTag); +# define ZEN_MEMSCOPE_UNINITIALIZED(Line) FMemScope PREPROCESSOR_JOIN(MemScope, Line); + +#else // UE_MEMORY_TAGS_TRACE_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +# define ZEN_MEMSCOPE(...) +# define ZEN_MEMSCOPE_PTR(...) +# define ZEN_MEMSCOPE_DEFAULT(...) +# define ZEN_MEMSCOPE_UNINITIALIZED(...) + +#endif // UE_MEMORY_TAGS_TRACE_ENABLED +} diff --git a/src/zencore/include/zencore/memory.h b/src/zencore/include/zencore/memoryview.h index a1f48555e..92f5aea3f 100644 --- a/src/zencore/include/zencore/memory.h +++ b/src/zencore/include/zencore/memoryview.h @@ -22,58 +22,10 @@ template<typename T> concept ContiguousRange = true; #endif -struct MemoryView; - -class MemoryArena -{ -public: - ZENCORE_API MemoryArena(); - ZENCORE_API ~MemoryArena(); - - ZENCORE_API void* Alloc(size_t Size, size_t Alignment); - ZENCORE_API void Free(void* Ptr); - -private: -}; - -class Memory -{ -public: - ZENCORE_API static void* Alloc(size_t Size, size_t Alignment = sizeof(void*)); - ZENCORE_API static void Free(void* Ptr); -}; - -/** Allocator which claims fixed-size blocks from the underlying allocator. - - There is no way to free individual memory blocks. - - \note This is not thread-safe, you will need to provide synchronization yourself -*/ - -class ChunkingLinearAllocator -{ -public: - explicit ChunkingLinearAllocator(uint64_t ChunkSize, uint64_t ChunkAlignment = sizeof(std::max_align_t)); - ~ChunkingLinearAllocator(); - - ZENCORE_API void Reset(); - - ZENCORE_API void* Alloc(size_t Size, size_t Alignment = sizeof(void*)); - inline void Free(void* Ptr) { ZEN_UNUSED(Ptr); /* no-op */ } - - ChunkingLinearAllocator(const ChunkingLinearAllocator&) = delete; - ChunkingLinearAllocator& operator=(const ChunkingLinearAllocator&) = delete; - -private: - uint8_t* m_ChunkCursor = nullptr; - uint64_t m_ChunkBytesRemain = 0; - const uint64_t m_ChunkSize = 0; - const uint64_t m_ChunkAlignment = 0; - std::vector<void*> m_ChunkList; -}; - ////////////////////////////////////////////////////////////////////////// +struct MemoryView; + struct MutableMemoryView { MutableMemoryView() = default; @@ -141,7 +93,10 @@ struct MutableMemoryView inline void MidInline(uint64_t InOffset, uint64_t InSize = ~uint64_t(0)) { RightChopInline(InOffset); - LeftInline(InSize); + if (InSize != ~uint64_t(0)) + { + LeftInline(InSize); + } } /** Returns the middle part of the view by taking up to the given number of bytes from the given position. */ @@ -271,7 +226,10 @@ struct MemoryView inline void MidInline(uint64_t InOffset, uint64_t InSize = ~uint64_t(0)) { RightChopInline(InOffset); - LeftInline(InSize); + if (InSize != ~uint64_t(0)) + { + LeftInline(InSize); + } } /** Returns the middle part of the view by taking up to the given number of bytes from the given position. */ @@ -350,7 +308,7 @@ template<ContiguousRange R> MakeMemoryView(const R& Container) { std::span Span = Container; - return MemoryView(Span.data(), Span.size() * sizeof(typename decltype(Span)::element_type)); + return MemoryView(Span.data(), Span.size_bytes()); } /** Make a non-owning const view starting at Data and ending at DataEnd. */ @@ -385,7 +343,7 @@ template<ContiguousRange R> MakeMutableMemoryView(R& Container) { std::span Span = Container; - return MutableMemoryView(Span.data(), Span.size() * sizeof(typename decltype(Span)::element_type)); + return MutableMemoryView(Span.data(), Span.size_bytes()); } /** Make a non-owning mutable view starting at Data and ending at DataEnd. */ diff --git a/src/zencore/include/zencore/parallelwork.h b/src/zencore/include/zencore/parallelwork.h new file mode 100644 index 000000000..05146d644 --- /dev/null +++ b/src/zencore/include/zencore/parallelwork.h @@ -0,0 +1,80 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/scopeguard.h> +#include <zencore/thread.h> +#include <zencore/workthreadpool.h> + +#include <atomic> + +namespace zen { + +class ParallelWork +{ +public: + ParallelWork(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, WorkerThreadPool::EMode Mode); + + ~ParallelWork(); + + typedef std::function<void(std::atomic<bool>& AbortFlag)> WorkCallback; + typedef std::function<void(std::exception_ptr Ex, std::atomic<bool>& AbortFlag)> ExceptionCallback; + typedef std::function<void(bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork)> UpdateCallback; + + void ScheduleWork(WorkerThreadPool& WorkerPool, WorkCallback&& Work, ExceptionCallback&& OnError = {}) + { + m_PendingWork.AddCount(1); + try + { + WorkerPool.ScheduleWork( + [this, Work = std::move(Work), OnError = OnError ? std::move(OnError) : DefaultErrorFunction()] { + auto _ = MakeGuard([this]() { m_PendingWork.CountDown(); }); + try + { + while (m_PauseFlag && !m_AbortFlag) + { + Sleep(2000); + } + Work(m_AbortFlag); + } + catch (...) + { + OnError(std::current_exception(), m_AbortFlag); + } + }, + m_Mode); + } + catch (const std::exception&) + { + m_PendingWork.CountDown(); + throw; + } + } + + void Abort() { m_AbortFlag = true; } + + bool IsAborted() const { return m_AbortFlag.load(); } + + void Wait(int32_t UpdateIntervalMS, UpdateCallback&& UpdateCallback); + + void Wait(); + + Latch& PendingWork() { return m_PendingWork; } + +private: + ExceptionCallback DefaultErrorFunction(); + void RethrowErrors(); + + std::atomic<bool>& m_AbortFlag; + std::atomic<bool>& m_PauseFlag; + const WorkerThreadPool::EMode m_Mode; + bool m_DispatchComplete = false; + Latch m_PendingWork; + + RwLock m_ErrorLock; + std::vector<std::exception_ptr> m_Errors; +}; + +void parallellwork_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h index d90a32301..04b79a1e0 100644 --- a/src/zencore/include/zencore/process.h +++ b/src/zencore/include/zencore/process.h @@ -22,14 +22,18 @@ public: ZENCORE_API ~ProcessHandle(); ZENCORE_API void Initialize(int Pid); + ZENCORE_API void Initialize(int Pid, std::error_code& OutEc); ZENCORE_API void Initialize(void* ProcessHandle); /// Initialize with an existing handle - takes ownership of the handle ZENCORE_API [[nodiscard]] bool IsRunning() const; ZENCORE_API [[nodiscard]] bool IsValid() const; ZENCORE_API bool Wait(int TimeoutMs = -1); + ZENCORE_API bool Wait(int TimeoutMs, std::error_code& OutEc); ZENCORE_API int WaitExitCode(); - ZENCORE_API void Terminate(int ExitCode); + ZENCORE_API int GetExitCode(); + ZENCORE_API bool Terminate(int ExitCode); ZENCORE_API void Reset(); [[nodiscard]] inline int Pid() const { return m_Pid; } + [[nodiscard]] inline void* Handle() const { return m_ProcessHandle; } private: void* m_ProcessHandle = nullptr; @@ -48,6 +52,7 @@ struct CreateProcOptions Flag_NewConsole = 1 << 0, Flag_Elevated = 1 << 1, Flag_Unelevated = 1 << 2, + Flag_NoConsole = 1 << 3, }; const std::filesystem::path* WorkingDirectory = nullptr; @@ -90,9 +95,17 @@ private: }; ZENCORE_API bool IsProcessRunning(int pid); +ZENCORE_API bool IsProcessRunning(int pid, std::error_code& OutEc); ZENCORE_API int GetCurrentProcessId(); int GetProcessId(CreateProcResult ProcId); +std::filesystem::path GetProcessExecutablePath(int Pid, std::error_code& OutEc); +std::error_code FindProcess(const std::filesystem::path& ExecutableImage, ProcessHandle& OutHandle, bool IncludeSelf = true); + +#if ZEN_PLATFORM_LINUX +void IgnoreChildSignals(); +#endif + void process_forcelink(); // internal } // namespace zen diff --git a/src/zencore/include/zencore/scopeguard.h b/src/zencore/include/zencore/scopeguard.h index d04c8ed9c..3fd0564f6 100644 --- a/src/zencore/include/zencore/scopeguard.h +++ b/src/zencore/include/zencore/scopeguard.h @@ -21,7 +21,11 @@ public: { m_guardFunc(); } - catch (std::exception& Ex) + catch (const AssertException& Ex) + { + ZEN_ERROR("Assert exception in scope guard: {}", Ex.FullDescription()); + } + catch (const std::exception& Ex) { ZEN_ERROR("scope guard threw exception: '{}'", Ex.what()); } diff --git a/src/zencore/include/zencore/sentryintegration.h b/src/zencore/include/zencore/sentryintegration.h new file mode 100644 index 000000000..faf1238b7 --- /dev/null +++ b/src/zencore/include/zencore/sentryintegration.h @@ -0,0 +1,60 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/intmath.h> +#include <zencore/zencore.h> + +#if !defined(ZEN_USE_SENTRY) +# define ZEN_USE_SENTRY 1 +#endif + +#if ZEN_USE_SENTRY + +# include <memory> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <spdlog/logger.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace sentry { + +struct SentryAssertImpl; + +} // namespace sentry + +namespace zen { + +class SentryIntegration +{ +public: + SentryIntegration(); + ~SentryIntegration(); + + struct Config + { + std::string DatabasePath; + std::string AttachmentsPath; + std::string Dsn; + std::string Environment; + bool AllowPII = false; + bool Debug = false; + }; + + void Initialize(const Config& Conf, const std::string& CommandLine); + void LogStartupInformation(); + static void ClearCaches(); + +private: + int m_SentryErrorCode = 0; + bool m_IsInitialized = false; + bool m_AllowPII = false; + std::unique_ptr<sentry::SentryAssertImpl> m_SentryAssert; + std::string m_SentryUserName; + std::string m_SentryHostName; + std::string m_SentryId; + std::shared_ptr<spdlog::logger> m_SentryLogger; +}; + +} // namespace zen +#endif diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h index 618bd2937..7df5109cb 100644 --- a/src/zencore/include/zencore/sharedbuffer.h +++ b/src/zencore/include/zencore/sharedbuffer.h @@ -6,7 +6,7 @@ #include <zenbase/refcount.h> #include <zencore/iobuffer.h> -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <memory.h> @@ -114,6 +114,17 @@ public: [[nodiscard]] bool IsOwned() const { return !m_Buffer || m_Buffer->IsOwned(); } [[nodiscard]] inline bool IsNull() const { return !m_Buffer; } inline void Reset() { m_Buffer = nullptr; } + inline bool GetFileReference(IoBufferFileReference& OutRef) const + { + if (const IoBufferExtendedCore* Core = m_Buffer->ExtendedCore()) + { + return Core->GetFileReference(OutRef); + } + else + { + return false; + } + } [[nodiscard]] MemoryView GetView() const { diff --git a/src/zencore/include/zencore/stream.h b/src/zencore/include/zencore/stream.h index a9d35ef1b..77e763518 100644 --- a/src/zencore/include/zencore/stream.h +++ b/src/zencore/include/zencore/stream.h @@ -4,7 +4,7 @@ #include "zencore.h" -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <zencore/thread.h> #include <vector> @@ -34,8 +34,8 @@ public: inline const uint8_t* Data() const { return m_Buffer.data(); } inline const uint8_t* GetData() const { return m_Buffer.data(); } - inline uint64_t Size() const { return m_Buffer.size(); } - inline uint64_t GetSize() const { return m_Buffer.size(); } + inline uint64_t Size() const { return m_Offset; } + inline uint64_t GetSize() const { return m_Offset; } void Reset(); inline MemoryView GetView() const { return MemoryView(m_Buffer.data(), m_Offset); } @@ -70,14 +70,25 @@ public: inline void Read(void* DataPtr, size_t ByteCount) { + ZEN_ASSERT(m_Offset + ByteCount <= m_BufferSize); memcpy(DataPtr, m_BufferBase + m_Offset, ByteCount); m_Offset += ByteCount; } + inline MemoryView GetView(size_t ByteCount) const + { + ZEN_ASSERT(m_Offset + ByteCount <= m_BufferSize); + return MemoryView((const void*)(m_BufferBase + m_Offset), (const void*)(m_BufferBase + m_Offset + ByteCount)); + } inline uint64_t Size() const { return m_BufferSize; } inline uint64_t GetSize() const { return Size(); } inline uint64_t CurrentOffset() const { return m_Offset; } - inline void Skip(size_t ByteCount) { m_Offset += ByteCount; }; + inline uint64_t Remaining() const { return m_BufferSize - m_Offset; } + inline void Skip(size_t ByteCount) + { + ZEN_ASSERT(m_Offset + ByteCount <= m_BufferSize); + m_Offset += ByteCount; + }; protected: const uint8_t* m_BufferBase; diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h index da8deb425..ef7cd36ea 100644 --- a/src/zencore/include/zencore/string.h +++ b/src/zencore/include/zencore/string.h @@ -51,6 +51,36 @@ StringLength(const wchar_t* str) return wcslen(str); } +inline bool +StringCompare(const char16_t* s1, const char16_t* s2) +{ + char16_t c1, c2; + + while ((c1 = *s1) == (c2 = *s2)) + { + if (c1 == 0) + { + return 0; + } + + ++s1; + ++s2; + } + return uint16_t(c1) - uint16_t(c2); +} + +inline bool +StringEquals(const char16_t* s1, const char16_t* s2) +{ + return StringCompare(s1, s2) == 0; +} + +inline size_t +StringLength(const char16_t* str) +{ + return std::char_traits<char16_t>::length(str); +} + ////////////////////////////////////////////////////////////////////////// // File name helpers // @@ -492,6 +522,9 @@ public: ////////////////////////////////////////////////////////////////////////// +bool IsValidUtf8(const std::string_view& str); +std::string_view::const_iterator FindFirstInvalidUtf8Byte(const std::string_view& str); + void Utf8ToWide(const char8_t* str, WideStringBuilderBase& out); void Utf8ToWide(const std::u8string_view& wstr, WideStringBuilderBase& out); void Utf8ToWide(const std::string_view& wstr, WideStringBuilderBase& out); @@ -646,6 +679,9 @@ ParseHexNumber(const std::string_view HexString, UnsignedIntegral auto& OutValue return ParseHexNumber(HexString.data(), ExpectedCharacterCount, (uint8_t*)&OutValue); } +void UrlDecode(std::string_view InUrl, StringBuilderBase& OutUrl); +std::string UrlDecode(std::string_view InUrl); + ////////////////////////////////////////////////////////////////////////// // Format numbers for humans // diff --git a/src/zencore/include/zencore/testutils.h b/src/zencore/include/zencore/testutils.h index f5526945a..45fde4eda 100644 --- a/src/zencore/include/zencore/testutils.h +++ b/src/zencore/include/zencore/testutils.h @@ -18,7 +18,7 @@ public: ScopedTemporaryDirectory(); ~ScopedTemporaryDirectory(); - std::filesystem::path& Path() { return m_RootPath; } + const std::filesystem::path& Path() const { return m_RootPath; } private: std::filesystem::path m_RootPath; @@ -33,6 +33,16 @@ struct ScopedCurrentDirectoryChange }; IoBuffer CreateRandomBlob(uint64_t Size); +IoBuffer CreateSemiRandomBlob(uint64_t Size); + +struct FalseType +{ + static const bool Enabled = false; +}; +struct TrueType +{ + static const bool Enabled = true; +}; } // namespace zen diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h index 2d0ef7396..d9fb5c023 100644 --- a/src/zencore/include/zencore/thread.h +++ b/src/zencore/include/zencore/thread.h @@ -35,6 +35,16 @@ public: struct SharedLockScope { + SharedLockScope(const SharedLockScope& Rhs) = delete; + SharedLockScope(SharedLockScope&& Rhs) : m_Lock(Rhs.m_Lock) { Rhs.m_Lock = nullptr; } + SharedLockScope& operator=(SharedLockScope&& Rhs) + { + ReleaseNow(); + m_Lock = Rhs.m_Lock; + Rhs.m_Lock = nullptr; + return *this; + } + SharedLockScope& operator=(const SharedLockScope& Rhs) = delete; SharedLockScope(RwLock& Lock) : m_Lock(&Lock) { Lock.AcquireShared(); } ~SharedLockScope() { ReleaseNow(); } @@ -128,8 +138,8 @@ public: ZENCORE_API explicit NamedEvent(std::string_view EventName); ZENCORE_API ~NamedEvent(); ZENCORE_API void Close(); - ZENCORE_API void Set(); - ZENCORE_API bool Wait(int TimeoutMs = -1); + ZENCORE_API std::error_code Set(); + ZENCORE_API bool Wait(int TimeoutMs = -1); NamedEvent(NamedEvent&& Rhs) noexcept : m_EventHandle(Rhs.m_EventHandle) { Rhs.m_EventHandle = nullptr; } @@ -173,6 +183,7 @@ public: void CountDown() { std::ptrdiff_t Old = Counter.fetch_sub(1); + ZEN_ASSERT(Old > 0); if (Old == 1) { Complete.Set(); @@ -187,8 +198,7 @@ public: void AddCount(std::ptrdiff_t Count) { std::atomic_ptrdiff_t Old = Counter.fetch_add(Count); - ZEN_UNUSED(Old); - ZEN_ASSERT_SLOW(Old > 0); + ZEN_ASSERT(Old > 0); } bool Wait(int TimeoutMs = -1) @@ -206,6 +216,23 @@ private: Event Complete; }; +inline void +SetAtomicMax(std::atomic_uint64_t& Max, uint64_t Value) +{ + while (true) + { + uint64_t CurrentMax = Max.load(); + if (Value <= CurrentMax) + { + return; + } + if (Max.compare_exchange_strong(CurrentMax, Value)) + { + return; + } + } +} + ZENCORE_API int GetCurrentThreadId(); ZENCORE_API void Sleep(int ms); diff --git a/src/zencore/include/zencore/timer.h b/src/zencore/include/zencore/timer.h index e4ddc3505..767dc4314 100644 --- a/src/zencore/include/zencore/timer.h +++ b/src/zencore/include/zencore/timer.h @@ -21,6 +21,10 @@ ZENCORE_API uint64_t GetHifreqTimerFrequency(); ZENCORE_API double GetHifreqTimerToSeconds(); ZENCORE_API uint64_t GetHifreqTimerFrequencySafe(); // May be used during static init +// Query time since process was spawned (returns time in ms) + +uint64_t GetTimeSinceProcessStart(); + class Stopwatch { public: diff --git a/src/zencore/include/zencore/trace.h b/src/zencore/include/zencore/trace.h index 89e4b76bf..99a565151 100644 --- a/src/zencore/include/zencore/trace.h +++ b/src/zencore/include/zencore/trace.h @@ -19,19 +19,25 @@ ZEN_THIRD_PARTY_INCLUDES_END #define ZEN_TRACE_CPU(x) TRACE_CPU_SCOPE(x) #define ZEN_TRACE_CPU_FLUSH(x) TRACE_CPU_SCOPE(x, trace::CpuScopeFlags::CpuFlush) -enum class TraceType +namespace zen { + +struct TraceOptions { - File, - Network, - None + std::string Host; + std::string File; + std::string Channels; }; void TraceInit(std::string_view ProgramName); void TraceShutdown(); bool IsTracing(); -void TraceStart(std::string_view ProgramName, const char* HostOrPath, TraceType Type); bool TraceStop(); +bool GetTraceOptionsFromCommandline(TraceOptions& OutOptions); +void TraceConfigure(const TraceOptions& Options); + +} + #else #define ZEN_TRACE_CPU(x) diff --git a/src/zencore/include/zencore/uid.h b/src/zencore/include/zencore/uid.h index 3abec9d16..0c1079444 100644 --- a/src/zencore/include/zencore/uid.h +++ b/src/zencore/include/zencore/uid.h @@ -2,6 +2,8 @@ #pragma once +#include <zencore/memcmp.h> +#include <zencore/memoryview.h> #include <zencore/zencore.h> #include <compare> @@ -53,19 +55,23 @@ class StringBuilderBase; struct Oid { - static const int StringLength = 24; - typedef char String_t[StringLength + 1]; + static constexpr int StringLength = 24; + typedef char String_t[StringLength + 1]; static void Initialize(); [[nodiscard]] static Oid NewOid(); const Oid& Generate(); [[nodiscard]] static Oid FromHexString(const std::string_view String); + [[nodiscard]] static Oid TryFromHexString(const std::string_view String, const Oid& Default = Oid::Zero); + static bool TryParse(std::string_view Str, Oid& Id); StringBuilderBase& ToString(StringBuilderBase& OutString) const; void ToString(char OutString[StringLength]) const; + std::string ToString() const; [[nodiscard]] static Oid FromMemory(const void* Ptr); - auto operator<=>(const Oid& rhs) const = default; + auto operator<=>(const Oid& rhs) const = default; + inline bool operator<(const Oid& rhs) const { return MemCmpFixed<sizeof OidBits, std::uint32_t>(OidBits, rhs.OidBits) < 0; } [[nodiscard]] inline explicit operator bool() const { return *this != Zero; } static const Oid Zero; // Min (can be used to signify a "null" value, or for open range queries) diff --git a/src/zencore/include/zencore/varint.h b/src/zencore/include/zencore/varint.h index c0d63d814..ae9aceed6 100644 --- a/src/zencore/include/zencore/varint.h +++ b/src/zencore/include/zencore/varint.h @@ -181,17 +181,17 @@ ReadVarInt(const void* InData, uint32_t& OutByteCount) } /** - * Write a variable-length unsigned integer. + * Write a pre-measured variable-length unsigned integer. * * @param InValue An unsigned integer to encode. + * @param ByteCount The number of bytes the integer requires to be encoded * @param OutData A buffer of at least 5 bytes to write the output to. * @return The number of bytes used in the output. */ -inline uint32_t -WriteVarUInt(uint32_t InValue, void* OutData) +inline void +WriteMeasuredVarUInt(uint32_t InValue, uint32_t ByteCount, void* OutData) { - const uint32_t ByteCount = MeasureVarUInt(InValue); - uint8_t* OutBytes = static_cast<uint8_t*>(OutData) + ByteCount - 1; + uint8_t* OutBytes = static_cast<uint8_t*>(OutData) + ByteCount - 1; switch (ByteCount - 1) { case 4: @@ -214,21 +214,35 @@ WriteVarUInt(uint32_t InValue, void* OutData) break; } *OutBytes = uint8_t(0xff << (9 - ByteCount)) | uint8_t(InValue); - return ByteCount; } /** * Write a variable-length unsigned integer. * * @param InValue An unsigned integer to encode. - * @param OutData A buffer of at least 9 bytes to write the output to. + * @param OutData A buffer of at least 5 bytes to write the output to. * @return The number of bytes used in the output. */ inline uint32_t -WriteVarUInt(uint64_t InValue, void* OutData) +WriteVarUInt(uint32_t InValue, void* OutData) { const uint32_t ByteCount = MeasureVarUInt(InValue); - uint8_t* OutBytes = static_cast<uint8_t*>(OutData) + ByteCount - 1; + WriteMeasuredVarUInt(InValue, MeasureVarUInt(InValue), OutData); + return ByteCount; +} + +/** + * Write a pre-measured variable-length unsigned integer. + * + * @param InValue An unsigned integer to encode. + * @param ByteCount The number of bytes the integer requires to be encoded + * @param OutData A buffer of at least 9 bytes to write the output to. + * @return The number of bytes used in the output. + */ +inline void +WriteMeasuredVarUInt(uint64_t InValue, uint32_t ByteCount, void* OutData) +{ + uint8_t* OutBytes = static_cast<uint8_t*>(OutData) + ByteCount - 1; switch (ByteCount - 1) { case 8: @@ -267,6 +281,20 @@ WriteVarUInt(uint64_t InValue, void* OutData) break; } *OutBytes = uint8_t(0xff << (9 - ByteCount)) | uint8_t(InValue); +} + +/** + * Write a variable-length unsigned integer. + * + * @param InValue An unsigned integer to encode. + * @param OutData A buffer of at least 9 bytes to write the output to. + * @return The number of bytes used in the output. + */ +inline uint32_t +WriteVarUInt(uint64_t InValue, void* OutData) +{ + const uint32_t ByteCount = MeasureVarUInt(InValue); + WriteMeasuredVarUInt(InValue, ByteCount, OutData); return ByteCount; } diff --git a/src/zencore/include/zencore/windows.h b/src/zencore/include/zencore/windows.h index 14026fef1..b2b220f8f 100644 --- a/src/zencore/include/zencore/windows.h +++ b/src/zencore/include/zencore/windows.h @@ -412,5 +412,7 @@ private: DWORD m_dwViewDesiredAccess; }; +bool IsRunningOnWine(); + } // namespace zen::windows #endif diff --git a/src/zencore/include/zencore/workthreadpool.h b/src/zencore/include/zencore/workthreadpool.h index 62356495c..4c38dd651 100644 --- a/src/zencore/include/zencore/workthreadpool.h +++ b/src/zencore/include/zencore/workthreadpool.h @@ -18,11 +18,7 @@ struct IWork : public RefCounted { virtual void Execute() = 0; - inline std::exception_ptr GetException() { return m_Exception; } - private: - std::exception_ptr m_Exception; - friend class WorkerThreadPool; }; @@ -35,13 +31,18 @@ public: WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName); ~WorkerThreadPool(); - void ScheduleWork(Ref<IWork> Work); - void ScheduleWork(std::function<void()>&& Work); + // Decides what to do if there are no free workers in the pool when the work is submitted + enum class EMode + { + EnableBacklog, // The work will be added to a backlog of work to do + DisableBacklog // The work will be executed synchronously in the caller thread + }; + + void ScheduleWork(Ref<IWork> Work, EMode Mode); + void ScheduleWork(std::function<void()>&& Work, EMode Mode); template<typename Func> - auto EnqueueTask(std::packaged_task<Func> Task); - - [[nodiscard]] size_t PendingWorkItemCount() const; + auto EnqueueTask(std::packaged_task<Func> Task, EMode Mode); private: struct Impl; @@ -54,7 +55,7 @@ private: template<typename Func> auto -WorkerThreadPool::EnqueueTask(std::packaged_task<Func> Task) +WorkerThreadPool::EnqueueTask(std::packaged_task<Func> Task, EMode Mode) { struct FutureWork : IWork { @@ -67,7 +68,7 @@ WorkerThreadPool::EnqueueTask(std::packaged_task<Func> Task) Ref<FutureWork> Work{new FutureWork(std::move(Task))}; auto Future = Work->m_Task.get_future(); - ScheduleWork(std::move(Work)); + ScheduleWork(std::move(Work), Mode); return Future; } diff --git a/src/zencore/include/zencore/xxhash.h b/src/zencore/include/zencore/xxhash.h index 04872f4c3..1616e5f93 100644 --- a/src/zencore/include/zencore/xxhash.h +++ b/src/zencore/include/zencore/xxhash.h @@ -4,7 +4,7 @@ #include "zencore.h" -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <xxh3.h> @@ -61,8 +61,10 @@ struct XXH3_128 struct XXH3_128Stream { + XXH3_128Stream() { Reset(); } + /// Begin streaming hash compute (not needed on freshly constructed instance) - void Reset() { memset(&m_State, 0, sizeof m_State); } + void Reset() { XXH3_128bits_reset(&m_State); } /// Append another chunk XXH3_128Stream& Append(const void* Data, size_t ByteCount) @@ -83,6 +85,33 @@ struct XXH3_128Stream } private: + XXH3_state_s m_State; +}; + +struct XXH3_128Stream_deprecated +{ + /// Begin streaming hash compute (not needed on freshly constructed instance) + void Reset() { memset(&m_State, 0, sizeof m_State); } + + /// Append another chunk + XXH3_128Stream_deprecated& Append(const void* Data, size_t ByteCount) + { + XXH3_128bits_update(&m_State, Data, ByteCount); + return *this; + } + + /// Append another chunk + XXH3_128Stream_deprecated& Append(MemoryView Data) { return Append(Data.GetData(), Data.GetSize()); } + + /// Obtain final hash. If you wish to reuse the instance call reset() + XXH3_128 GetHash() + { + XXH3_128 Hash; + XXH128_canonicalFromHash((XXH128_canonical_t*)Hash.Hash, XXH3_128bits_digest(&m_State)); + return Hash; + } + +private: XXH3_state_s m_State{}; }; diff --git a/src/zencore/include/zencore/zencore.h b/src/zencore/include/zencore/zencore.h index e8c734ba9..b5eb3e3e8 100644 --- a/src/zencore/include/zencore/zencore.h +++ b/src/zencore/include/zencore/zencore.h @@ -24,34 +24,48 @@ #endif namespace zen { + +struct CallstackFrames; + class AssertException : public std::logic_error { public: - inline explicit AssertException(const char* Msg) : std::logic_error(Msg) {} + using _Mybase = std::logic_error; + + virtual ~AssertException() noexcept; + + inline AssertException(const char* Msg, struct CallstackFrames* Callstack) noexcept : _Mybase(Msg), _Callstack(Callstack) {} + + AssertException(const AssertException& Rhs) noexcept; + + AssertException(AssertException&& Rhs) noexcept; + + AssertException& operator=(const AssertException& Rhs) noexcept; + + std::string FullDescription() const noexcept; + + struct CallstackFrames* _Callstack = nullptr; }; +struct CallstackFrames; + struct AssertImpl { + ZEN_FORCENOINLINE ZEN_DEBUG_SECTION AssertImpl(); + virtual ZEN_FORCENOINLINE ZEN_DEBUG_SECTION ~AssertImpl(); + static void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION ExecAssert - [[noreturn]] (const char* Filename, int LineNumber, const char* FunctionName, const char* Msg) - { - CurrentAssertImpl->OnAssert(Filename, LineNumber, FunctionName, Msg); - throw AssertException{Msg}; - } + [[noreturn]] (const char* Filename, int LineNumber, const char* FunctionName, const char* Msg); + + virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION + OnAssert(const char* Filename, int LineNumber, const char* FunctionName, const char* Msg, const CallstackFrames* Callstack); protected: - virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION OnAssert(const char* Filename, - int LineNumber, - const char* FunctionName, - const char* Msg) - { - (void(Filename)); - (void(LineNumber)); - (void(FunctionName)); - (void(Msg)); - } - static AssertImpl DefaultAssertImpl; + static void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION ThrowAssertException + [[noreturn]] (const char* Filename, int LineNumber, const char* FunctionName, const char* Msg, const CallstackFrames* Callstack); static AssertImpl* CurrentAssertImpl; + static AssertImpl DefaultAssertImpl; + AssertImpl* NextAssertImpl = nullptr; }; } // namespace zen @@ -84,7 +98,7 @@ protected: namespace zen { ZENCORE_API bool IsApplicationExitRequested(); -ZENCORE_API void RequestApplicationExit(int ExitCode); +ZENCORE_API bool RequestApplicationExit(int ExitCode); ZENCORE_API int ApplicationExitCode(); ZENCORE_API bool IsDebuggerPresent(); ZENCORE_API void SetIsInteractiveSession(bool Value); diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp index 80d0f4ee4..be9b39e7a 100644 --- a/src/zencore/iobuffer.cpp +++ b/src/zencore/iobuffer.cpp @@ -7,7 +7,9 @@ #include <zencore/fmtutils.h> #include <zencore/iohash.h> #include <zencore/logging.h> -#include <zencore/memory.h> +#include <zencore/memory/llm.h> +#include <zencore/memory/memory.h> +#include <zencore/memoryview.h> #include <zencore/testing.h> #include <zencore/thread.h> #include <zencore/trace.h> @@ -15,12 +17,6 @@ #include <memory.h> #include <system_error> -#if ZEN_USE_MIMALLOC -ZEN_THIRD_PARTY_INCLUDES_START -# include <mimalloc.h> -ZEN_THIRD_PARTY_INCLUDES_END -#endif - #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> #else @@ -30,6 +26,10 @@ ZEN_THIRD_PARTY_INCLUDES_END # include <unistd.h> #endif +#if ZEN_WITH_TESTS +# include <zencore/testutils.h> +#endif + #include <gsl/gsl-lite.hpp> namespace zen { @@ -39,61 +39,39 @@ namespace zen { void IoBufferCore::AllocateBuffer(size_t InSize, size_t Alignment) const { -#if ZEN_PLATFORM_WINDOWS - if (((InSize & 0xffFF) == 0) && (Alignment == 0x10000)) - { - m_Flags.fetch_or(kLowLevelAlloc, std::memory_order_relaxed); - void* Ptr = VirtualAlloc(nullptr, InSize, MEM_COMMIT, PAGE_READWRITE); - if (!Ptr) - { - ThrowLastError(fmt::format("VirtualAlloc failed for {:#x} bytes aligned to {:#x}", InSize, Alignment)); - } - m_DataPtr = Ptr; - return; - } -#endif // ZEN_PLATFORM_WINDOWS + ZEN_MEMSCOPE(ELLMTag::IoBufferMemory); -#if ZEN_USE_MIMALLOC - void* Ptr = mi_aligned_alloc(Alignment, RoundUp(InSize, Alignment)); - m_Flags.fetch_or(kIoBufferAlloc, std::memory_order_relaxed); -#else void* Ptr = Memory::Alloc(InSize, Alignment); -#endif if (!Ptr) { ThrowOutOfMemory(fmt::format("failed allocating {:#x} bytes aligned to {:#x}", InSize, Alignment)); } + m_DataPtr = Ptr; } void IoBufferCore::FreeBuffer() { - if (!m_DataPtr) + if (m_DataPtr) { - return; - } - - const uint32_t LocalFlags = m_Flags.load(std::memory_order_relaxed); -#if ZEN_PLATFORM_WINDOWS - if (LocalFlags & kLowLevelAlloc) - { - VirtualFree(const_cast<void*>(m_DataPtr), 0, MEM_DECOMMIT); - - return; + Memory::Free(const_cast<void*>(m_DataPtr)); + m_DataPtr = nullptr; } -#endif // ZEN_PLATFORM_WINDOWS +} -#if ZEN_USE_MIMALLOC - if (LocalFlags & kIoBufferAlloc) - { - return mi_free(const_cast<void*>(m_DataPtr)); - } -#endif +void* +IoBufferCore::operator new(size_t Size) +{ + ZEN_MEMSCOPE(ELLMTag::IoBufferCore); + return Memory::Malloc(Size); +} - ZEN_UNUSED(LocalFlags); - return Memory::Free(const_cast<void*>(m_DataPtr)); +void +IoBufferCore::operator delete(void* Ptr) +{ + Memory::Free(Ptr); } ////////////////////////////////////////////////////////////////////////// @@ -122,10 +100,9 @@ IoBufferCore::IoBufferCore(size_t InSize, size_t Alignment) IoBufferCore::~IoBufferCore() { - if (IsOwnedByThis() && m_DataPtr) + if (IsOwnedByThis()) { FreeBuffer(); - m_DataPtr = nullptr; } } @@ -212,6 +189,10 @@ IoBufferExtendedCore::~IoBufferExtendedCore() m_DataPtr = nullptr; // prevent any buffer deallocation attempts } + else if (m_DataPtr == reinterpret_cast<uint8_t*>(&m_MmapHandle)) + { + m_DataPtr = nullptr; + } const uint32_t LocalFlags = m_Flags.load(std::memory_order_relaxed); #if ZEN_PLATFORM_WINDOWS @@ -231,8 +212,18 @@ IoBufferExtendedCore::~IoBufferExtendedCore() SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); #else - std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle); - unlink(FilePath.c_str()); + std::error_code Ec; + std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec); + if (Ec) + { + ZEN_WARN("Error reported on file handle {}, get path for IoBufferExtendedCore destructor, reason '{}'", + m_FileHandle, + Ec.message()); + } + else + { + unlink(FilePath.c_str()); + } #endif } #if ZEN_PLATFORM_WINDOWS @@ -298,7 +289,7 @@ IoBufferExtendedCore::Materialize() const return; } - const size_t DisableMMapSizeLimit = 0x1000ull; + const size_t DisableMMapSizeLimit = 0x2000ull; if (m_DataBytes < DisableMMapSizeLimit) { @@ -306,53 +297,26 @@ IoBufferExtendedCore::Materialize() const AllocateBuffer(m_DataBytes, sizeof(void*)); NewFlags |= kIsOwnedByThis; - int32_t Error = 0; - size_t BytesRead = 0; - -#if ZEN_PLATFORM_WINDOWS - OVERLAPPED Ovl{}; - - Ovl.Offset = DWORD(m_FileOffset & 0xffff'ffffu); - Ovl.OffsetHigh = DWORD(m_FileOffset >> 32); - DWORD dwNumberOfBytesRead = 0; - BOOL Success = ::ReadFile(m_FileHandle, (void*)m_DataPtr, DWORD(m_DataBytes), &dwNumberOfBytesRead, &Ovl) == TRUE; - if (Success) - { - BytesRead = size_t(dwNumberOfBytesRead); - } - else - { - Error = zen::GetLastError(); - } -#else - static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); - int Fd = int(uintptr_t(m_FileHandle)); - int ReadResult = pread(Fd, (void*)m_DataPtr, m_DataBytes, m_FileOffset); - if (ReadResult != -1) - { - BytesRead = size_t(ReadResult); - } - else - { - Error = zen::GetLastError(); - } -#endif // ZEN_PLATFORM_WINDOWS - if (Error || (BytesRead != m_DataBytes)) + std::error_code Ec; + ReadFile(m_FileHandle, (void*)m_DataPtr, m_DataBytes, m_FileOffset, DisableMMapSizeLimit, Ec); + if (Ec) { std::error_code DummyEc; - ZEN_WARN("ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x}), {}", + ZEN_WARN("IoBufferExtendedCore::Materialize: ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x}), {} ({})", m_FileOffset, m_DataBytes, zen::PathFromHandle(m_FileHandle, DummyEc), zen::FileSizeFromHandle(m_FileHandle), - GetSystemErrorAsString(Error)); - throw std::system_error(std::error_code(Error, std::system_category()), - fmt::format("ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})", - m_FileOffset, - m_DataBytes, - PathFromHandle(m_FileHandle, DummyEc), - FileSizeFromHandle(m_FileHandle))); + Ec.message(), + Ec.value()); + throw std::system_error( + Ec, + fmt::format("IoBufferExtendedCore::Materialize: ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})", + m_FileOffset, + m_DataBytes, + PathFromHandle(m_FileHandle, DummyEc), + FileSizeFromHandle(m_FileHandle))); } m_Flags.fetch_or(NewFlags, std::memory_order_release); @@ -465,7 +429,25 @@ IoBufferExtendedCore::SetDeleteOnClose(bool DeleteOnClose) ////////////////////////////////////////////////////////////////////////// -RefPtr<IoBufferCore> IoBuffer::NullBufferCore(new IoBufferCore); +static IoBufferCore* +GetNullBufferCore() +{ + // This is safe from a threading standpoint since the first call is non-threaded (during static init) and for the following + // calls Core is never nullptr + // We do this workaround since we don't want to call new (IoBufferCore) at static initializers + // Calling new during static initialize causes problem with memtracing since the flags are not set up correctly yet + + static IoBufferCore NullBufferCore; + static IoBufferCore* Core = nullptr; + if (Core == nullptr) + { + Core = &NullBufferCore; + Core->AddRef(); // Make sure we never deallocate it as it is a static instance + } + return Core; +} + +RefPtr<IoBufferCore> IoBuffer::NullBufferCore(GetNullBufferCore()); IoBuffer::IoBuffer(size_t InSize) : m_Core(new IoBufferCore(InSize)) { @@ -547,37 +529,60 @@ IoBufferBuilder::ReadFromFileMaybe(const IoBuffer& InBuffer) { IoBuffer OutBuffer(FileRef.FileChunkSize); + int32_t Error = 0; + size_t BytesRead = 0; + + const uint64_t NumberOfBytesToRead = FileRef.FileChunkSize; + const uint64_t FileOffset = FileRef.FileChunkOffset; + #if ZEN_PLATFORM_WINDOWS OVERLAPPED Ovl{}; - const uint64_t NumberOfBytesToRead = FileRef.FileChunkSize; - const uint64_t& FileOffset = FileRef.FileChunkOffset; - Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); Ovl.OffsetHigh = DWORD(FileOffset >> 32); DWORD dwNumberOfBytesRead = 0; BOOL Success = ::ReadFile(FileRef.FileHandle, OutBuffer.MutableData(), DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl); + if (Success) + { + BytesRead = size_t(dwNumberOfBytesRead); + } + else + { + Error = zen::GetLastError(); + } #else - int Fd = int(intptr_t(FileRef.FileHandle)); - int Result = pread(Fd, OutBuffer.MutableData(), size_t(FileRef.FileChunkSize), off_t(FileRef.FileChunkOffset)); - bool Success = (Result >= 0); - - uint32_t dwNumberOfBytesRead = uint32_t(Result); + int Fd = int(intptr_t(FileRef.FileHandle)); + ssize_t ReadResult = pread(Fd, OutBuffer.MutableData(), size_t(NumberOfBytesToRead), off_t(FileOffset)); + if (ReadResult != -1) + { + BytesRead = size_t(ReadResult); + } + else + { + Error = zen::GetLastError(); + } #endif - if (!Success) + if (Error || (BytesRead != NumberOfBytesToRead)) { - ThrowLastError(fmt::format("file read failed in IoBufferBuilder::ReadFromFileMaybe (handle: {}, offset: {}, length: {})", - intptr_t(FileRef.FileHandle), - FileRef.FileChunkOffset, - FileRef.FileChunkSize)); + std::error_code DummyEc; + ZEN_WARN("IoBufferBuilder::ReadFromFileMaybe: ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x}), {}", + FileOffset, + NumberOfBytesToRead, + zen::PathFromHandle(FileRef.FileHandle, DummyEc), + zen::FileSizeFromHandle(FileRef.FileHandle), + GetSystemErrorAsString(Error)); + throw std::system_error( + std::error_code(Error, std::system_category()), + fmt::format("IoBufferBuilder::ReadFromFileMaybe: ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})", + FileOffset, + NumberOfBytesToRead, + PathFromHandle(FileRef.FileHandle, DummyEc), + FileSizeFromHandle(FileRef.FileHandle))); } - ZEN_ASSERT(dwNumberOfBytesRead == FileRef.FileChunkSize); - OutBuffer.SetContentType(InBuffer.GetContentType()); - return OutBuffer; } else @@ -615,7 +620,7 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of DataFile.GetSize((ULONGLONG&)FileSize); #else int Flags = O_RDONLY | O_CLOEXEC; - int Fd = open(FileName.c_str(), Flags); + int Fd = open(FileName.c_str(), Flags); if (Fd < 0) { return {}; @@ -684,7 +689,7 @@ IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName) Handle = DataFile.Detach(); #else - int Fd = open(FileName.native().c_str(), O_RDONLY); + int Fd = open(FileName.native().c_str(), O_RDONLY); if (Fd < 0) { return {}; @@ -701,13 +706,6 @@ IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName) return IoBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true); } -IoHash -HashBuffer(IoBuffer& Buffer) -{ - // TODO: handle disk buffers with special path - return IoHash::HashBuffer(Buffer.Data(), Buffer.Size()); -} - ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS @@ -726,14 +724,16 @@ TEST_CASE("IoBuffer") TEST_CASE("IoBuffer.mmap") { + zen::ScopedTemporaryDirectory TempDir; + zen::IoBuffer Buffer1{65536}; uint8_t* Mutate = Buffer1.MutableData<uint8_t>(); memcpy(Mutate, "abc123", 6); - zen::WriteFile("test_file.data", Buffer1); + zen::WriteFile(TempDir.Path() / "test_file.data", Buffer1); SUBCASE("in-range") { - zen::IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile("test_file.data", 0, 65536); + zen::IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(TempDir.Path() / "test_file.data", 0, 65536); const void* Data = FileBuffer.GetData(); CHECK(Data != nullptr); CHECK_EQ(memcmp(Data, "abc123", 6), 0); @@ -744,7 +744,7 @@ TEST_CASE("IoBuffer.mmap") # if ZEN_PLATFORM_WINDOWS SUBCASE("out-of-range") { - zen::IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile("test_file.data", 131072, 65536); + zen::IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(TempDir.Path() / "test_file.data", 131072, 65536); const void* Data = nullptr; CHECK_THROWS(Data = FileBuffer.GetData()); CHECK(Data == nullptr); diff --git a/src/zencore/iohash.cpp b/src/zencore/iohash.cpp index 77076c133..3b2af0db4 100644 --- a/src/zencore/iohash.cpp +++ b/src/zencore/iohash.cpp @@ -4,6 +4,7 @@ #include <zencore/blake3.h> #include <zencore/compositebuffer.h> +#include <zencore/filesystem.h> #include <zencore/string.h> #include <zencore/testing.h> @@ -11,7 +12,11 @@ namespace zen { -const IoHash IoHash::Zero{}; // Initialized to all zeros +static const uint8_t MaxData[20] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +const IoHash IoHash::Max = IoHash::MakeFrom(MaxData); // Initialized to all 0xff +const IoHash IoHash::Zero{}; // Initialized to all zeros IoHash IoHash::HashBuffer(const void* data, size_t byteCount) @@ -25,13 +30,72 @@ IoHash::HashBuffer(const void* data, size_t byteCount) } IoHash -IoHash::HashBuffer(const CompositeBuffer& Buffer) +IoHash::HashBuffer(const CompositeBuffer& Buffer, std::atomic<uint64_t>* ProcessedBytes) { IoHashStream Hasher; for (const SharedBuffer& Segment : Buffer.GetSegments()) { - Hasher.Append(Segment.GetData(), Segment.GetSize()); + size_t SegmentSize = Segment.GetSize(); + static const uint64_t BufferingSize = 256u * 1024u; + + IoBufferFileReference FileRef; + if (SegmentSize >= (BufferingSize + BufferingSize / 2) && Segment.GetFileReference(FileRef)) + { + ScanFile(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + BufferingSize, + [&Hasher, ProcessedBytes](const void* Data, size_t Size) { + Hasher.Append(Data, Size); + if (ProcessedBytes != nullptr) + { + ProcessedBytes->fetch_add(Size); + } + }); + } + else + { + Hasher.Append(Segment.GetData(), SegmentSize); + if (ProcessedBytes != nullptr) + { + ProcessedBytes->fetch_add(SegmentSize); + } + } + } + + return Hasher.GetHash(); +} + +IoHash +IoHash::HashBuffer(const IoBuffer& Buffer, std::atomic<uint64_t>* ProcessedBytes) +{ + IoHashStream Hasher; + + size_t BufferSize = Buffer.GetSize(); + static const uint64_t BufferingSize = 256u * 1024u; + IoBufferFileReference FileRef; + if (BufferSize >= (BufferingSize + BufferingSize / 2) && Buffer.GetFileReference(FileRef)) + { + ScanFile(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + BufferingSize, + [&Hasher, ProcessedBytes](const void* Data, size_t Size) { + Hasher.Append(Data, Size); + if (ProcessedBytes != nullptr) + { + ProcessedBytes->fetch_add(Size); + } + }); + } + else + { + Hasher.Append(Buffer.GetData(), BufferSize); + if (ProcessedBytes != nullptr) + { + ProcessedBytes->fetch_add(BufferSize); + } } return Hasher.GetHash(); @@ -55,6 +119,24 @@ IoHash::FromHexString(std::string_view string) return io; } +bool +IoHash::TryParse(std::string_view Str, IoHash& Hash) +{ + using namespace std::literals; + + if (Str.size() == IoHash::StringLength) + { + return ParseHexBytes(Str.data(), Str.size(), Hash.Hash); + } + + if (Str.starts_with("0x"sv)) + { + return TryParse(Str.substr(2), Hash); + } + + return false; +} + const char* IoHash::ToHexString(char* outString /* 40 characters + NUL terminator */) const { diff --git a/src/zencore/jobqueue.cpp b/src/zencore/jobqueue.cpp index 4bcc5c885..bd391909d 100644 --- a/src/zencore/jobqueue.cpp +++ b/src/zencore/jobqueue.cpp @@ -11,6 +11,10 @@ # include <zencore/testing.h> #endif // ZEN_WITH_TESTS +ZEN_THIRD_PARTY_INCLUDES_START +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <deque> #include <thread> #include <unordered_map> @@ -46,12 +50,16 @@ public: JobClock::Tick StartTick; JobClock::Tick EndTick; int WorkerThreadId; + int ReturnCode; virtual bool IsCancelled() const override { return CancelFlag.load(); } virtual void ReportMessage(std::string_view Message) override { Queue->ReportMessage(Id, Message); } - virtual void ReportProgress(std::string_view CurrentOp, uint32_t CurrentOpPercentComplete) override + virtual void ReportProgress(std::string_view CurrentOp, + std::string_view Details, + ptrdiff_t TotalCount, + ptrdiff_t RemainingCount) override { - Queue->ReportProgress(Id, CurrentOp, CurrentOpPercentComplete); + Queue->ReportProgress(Id, CurrentOp, Details, TotalCount, RemainingCount); } }; @@ -69,7 +77,7 @@ public: Stop(); } } - catch (std::exception& Ex) + catch (const std::exception& Ex) { ZEN_WARN("Failed shutting down jobqueue. Reason: '{}'", Ex.what()); } @@ -94,19 +102,22 @@ public: NewJob->StartTick = JobClock::Never(); NewJob->EndTick = JobClock::Never(); NewJob->WorkerThreadId = 0; + NewJob->ReturnCode = -1; ZEN_DEBUG("Scheduling background job {}:'{}'", NewJob->Id.Id, NewJob->Name); QueueLock.WithExclusiveLock([&]() { QueuedJobs.emplace_back(std::move(NewJob)); }); WorkerCounter.AddCount(1); try { - WorkerPool.ScheduleWork([&]() { - auto _ = MakeGuard([&]() { WorkerCounter.CountDown(); }); - Worker(); - }); + WorkerPool.ScheduleWork( + [&]() { + auto _ = MakeGuard([&]() { WorkerCounter.CountDown(); }); + Worker(); + }, + WorkerThreadPool::EMode::EnableBacklog); return {.Id = NewJobId}; } - catch (std::exception& Ex) + catch (const std::exception& Ex) { WorkerCounter.CountDown(); QueueLock.WithExclusiveLock([&]() { @@ -254,15 +265,21 @@ public: virtual std::optional<JobDetails> Get(JobId Id) override { auto Convert = [](Status Status, Job& Job) -> JobDetails { - return JobDetails{.Name = Job.Name, - .Status = Status, - .State = {.CurrentOp = Job.State.CurrentOp, - .CurrentOpPercentComplete = Job.State.CurrentOpPercentComplete, - .Messages = std::move(Job.State.Messages)}, - .CreateTime = JobClock::TimePointFromTick(Job.CreateTick), - .StartTime = JobClock::TimePointFromTick(Job.StartTick), - .EndTime = JobClock::TimePointFromTick(Job.EndTick), - .WorkerThreadId = Job.WorkerThreadId}; + return JobDetails{ + .Name = Job.Name, + .Status = Status, + .State = {.CurrentOp = Job.State.CurrentOp, + .CurrentOpDetails = Job.State.CurrentOpDetails, + .TotalCount = Job.State.TotalCount, + .RemainingCount = Job.State.RemainingCount, + // .CurrentOpPercentComplete = Job.State.CurrentOpPercentComplete, + .Messages = std::move(Job.State.Messages), + .AbortReason = Job.State.AbortReason}, + .CreateTime = JobClock::TimePointFromTick(Job.CreateTick), + .StartTime = JobClock::TimePointFromTick(Job.StartTick), + .EndTime = JobClock::TimePointFromTick(Job.EndTick), + .WorkerThreadId = Job.WorkerThreadId, + .ReturnCode = Job.ReturnCode}; }; std::optional<JobDetails> Result; @@ -296,20 +313,22 @@ public: void ReportMessage(JobId Id, std::string_view Message) { - QueueLock.WithSharedLock([&]() { + QueueLock.WithExclusiveLock([&]() { auto It = RunningJobs.find(Id.Id); ZEN_ASSERT(It != RunningJobs.end()); It->second->State.Messages.push_back(std::string(Message)); }); } - void ReportProgress(JobId Id, std::string_view CurrentOp, uint32_t CurrentOpPercentComplete) + void ReportProgress(JobId Id, std::string_view CurrentOp, std::string_view Details, ptrdiff_t TotalCount, ptrdiff_t RemainingCount) { - QueueLock.WithSharedLock([&]() { + QueueLock.WithExclusiveLock([&]() { auto It = RunningJobs.find(Id.Id); ZEN_ASSERT(It != RunningJobs.end()); - It->second->State.CurrentOp = CurrentOp; - It->second->State.CurrentOpPercentComplete = CurrentOpPercentComplete; + It->second->State.CurrentOp = CurrentOp; + It->second->State.CurrentOpDetails = Details; + It->second->State.TotalCount = TotalCount; + It->second->State.RemainingCount = RemainingCount; }); } @@ -351,6 +370,7 @@ public: ZEN_DEBUG("Executing background job {}:'{}'", CurrentJob->Id.Id, CurrentJob->Name); CurrentJob->Callback(*CurrentJob); ZEN_DEBUG("Completed background job {}:'{}'", CurrentJob->Id.Id, CurrentJob->Name); + CurrentJob->ReturnCode = 0; QueueLock.WithExclusiveLock([&]() { CurrentJob->EndTick = JobClock::Now(); CurrentJob->WorkerThreadId = 0; @@ -358,13 +378,40 @@ public: CompletedJobs.insert_or_assign(CurrentJob->Id.Id, std::move(CurrentJob)); }); } - catch (std::exception& Ex) + catch (const AssertException& Ex) + { + ZEN_DEBUG("Background job {}:'{}' asserted. Reason: {}", CurrentJob->Id.Id, CurrentJob->Name, Ex.FullDescription()); + QueueLock.WithExclusiveLock([&]() { + CurrentJob->State.AbortReason = Ex.FullDescription(); + CurrentJob->EndTick = JobClock::Now(); + CurrentJob->WorkerThreadId = 0; + RunningJobs.erase(CurrentJob->Id.Id); + AbortedJobs.insert_or_assign(CurrentJob->Id.Id, std::move(CurrentJob)); + }); + } + catch (const JobError& Ex) + { + ZEN_DEBUG("Background job {}:'{}' failed. Reason: '{}'. Return code {}", + CurrentJob->Id.Id, + CurrentJob->Name, + Ex.what(), + Ex.m_ReturnCode); + QueueLock.WithExclusiveLock([&]() { + CurrentJob->State.AbortReason = Ex.what(); + CurrentJob->EndTick = JobClock::Now(); + CurrentJob->WorkerThreadId = 0; + CurrentJob->ReturnCode = Ex.m_ReturnCode; + RunningJobs.erase(CurrentJob->Id.Id); + AbortedJobs.insert_or_assign(CurrentJob->Id.Id, std::move(CurrentJob)); + }); + } + catch (const std::exception& Ex) { ZEN_DEBUG("Background job {}:'{}' aborted. Reason: '{}'", CurrentJob->Id.Id, CurrentJob->Name, Ex.what()); QueueLock.WithExclusiveLock([&]() { - CurrentJob->State.Messages.push_back(Ex.what()); - CurrentJob->EndTick = JobClock::Now(); - CurrentJob->WorkerThreadId = 0; + CurrentJob->State.AbortReason = Ex.what(); + CurrentJob->EndTick = JobClock::Now(); + CurrentJob->WorkerThreadId = 0; RunningJobs.erase(CurrentJob->Id.Id); AbortedJobs.insert_or_assign(CurrentJob->Id.Id, std::move(CurrentJob)); }); @@ -418,37 +465,39 @@ TEST_CASE("JobQueue") std::unique_ptr<JobQueue> Queue(MakeJobQueue(2, "queue")); WorkerThreadPool Pool(4); Latch JobsLatch(1); - for (uint32_t I = 0; I < 100; I++) + for (uint32_t I = 0; I < 32; I++) { JobsLatch.AddCount(1); - Pool.ScheduleWork([&Queue, &JobsLatch, I]() { - auto _ = MakeGuard([&JobsLatch]() { JobsLatch.CountDown(); }); - JobsLatch.AddCount(1); - auto Id = Queue->QueueJob(fmt::format("busy {}", I), [&JobsLatch, I](JobContext& Context) { - auto $ = MakeGuard([&JobsLatch]() { JobsLatch.CountDown(); }); - if (Context.IsCancelled()) - { - return; - } - Context.ReportProgress("going to sleep", 0); - Sleep(10); - if (Context.IsCancelled()) - { - return; - } - Context.ReportProgress("going to sleep again", 50); - if ((I & 0xFF) == 0x10) - { - zen::ThrowSystemError(8, fmt::format("Job {} forced to fail", I)); - } - Sleep(10); - if (Context.IsCancelled()) - { - return; - } - Context.ReportProgress("done", 100); - }); - }); + Pool.ScheduleWork( + [&Queue, &JobsLatch, I]() { + auto _ = MakeGuard([&JobsLatch]() { JobsLatch.CountDown(); }); + JobsLatch.AddCount(1); + auto Id = Queue->QueueJob(fmt::format("busy {}", I), [&JobsLatch, I](JobContext& Context) { + auto $ = MakeGuard([&JobsLatch]() { JobsLatch.CountDown(); }); + if (Context.IsCancelled()) + { + return; + } + Context.ReportProgress("going to sleep", "", 100, 100); + Sleep(5); + if (Context.IsCancelled()) + { + return; + } + Context.ReportProgress("going to sleep again", "", 100, 50); + if ((I & 0xFF) == 0x10) + { + zen::ThrowSystemError(8, fmt::format("Job {} forced to fail", I)); + } + Sleep(5); + if (Context.IsCancelled()) + { + return; + } + Context.ReportProgress("done", "", 100, 0); + }); + }, + WorkerThreadPool::EMode::EnableBacklog); } auto Join = [](std::span<std::string> Strings, std::string_view Delimiter) -> std::string { @@ -495,15 +544,20 @@ TEST_CASE("JobQueue") RemainingJobs.push_back(Id); break; case JobQueue::Status::Running: - ZEN_DEBUG("{} running. '{}' {}% '{}'", - Id.Id, - CurrentState->State.CurrentOp, - CurrentState->State.CurrentOpPercentComplete, - Join(CurrentState->State.Messages, " "sv)); + ZEN_DEBUG( + "{} running. '{}{}' {}% '{}'", + Id.Id, + CurrentState->State.CurrentOp, + CurrentState->State.CurrentOpDetails.empty() ? ""sv : fmt::format(", {}", CurrentState->State.CurrentOpDetails), + CurrentState->State.TotalCount > 0 + ? gsl::narrow<uint32_t>((100 * (CurrentState->State.TotalCount - CurrentState->State.RemainingCount)) / + CurrentState->State.TotalCount) + : 0, + Join(CurrentState->State.Messages, " "sv)); RemainingJobs.push_back(Id); break; case JobQueue::Status::Aborted: - ZEN_DEBUG("{} aborted. Reason: '{}'", Id.Id, Join(CurrentState->State.Messages, " "sv)); + ZEN_DEBUG("{} aborted. Reason: '{}'", Id.Id, CurrentState->State.AbortReason); break; case JobQueue::Status::Completed: ZEN_DEBUG("{} completed. '{}'", Id.Id, Join(CurrentState->State.Messages, " "sv)); @@ -521,7 +575,7 @@ TEST_CASE("JobQueue") RemainingJobs.size(), PendingCount, RemainingJobs.size() - PendingCount); - Sleep(100); + Sleep(5); } JobsLatch.Wait(); } diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp index 90f4e2428..a6697c443 100644 --- a/src/zencore/logging.cpp +++ b/src/zencore/logging.cpp @@ -6,6 +6,8 @@ #include <zencore/testing.h> #include <zencore/thread.h> +#include <zencore/memory/llm.h> + ZEN_THIRD_PARTY_INCLUDES_START #include <spdlog/details/registry.h> #include <spdlog/sinks/null_sink.h> @@ -66,6 +68,7 @@ static_assert(offsetof(spdlog::source_loc, funcname) == offsetof(SourceLocation, void EmitLogMessage(LoggerRef& Logger, int LogLevel, const std::string_view Message) { + ZEN_MEMSCOPE(ELLMTag::Logging); const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; Logger.SpdLogger->log(InLevel, Message); if (IsErrorLevel(LogLevel)) @@ -80,6 +83,7 @@ EmitLogMessage(LoggerRef& Logger, int LogLevel, const std::string_view Message) void EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Format, fmt::format_args Args) { + ZEN_MEMSCOPE(ELLMTag::Logging); zen::logging::LoggingContext LogCtx; fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); zen::logging::EmitLogMessage(Logger, LogLevel, LogCtx.Message()); @@ -88,6 +92,7 @@ EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Format, fmt::fo void EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel, const std::string_view Message) { + ZEN_MEMSCOPE(ELLMTag::Logging); const spdlog::source_loc& Location = *reinterpret_cast<const spdlog::source_loc*>(&InLocation); const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; Logger.SpdLogger->log(Location, InLevel, Message); @@ -103,6 +108,7 @@ EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel void EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel, std::string_view Format, fmt::format_args Args) { + ZEN_MEMSCOPE(ELLMTag::Logging); zen::logging::LoggingContext LogCtx; fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); zen::logging::EmitLogMessage(Logger, InLocation, LogLevel, LogCtx.Message()); @@ -111,14 +117,39 @@ EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel void EmitConsoleLogMessage(int LogLevel, const std::string_view Message) { + ZEN_MEMSCOPE(ELLMTag::Logging); const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; ConsoleLog().SpdLogger->log(InLevel, Message); } +#define ZEN_COLOR_YELLOW "\033[0;33m" +#define ZEN_COLOR_RED "\033[0;31m" +#define ZEN_BRIGHT_COLOR_RED "\033[1;31m" +#define ZEN_COLOR_RESET "\033[0m" + void EmitConsoleLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) { + ZEN_MEMSCOPE(ELLMTag::Logging); zen::logging::LoggingContext LogCtx; + + // We are not using a format option for console which include log level since it would interfere with normal console output + + const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; + switch (InLevel) + { + case spdlog::level::level_enum::warn: + fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_COLOR_YELLOW "Warning: " ZEN_COLOR_RESET); + break; + case spdlog::level::level_enum::err: + fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_BRIGHT_COLOR_RED "Error: " ZEN_COLOR_RESET); + break; + case spdlog::level::level_enum::critical: + fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_COLOR_RED "Critical: " ZEN_COLOR_RESET); + break; + default: + break; + } fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); zen::logging::EmitConsoleLogMessage(LogLevel, LogCtx.Message()); } @@ -192,6 +223,8 @@ std::string LogLevels[level::LogLevelCount]; void ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers) { + ZEN_MEMSCOPE(ELLMTag::Logging); + RwLock::ExclusiveLockScope _(LogLevelsLock); LogLevels[Level] = Loggers; } @@ -199,6 +232,8 @@ ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers) void RefreshLogLevels(level::LogLevel* DefaultLevel) { + ZEN_MEMSCOPE(ELLMTag::Logging); + spdlog::details::registry::log_levels Levels; { @@ -275,6 +310,8 @@ Default() void SetDefault(std::string_view NewDefaultLoggerId) { + ZEN_MEMSCOPE(ELLMTag::Logging); + auto NewDefaultLogger = spdlog::get(std::string(NewDefaultLoggerId)); ZEN_ASSERT(NewDefaultLogger); @@ -293,6 +330,8 @@ ErrorLog() void SetErrorLog(std::string_view NewErrorLoggerId) { + ZEN_MEMSCOPE(ELLMTag::Logging); + if (NewErrorLoggerId.empty()) { TheErrorLogger = {}; @@ -307,16 +346,27 @@ SetErrorLog(std::string_view NewErrorLoggerId) } } +RwLock g_LoggerMutex; + LoggerRef Get(std::string_view Name) { + ZEN_MEMSCOPE(ELLMTag::Logging); + std::shared_ptr<spdlog::logger> Logger = spdlog::get(std::string(Name)); if (!Logger) { - Logger = Default().SpdLogger->clone(std::string(Name)); - spdlog::apply_logger_env_levels(Logger); - spdlog::register_logger(Logger); + g_LoggerMutex.WithExclusiveLock([&] { + Logger = spdlog::get(std::string(Name)); + + if (!Logger) + { + Logger = Default().SpdLogger->clone(std::string(Name)); + spdlog::apply_logger_env_levels(Logger); + spdlog::register_logger(Logger); + } + }); } return *Logger; @@ -339,6 +389,8 @@ SuppressConsoleLog() LoggerRef ConsoleLog() { + ZEN_MEMSCOPE(ELLMTag::Logging); + std::call_once(ConsoleInitFlag, [&] { if (!ConLogger) { @@ -355,6 +407,8 @@ ConsoleLog() void InitializeLogging() { + ZEN_MEMSCOPE(ELLMTag::Logging); + TheDefaultLogger = *spdlog::default_logger_raw(); } @@ -392,6 +446,12 @@ EnableVTMode() return true; } +void +FlushLogging() +{ + spdlog::details::registry::instance().flush_all(); +} + } // namespace zen::logging namespace zen { diff --git a/src/zencore/memory.cpp b/src/zencore/memory.cpp deleted file mode 100644 index 808c9fcb6..000000000 --- a/src/zencore/memory.cpp +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zencore/except.h> -#include <zencore/fmtutils.h> -#include <zencore/intmath.h> -#include <zencore/memory.h> -#include <zencore/testing.h> -#include <zencore/zencore.h> - -#include <cstdlib> - -#if ZEN_USE_MIMALLOC -ZEN_THIRD_PARTY_INCLUDES_START -# include <mimalloc.h> -ZEN_THIRD_PARTY_INCLUDES_END -#endif - -namespace zen { - -////////////////////////////////////////////////////////////////////////// - -static void* -AlignedAllocImpl(size_t Size, size_t Alignment) -{ - // aligned_alloc() states that size must be a multiple of alignment. Some - // platforms return null if this requirement isn't met. - Size = (Size + Alignment - 1) & ~(Alignment - 1); - -#if ZEN_USE_MIMALLOC - return mi_aligned_alloc(Alignment, Size); -#elif ZEN_PLATFORM_WINDOWS - return _aligned_malloc(Size, Alignment); -#else - return std::aligned_alloc(Alignment, Size); -#endif -} - -void -AlignedFreeImpl(void* ptr) -{ - if (ptr == nullptr) - return; - -#if ZEN_USE_MIMALLOC - return mi_free(ptr); -#elif ZEN_PLATFORM_WINDOWS - _aligned_free(ptr); -#else - std::free(ptr); -#endif -} - -////////////////////////////////////////////////////////////////////////// - -MemoryArena::MemoryArena() -{ -} - -MemoryArena::~MemoryArena() -{ -} - -void* -MemoryArena::Alloc(size_t Size, size_t Alignment) -{ - return AlignedAllocImpl(Size, Alignment); -} - -void -MemoryArena::Free(void* ptr) -{ - AlignedFreeImpl(ptr); -} - -////////////////////////////////////////////////////////////////////////// - -void* -Memory::Alloc(size_t Size, size_t Alignment) -{ - return AlignedAllocImpl(Size, Alignment); -} - -void -Memory::Free(void* ptr) -{ - AlignedFreeImpl(ptr); -} - -////////////////////////////////////////////////////////////////////////// - -ChunkingLinearAllocator::ChunkingLinearAllocator(uint64_t ChunkSize, uint64_t ChunkAlignment) -: m_ChunkSize(ChunkSize) -, m_ChunkAlignment(ChunkAlignment) -{ -} - -ChunkingLinearAllocator::~ChunkingLinearAllocator() -{ - Reset(); -} - -void -ChunkingLinearAllocator::Reset() -{ - for (void* ChunkEntry : m_ChunkList) - { - Memory::Free(ChunkEntry); - } - m_ChunkList.clear(); - - m_ChunkCursor = nullptr; - m_ChunkBytesRemain = 0; -} - -void* -ChunkingLinearAllocator::Alloc(size_t Size, size_t Alignment) -{ - ZEN_ASSERT_SLOW(zen::IsPow2(Alignment)); - - // This could be improved in a bunch of ways - // - // * We pessimistically allocate memory even though there may be enough memory available for a single allocation due to the way we take - // alignment into account below - // * The block allocation size could be chosen to minimize slack for the case when multiple oversize allocations are made rather than - // minimizing the number of chunks - // * ... - - const uint64_t AllocationSize = zen::RoundUp(Size, Alignment); - - if (m_ChunkBytesRemain < (AllocationSize + Alignment - 1)) - { - const uint64_t ChunkSize = zen::RoundUp(zen::Max(m_ChunkSize, Size), m_ChunkSize); - void* ChunkPtr = Memory::Alloc(ChunkSize, m_ChunkAlignment); - if (!ChunkPtr) - { - ThrowOutOfMemory(fmt::format("failed allocating {:#x} bytes aligned to {:#x}", ChunkSize, m_ChunkAlignment)); - } - m_ChunkCursor = reinterpret_cast<uint8_t*>(ChunkPtr); - m_ChunkBytesRemain = ChunkSize; - m_ChunkList.push_back(ChunkPtr); - } - - const uint64_t AlignFixup = (Alignment - reinterpret_cast<uintptr_t>(m_ChunkCursor)) & (Alignment - 1); - void* ReturnPtr = m_ChunkCursor + AlignFixup; - const uint64_t Delta = AlignFixup + AllocationSize; - - ZEN_ASSERT_SLOW(m_ChunkBytesRemain >= Delta); - - m_ChunkCursor += Delta; - m_ChunkBytesRemain -= Delta; - - ZEN_ASSERT_SLOW(IsPointerAligned(ReturnPtr, Alignment)); - - return ReturnPtr; -} - -////////////////////////////////////////////////////////////////////////// -// -// Unit tests -// - -#if ZEN_WITH_TESTS - -TEST_CASE("ChunkingLinearAllocator") -{ - ChunkingLinearAllocator Allocator(4096); - - void* p1 = Allocator.Alloc(1, 1); - void* p2 = Allocator.Alloc(1, 1); - - CHECK(p1 != p2); - - void* p3 = Allocator.Alloc(1, 4); - CHECK(IsPointerAligned(p3, 4)); - - void* p3_2 = Allocator.Alloc(1, 4); - CHECK(IsPointerAligned(p3_2, 4)); - - void* p4 = Allocator.Alloc(1, 8); - CHECK(IsPointerAligned(p4, 8)); - - for (int i = 0; i < 100; ++i) - { - void* p0 = Allocator.Alloc(64); - ZEN_UNUSED(p0); - } -} - -TEST_CASE("MemoryView") -{ - { - uint8_t Array1[16] = {}; - MemoryView View1 = MakeMemoryView(Array1); - CHECK(View1.GetSize() == 16); - } - - { - uint32_t Array2[16] = {}; - MemoryView View2 = MakeMemoryView(Array2); - CHECK(View2.GetSize() == 64); - } - - CHECK(MakeMemoryView<float>({1.0f, 1.2f}).GetSize() == 8); -} - -void -memory_forcelink() -{ -} - -#endif - -} // namespace zen diff --git a/src/zencore/memory/fmalloc.cpp b/src/zencore/memory/fmalloc.cpp new file mode 100644 index 000000000..3e96003f5 --- /dev/null +++ b/src/zencore/memory/fmalloc.cpp @@ -0,0 +1,156 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <string.h> +#include <zencore/memory/fmalloc.h> +#include <zencore/memory/memory.h> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +class FInitialMalloc : public FMalloc +{ + virtual void* Malloc(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT) override + { + Memory::Initialize(); + return GMalloc->Malloc(Count, Alignment); + } + virtual void* TryMalloc(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT) override + { + Memory::Initialize(); + return GMalloc->TryMalloc(Count, Alignment); + } + virtual void* Realloc(void* Original, size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT) override + { + Memory::Initialize(); + return GMalloc->Realloc(Original, Count, Alignment); + } + virtual void* TryRealloc(void* Original, size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT) override + { + Memory::Initialize(); + return GMalloc->TryRealloc(Original, Count, Alignment); + } + virtual void Free(void* Original) override + { + Memory::Initialize(); + return GMalloc->Free(Original); + } + virtual void* MallocZeroed(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT) override + { + Memory::Initialize(); + return GMalloc->MallocZeroed(Count, Alignment); + } + + virtual void* TryMallocZeroed(size_t Count, uint32_t Alignment = DEFAULT_ALIGNMENT) override + { + Memory::Initialize(); + return GMalloc->TryMallocZeroed(Count, Alignment); + } + virtual size_t QuantizeSize(size_t Count, uint32_t Alignment) override + { + Memory::Initialize(); + return GMalloc->QuantizeSize(Count, Alignment); + } + virtual bool GetAllocationSize(void* Original, size_t& SizeOut) override + { + Memory::Initialize(); + return GMalloc->GetAllocationSize(Original, SizeOut); + } + virtual void OnMallocInitialized() override {} + virtual void Trim(bool bTrimThreadCaches) override { ZEN_UNUSED(bTrimThreadCaches); } +} GInitialMalloc; + +FMalloc* GMalloc = &GInitialMalloc; /* Memory allocator */ + +////////////////////////////////////////////////////////////////////////// + +void* +FUseSystemMallocForNew::operator new(size_t Size) +{ + return Memory::SystemMalloc(Size); +} + +void +FUseSystemMallocForNew::operator delete(void* Ptr) +{ + Memory::SystemFree(Ptr); +} + +void* +FUseSystemMallocForNew::operator new[](size_t Size) +{ + return Memory::SystemMalloc(Size); +} + +void +FUseSystemMallocForNew::operator delete[](void* Ptr) +{ + Memory::SystemFree(Ptr); +} + +////////////////////////////////////////////////////////////////////////// + +void* +FMalloc::TryRealloc(void* Original, size_t Count, uint32_t Alignment) +{ + return Realloc(Original, Count, Alignment); +} + +void* +FMalloc::TryMalloc(size_t Count, uint32_t Alignment) +{ + return Malloc(Count, Alignment); +} + +void* +FMalloc::TryMallocZeroed(size_t Count, uint32_t Alignment) +{ + return MallocZeroed(Count, Alignment); +} + +void* +FMalloc::MallocZeroed(size_t Count, uint32_t Alignment) +{ + void* const Memory = Malloc(Count, Alignment); + + if (Memory) + { + ::memset(Memory, 0, Count); + } + + return Memory; +} + +void +FMalloc::OutOfMemory(size_t Size, uint32_t Alignment) +{ + ZEN_UNUSED(Size, Alignment); + // no-op by default +} + +void +FMalloc::Trim(bool bTrimThreadCaches) +{ + ZEN_UNUSED(bTrimThreadCaches); +} + +void +FMalloc::OnMallocInitialized() +{ +} + +bool +FMalloc::GetAllocationSize(void* Original, size_t& SizeOut) +{ + ZEN_UNUSED(Original, SizeOut); + return false; // Generic implementation has no way of determining this +} + +size_t +FMalloc::QuantizeSize(size_t Count, uint32_t Alignment) +{ + ZEN_UNUSED(Alignment); + return Count; // Generic implementation has no way of determining this +} + +} // namespace zen diff --git a/src/zencore/memory/llm.cpp b/src/zencore/memory/llm.cpp new file mode 100644 index 000000000..61fa29a66 --- /dev/null +++ b/src/zencore/memory/llm.cpp @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/memory/llm.h> + +#include <zencore/string.h> +#include <zencore/thread.h> + +#include <atomic> + +namespace zen { + +static std::atomic<int32_t> CustomTagCounter = 257; // NOTE: hard-coded TRACE_TAG = 257 + +static const int32_t TagNamesBaseIndex = 256; +static const int32_t TrackedTagNameCount = 256; +static const char* TagNames[TrackedTagNameCount]; +static uint32_t TagNameHashes[TrackedTagNameCount]; +static int32_t ParentTags[TrackedTagNameCount]; + +static RwLock TableLock; + +FLLMTag::FLLMTag(const char* TagName) +{ + // NOTE: should add verification to prevent multiple definitions of same name? + + AssignAndAnnounceNewTag(TagName); +} + +FLLMTag::FLLMTag(const char* TagName, const FLLMTag& ParentTag) +{ + // NOTE: should add verification to prevent multiple definitions of same name? + + m_ParentTag = ParentTag.GetTag(); + + AssignAndAnnounceNewTag(TagName); +} + +void +FLLMTag::AssignAndAnnounceNewTag(const char* TagName) +{ + const uint32_t TagNameHash = HashStringDjb2(TagName); + + { + RwLock::ExclusiveLockScope _(TableLock); + + const int32_t CurrentMaxTagIndex = CustomTagCounter - TagNamesBaseIndex; + + for (int TagIndex = 0; TagIndex <= CurrentMaxTagIndex; ++TagIndex) + { + if (TagNameHashes[TagIndex] == TagNameHash && ParentTags[TagIndex] == m_ParentTag) + { + m_Tag = TagIndex + TagNamesBaseIndex; + // could verify the string matches here to catch hash collisions + + // return early, no need to announce the tag as it is already known + return; + } + } + + m_Tag = ++CustomTagCounter; + + const int TagIndex = m_Tag - TagNamesBaseIndex; + + if (TagIndex < TrackedTagNameCount) + { + TagNameHashes[TagIndex] = TagNameHash; + TagNames[TagIndex] = TagName; + ParentTags[TagIndex] = m_ParentTag; + } + else + { + // should really let user know there's an overflow + } + } + + MemoryTrace_AnnounceCustomTag(m_Tag, m_ParentTag, TagName); +} + +} // namespace zen diff --git a/src/zencore/memory/mallocansi.cpp b/src/zencore/memory/mallocansi.cpp new file mode 100644 index 000000000..9c3936172 --- /dev/null +++ b/src/zencore/memory/mallocansi.cpp @@ -0,0 +1,251 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/memory/mallocansi.h> + +#include <zencore/intmath.h> +#include <zencore/memory/align.h> +#include <zencore/windows.h> + +#if ZEN_PLATFORM_LINUX +# define PLATFORM_USE_ANSI_POSIX_MALLOC 1 +#endif + +#if ZEN_PLATFORM_MAC +# define PLATFORM_USE_CUSTOM_MEMALIGN 1 +#endif + +#ifndef PLATFORM_USE_ANSI_MEMALIGN +# define PLATFORM_USE_ANSI_MEMALIGN 0 +#endif + +#ifndef PLATFORM_USE_ANSI_POSIX_MALLOC +# define PLATFORM_USE_ANSI_POSIX_MALLOC 0 +#endif + +#ifndef PLATFORM_USE_CUSTOM_MEMALIGN +# define PLATFORM_USE_CUSTOM_MEMALIGN 0 +#endif + +#if PLATFORM_USE_ANSI_POSIX_MALLOC +# include <malloc.h> +# include <string.h> +#endif + +#define MALLOC_ANSI_USES__ALIGNED_MALLOC ZEN_PLATFORM_WINDOWS + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +void* +AnsiMalloc(size_t Size, uint32_t Alignment) +{ +#if MALLOC_ANSI_USES__ALIGNED_MALLOC + void* Result = _aligned_malloc(Size, Alignment); +#elif PLATFORM_USE_ANSI_POSIX_MALLOC + void* Result; + if (posix_memalign(&Result, Alignment, Size) != 0) + { + Result = nullptr; + } +#elif PLATFORM_USE_ANSI_MEMALIGN + Result = reallocalign(Ptr, NewSize, Alignment); +#elif PLATFORM_USE_CUSTOM_MEMALIGN + void* Ptr = malloc(Size + Alignment + sizeof(void*) + sizeof(size_t)); + void* Result = nullptr; + if (Ptr) + { + Result = Align((uint8_t*)Ptr + sizeof(void*) + sizeof(size_t), Alignment); + *((void**)((uint8_t*)Result - sizeof(void*))) = Ptr; + *((size_t*)((uint8_t*)Result - sizeof(void*) - sizeof(size_t))) = Size; + } +#else +# error Unknown allocation path +#endif + + return Result; +} + +size_t +AnsiGetAllocationSize(void* Original) +{ +#if MALLOC_ANSI_USES__ALIGNED_MALLOC + return _aligned_msize(Original, 16, 0); // TODO: incorrectly assumes alignment of 16 +#elif PLATFORM_USE_ANSI_POSIX_MALLOC || PLATFORM_USE_ANSI_MEMALIGN + return malloc_usable_size(Original); +#elif PLATFORM_USE_CUSTOM_MEMALIGN + return *((size_t*)((uint8_t*)Original - sizeof(void*) - sizeof(size_t))); +#else +# error Unknown allocation path +#endif +} + +void* +AnsiRealloc(void* Ptr, size_t NewSize, uint32_t Alignment) +{ + void* Result = nullptr; + +#if MALLOC_ANSI_USES__ALIGNED_MALLOC + if (Ptr && NewSize) + { + Result = _aligned_realloc(Ptr, NewSize, Alignment); + } + else if (Ptr == nullptr) + { + Result = _aligned_malloc(NewSize, Alignment); + } + else + { + _aligned_free(Ptr); + Result = nullptr; + } +#elif PLATFORM_USE_ANSI_POSIX_MALLOC + if (Ptr && NewSize) + { + size_t UsableSize = malloc_usable_size(Ptr); + if (posix_memalign(&Result, Alignment, NewSize) != 0) + { + Result = nullptr; + } + else if (UsableSize) + { + memcpy(Result, Ptr, Min(NewSize, UsableSize)); + } + free(Ptr); + } + else if (Ptr == nullptr) + { + if (posix_memalign(&Result, Alignment, NewSize) != 0) + { + Result = nullptr; + } + } + else + { + free(Ptr); + Result = nullptr; + } +#elif PLATFORM_USE_CUSTOM_MEMALIGN + if (Ptr && NewSize) + { + // Can't use realloc as it might screw with alignment. + Result = AnsiMalloc(NewSize, Alignment); + size_t PtrSize = AnsiGetAllocationSize(Ptr); + memcpy(Result, Ptr, Min(NewSize, PtrSize)); + AnsiFree(Ptr); + } + else if (Ptr == nullptr) + { + Result = AnsiMalloc(NewSize, Alignment); + } + else + { + free(*((void**)((uint8_t*)Ptr - sizeof(void*)))); + Result = nullptr; + } +#else +# error Unknown allocation path +#endif + + return Result; +} + +void +AnsiFree(void* Ptr) +{ +#if MALLOC_ANSI_USES__ALIGNED_MALLOC + _aligned_free(Ptr); +#elif PLATFORM_USE_ANSI_POSIX_MALLOC || PLATFORM_USE_ANSI_MEMALIGN + free(Ptr); +#elif PLATFORM_USE_CUSTOM_MEMALIGN + if (Ptr) + { + free(*((void**)((uint8_t*)Ptr - sizeof(void*)))); + } +#else +# error Unknown allocation path +#endif +} + +////////////////////////////////////////////////////////////////////////// + +FMallocAnsi::FMallocAnsi() +{ +#if ZEN_PLATFORM_WINDOWS + // Enable low fragmentation heap - http://msdn2.microsoft.com/en-US/library/aa366750.aspx + intptr_t CrtHeapHandle = _get_heap_handle(); + ULONG EnableLFH = 2; + HeapSetInformation((void*)CrtHeapHandle, HeapCompatibilityInformation, &EnableLFH, sizeof(EnableLFH)); +#endif +} + +void* +FMallocAnsi::TryMalloc(size_t Size, uint32_t Alignment) +{ + Alignment = Max(Size >= 16 ? (uint32_t)16 : (uint32_t)8, Alignment); + + void* Result = AnsiMalloc(Size, Alignment); + + return Result; +} + +void* +FMallocAnsi::Malloc(size_t Size, uint32_t Alignment) +{ + void* Result = TryMalloc(Size, Alignment); + + if (Result == nullptr && Size) + { + OutOfMemory(Size, Alignment); + } + + return Result; +} + +void* +FMallocAnsi::TryRealloc(void* Ptr, size_t NewSize, uint32_t Alignment) +{ + Alignment = Max(NewSize >= 16 ? (uint32_t)16 : (uint32_t)8, Alignment); + + void* Result = AnsiRealloc(Ptr, NewSize, Alignment); + + return Result; +} + +void* +FMallocAnsi::Realloc(void* Ptr, size_t NewSize, uint32_t Alignment) +{ + void* Result = TryRealloc(Ptr, NewSize, Alignment); + + if (Result == nullptr && NewSize != 0) + { + OutOfMemory(NewSize, Alignment); + } + + return Result; +} + +void +FMallocAnsi::Free(void* Ptr) +{ + AnsiFree(Ptr); +} + +bool +FMallocAnsi::GetAllocationSize(void* Original, size_t& SizeOut) +{ + if (!Original) + { + return false; + } + +#if MALLOC_ANSI_USES__ALIGNED_MALLOC + ZEN_UNUSED(SizeOut); + return false; +#else + SizeOut = AnsiGetAllocationSize(Original); + return true; +#endif +} + +} // namespace zen diff --git a/src/zencore/memory/mallocmimalloc.cpp b/src/zencore/memory/mallocmimalloc.cpp new file mode 100644 index 000000000..1f9aff404 --- /dev/null +++ b/src/zencore/memory/mallocmimalloc.cpp @@ -0,0 +1,199 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <cstring> + +#include <zencore/intmath.h> +#include <zencore/memory/align.h> +#include <zencore/memory/mallocmimalloc.h> + +#if ZEN_MIMALLOC_ENABLED + +# include <mimalloc.h> + +/** Value we fill a memory block with after it is free, in UE_BUILD_DEBUG **/ +# define DEBUG_FILL_FREED (0xdd) + +/** Value we fill a new memory block with, in UE_BUILD_DEBUG **/ +# define DEBUG_FILL_NEW (0xcd) + +# define ZEN_ENABLE_DEBUG_FILL 1 + +namespace zen { + +// Dramatically reduce memory zeroing and page faults during alloc intense workloads +// by keeping freed pages for a little while instead of releasing them +// right away to the OS, effectively acting like a scratch buffer +// until pages are both freed and inactive for the delay specified +// in milliseconds. +int32_t GMiMallocMemoryResetDelay = 10000; + +FMallocMimalloc::FMallocMimalloc() +{ + mi_option_set(mi_option_reset_delay, GMiMallocMemoryResetDelay); +} + +void* +FMallocMimalloc::TryMalloc(size_t Size, uint32_t Alignment) +{ + void* NewPtr = nullptr; + + if (Alignment != DEFAULT_ALIGNMENT) + { + Alignment = Max(uint32_t(Size >= 16 ? 16 : 8), Alignment); + NewPtr = mi_malloc_aligned(Size, Alignment); + } + else + { + NewPtr = mi_malloc_aligned(Size, uint32_t(Size >= 16 ? 16 : 8)); + } + +# if ZEN_BUILD_DEBUG && ZEN_ENABLE_DEBUG_FILL + if (Size && NewPtr != nullptr) + { + memset(NewPtr, DEBUG_FILL_NEW, mi_usable_size(NewPtr)); + } +# endif + + return NewPtr; +} + +void* +FMallocMimalloc::Malloc(size_t Size, uint32_t Alignment) +{ + void* Result = TryMalloc(Size, Alignment); + + if (Result == nullptr && Size) + { + OutOfMemory(Size, Alignment); + } + + return Result; +} + +void* +FMallocMimalloc::TryRealloc(void* Ptr, size_t NewSize, uint32_t Alignment) +{ +# if ZEN_BUILD_DEBUG && ZEN_ENABLE_DEBUG_FILL + size_t OldSize = 0; + if (Ptr) + { + OldSize = mi_malloc_size(Ptr); + if (NewSize < OldSize) + { + memset((uint8_t*)Ptr + NewSize, DEBUG_FILL_FREED, OldSize - NewSize); + } + } +# endif + void* NewPtr = nullptr; + + if (NewSize == 0) + { + mi_free(Ptr); + + return nullptr; + } + +# if ZEN_PLATFORM_MAC + // macOS expects all allocations to be aligned to 16 bytes, so on Mac we always have to use mi_realloc_aligned + Alignment = AlignArbitrary(Max((uint32_t)16, Alignment), (uint32_t)16); + NewPtr = mi_realloc_aligned(Ptr, NewSize, Alignment); +# else + if (Alignment != DEFAULT_ALIGNMENT) + { + Alignment = Max(NewSize >= 16 ? (uint32_t)16 : (uint32_t)8, Alignment); + NewPtr = mi_realloc_aligned(Ptr, NewSize, Alignment); + } + else + { + NewPtr = mi_realloc(Ptr, NewSize); + } +# endif + +# if ZEN_BUILD_DEBUG && ZEN_ENABLE_DEBUG_FILL + if (NewPtr && NewSize > OldSize) + { + memset((uint8_t*)NewPtr + OldSize, DEBUG_FILL_NEW, mi_usable_size(NewPtr) - OldSize); + } +# endif + + return NewPtr; +} + +void* +FMallocMimalloc::Realloc(void* Ptr, size_t NewSize, uint32_t Alignment) +{ + void* Result = TryRealloc(Ptr, NewSize, Alignment); + + if (Result == nullptr && NewSize) + { + OutOfMemory(NewSize, Alignment); + } + + return Result; +} + +void +FMallocMimalloc::Free(void* Ptr) +{ + if (!Ptr) + { + return; + } + +# if ZEN_BUILD_DEBUG && ZEN_ENABLE_DEBUG_FILL + memset(Ptr, DEBUG_FILL_FREED, mi_usable_size(Ptr)); +# endif + + mi_free(Ptr); +} + +void* +FMallocMimalloc::MallocZeroed(size_t Size, uint32_t Alignment) +{ + void* Result = TryMallocZeroed(Size, Alignment); + + if (Result == nullptr && Size) + { + OutOfMemory(Size, Alignment); + } + + return Result; +} + +void* +FMallocMimalloc::TryMallocZeroed(size_t Size, uint32_t Alignment) +{ + void* NewPtr = nullptr; + + if (Alignment != DEFAULT_ALIGNMENT) + { + Alignment = Max(uint32_t(Size >= 16 ? 16 : 8), Alignment); + NewPtr = mi_zalloc_aligned(Size, Alignment); + } + else + { + NewPtr = mi_zalloc_aligned(Size, uint32_t(Size >= 16 ? 16 : 8)); + } + + return NewPtr; +} + +bool +FMallocMimalloc::GetAllocationSize(void* Original, size_t& SizeOut) +{ + SizeOut = mi_malloc_size(Original); + return true; +} + +void +FMallocMimalloc::Trim(bool bTrimThreadCaches) +{ + mi_collect(bTrimThreadCaches); +} + +# undef DEBUG_FILL_FREED +# undef DEBUG_FILL_NEW + +} // namespace zen + +#endif // MIMALLOC_ENABLED diff --git a/src/zencore/memory/mallocrpmalloc.cpp b/src/zencore/memory/mallocrpmalloc.cpp new file mode 100644 index 000000000..ffced27c9 --- /dev/null +++ b/src/zencore/memory/mallocrpmalloc.cpp @@ -0,0 +1,189 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/intmath.h> +#include <zencore/memory/align.h> +#include <zencore/memory/mallocrpmalloc.h> + +#if ZEN_RPMALLOC_ENABLED + +# include "rpmalloc.h" + +/** Value we fill a memory block with after it is free, in UE_BUILD_DEBUG **/ +# define DEBUG_FILL_FREED (0xdd) + +/** Value we fill a new memory block with, in UE_BUILD_DEBUG **/ +# define DEBUG_FILL_NEW (0xcd) + +# define ZEN_ENABLE_DEBUG_FILL 1 + +namespace zen { + +FMallocRpmalloc::FMallocRpmalloc() +{ + rpmalloc_initialize(nullptr); +} + +FMallocRpmalloc::~FMallocRpmalloc() +{ + rpmalloc_finalize(); +} + +void* +FMallocRpmalloc::TryMalloc(size_t Size, uint32_t Alignment) +{ + void* NewPtr = nullptr; + + if (Alignment != DEFAULT_ALIGNMENT) + { + Alignment = Max(uint32_t(Size >= 16 ? 16 : 8), Alignment); + NewPtr = rpaligned_alloc(Alignment, Size); + } + else + { + NewPtr = rpaligned_alloc(uint32_t(Size >= 16 ? 16 : 8), Size); + } + +# if ZEN_BUILD_DEBUG && ZEN_ENABLE_DEBUG_FILL + if (Size && NewPtr != nullptr) + { + memset(NewPtr, DEBUG_FILL_NEW, rpmalloc_usable_size(NewPtr)); + } +# endif + + return NewPtr; +} + +void* +FMallocRpmalloc::Malloc(size_t Size, uint32_t Alignment) +{ + void* Result = TryMalloc(Size, Alignment); + + if (Result == nullptr && Size) + { + OutOfMemory(Size, Alignment); + } + + return Result; +} + +void* +FMallocRpmalloc::Realloc(void* Ptr, size_t NewSize, uint32_t Alignment) +{ + void* Result = TryRealloc(Ptr, NewSize, Alignment); + + if (Result == nullptr && NewSize) + { + OutOfMemory(NewSize, Alignment); + } + + return Result; +} + +void* +FMallocRpmalloc::TryRealloc(void* Ptr, size_t NewSize, uint32_t Alignment) +{ +# if ZEN_BUILD_DEBUG && ZEN_ENABLE_DEBUG_FILL + size_t OldSize = 0; + if (Ptr) + { + OldSize = rpmalloc_usable_size(Ptr); + if (NewSize < OldSize) + { + memset((uint8_t*)Ptr + NewSize, DEBUG_FILL_FREED, OldSize - NewSize); + } + } +# endif + void* NewPtr = nullptr; + + if (NewSize == 0) + { + rpfree(Ptr); + + return nullptr; + } + +# if ZEN_PLATFORM_MAC + // macOS expects all allocations to be aligned to 16 bytes, so on Mac we always have to use mi_realloc_aligned + Alignment = AlignArbitrary(Max((uint32_t)16, Alignment), (uint32_t)16); + NewPtr = rpaligned_realloc(Ptr, Alignment, NewSize, /* OldSize */ 0, /* flags */ 0); +# else + if (Alignment != DEFAULT_ALIGNMENT) + { + Alignment = Max(NewSize >= 16 ? (uint32_t)16 : (uint32_t)8, Alignment); + NewPtr = rpaligned_realloc(Ptr, Alignment, NewSize, /* OldSize */ 0, /* flags */ 0); + } + else + { + NewPtr = rprealloc(Ptr, NewSize); + } +# endif + +# if ZEN_BUILD_DEBUG && ZEN_ENABLE_DEBUG_FILL + if (NewPtr && NewSize > OldSize) + { + memset((uint8_t*)NewPtr + OldSize, DEBUG_FILL_NEW, rpmalloc_usable_size(NewPtr) - OldSize); + } +# endif + + return NewPtr; +} + +void +FMallocRpmalloc::Free(void* Ptr) +{ + if (!Ptr) + { + return; + } + +# if ZEN_BUILD_DEBUG && ZEN_ENABLE_DEBUG_FILL + memset(Ptr, DEBUG_FILL_FREED, rpmalloc_usable_size(Ptr)); +# endif + + rpfree(Ptr); +} + +void* +FMallocRpmalloc::MallocZeroed(size_t Size, uint32_t Alignment) +{ + void* Result = TryMallocZeroed(Size, Alignment); + + if (Result == nullptr && Size) + { + OutOfMemory(Size, Alignment); + } + + return Result; +} +void* +FMallocRpmalloc::TryMallocZeroed(size_t Size, uint32_t Alignment) +{ + void* NewPtr = nullptr; + + if (Alignment != DEFAULT_ALIGNMENT) + { + Alignment = Max(uint32_t(Size >= 16 ? 16 : 8), Alignment); + NewPtr = rpaligned_zalloc(Alignment, Size); + } + else + { + NewPtr = rpaligned_zalloc(uint32_t(Size >= 16 ? 16 : 8), Size); + } + + return NewPtr; +} + +bool +FMallocRpmalloc::GetAllocationSize(void* Original, size_t& SizeOut) +{ + // this is not the same as the allocation size - is that ok? + SizeOut = rpmalloc_usable_size(Original); + return true; +} +void +FMallocRpmalloc::Trim(bool bTrimThreadCaches) +{ + ZEN_UNUSED(bTrimThreadCaches); +} +} // namespace zen +#endif diff --git a/src/zencore/memory/mallocstomp.cpp b/src/zencore/memory/mallocstomp.cpp new file mode 100644 index 000000000..db9e1535e --- /dev/null +++ b/src/zencore/memory/mallocstomp.cpp @@ -0,0 +1,283 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/memory/mallocstomp.h> + +#if ZEN_WITH_MALLOC_STOMP + +# include <zencore/memory/align.h> +# include <zencore/xxhash.h> + +# if ZEN_PLATFORM_LINUX +# include <sys/mman.h> +# endif + +# if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# endif + +# if ZEN_PLATFORM_WINDOWS +// MallocStomp can keep virtual address range reserved after memory block is freed, while releasing the physical memory. +// This dramatically increases accuracy of use-after-free detection, but consumes significant amount of memory for the OS page table. +// Virtual memory limit for a process on Win10 is 128 TB, which means we can afford to keep virtual memory reserved for a very long time. +// Running Infiltrator demo consumes ~700MB of virtual address space per second. +# define MALLOC_STOMP_KEEP_VIRTUAL_MEMORY 1 +# else +# define MALLOC_STOMP_KEEP_VIRTUAL_MEMORY 0 +# endif + +// 64-bit ABIs on x86_64 expect a 16-byte alignment +# define STOMPALIGNMENT 16U + +namespace zen { + +struct FMallocStomp::FAllocationData +{ + /** Pointer to the full allocation. Needed so the OS knows what to free. */ + void* FullAllocationPointer; + /** Full size of the allocation including the extra page. */ + size_t FullSize; + /** Size of the allocation requested. */ + size_t Size; + /** Sentinel used to check for underrun. */ + size_t Sentinel; + + /** Calculate the expected sentinel value for this allocation data. */ + size_t CalculateSentinel() const + { + XXH3_128 Xxh = XXH3_128::HashMemory(this, offsetof(FAllocationData, Sentinel)); + + size_t Hash; + memcpy(&Hash, Xxh.Hash, sizeof(Hash)); + + return Hash; + } +}; + +FMallocStomp::FMallocStomp(const bool InUseUnderrunMode) : PageSize(4096 /* TODO: make dynamic */), bUseUnderrunMode(InUseUnderrunMode) +{ +} + +void* +FMallocStomp::Malloc(size_t Size, uint32_t Alignment) +{ + void* Result = TryMalloc(Size, Alignment); + + if (Result == nullptr) + { + OutOfMemory(Size, Alignment); + } + + return Result; +} + +void* +FMallocStomp::TryMalloc(size_t Size, uint32_t Alignment) +{ + if (Size == 0U) + { + Size = 1U; + } + + Alignment = Max<uint32_t>(Alignment, STOMPALIGNMENT); + + constexpr static size_t AllocationDataSize = sizeof(FAllocationData); + + const size_t AlignedSize = Alignment ? ((Size + Alignment - 1) & -(int32_t)Alignment) : Size; + const size_t AlignmentSize = Alignment > PageSize ? Alignment - PageSize : 0; + const size_t AllocFullPageSize = (AlignedSize + AlignmentSize + AllocationDataSize + PageSize - 1) & ~(PageSize - 1); + const size_t TotalAllocationSize = AllocFullPageSize + PageSize; + +# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + void* FullAllocationPointer = mmap(nullptr, TotalAllocationSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); +# elif ZEN_PLATFORM_WINDOWS && MALLOC_STOMP_KEEP_VIRTUAL_MEMORY + // Allocate virtual address space from current block using linear allocation strategy. + // If there is not enough space, try to allocate new block from OS. Report OOM if block allocation fails. + void* FullAllocationPointer = nullptr; + + { + RwLock::ExclusiveLockScope _(Lock); + + if (VirtualAddressCursor + TotalAllocationSize <= VirtualAddressMax) + { + FullAllocationPointer = (void*)(VirtualAddressCursor); + } + else + { + const size_t ReserveSize = Max(VirtualAddressBlockSize, TotalAllocationSize); + + // Reserve a new block of virtual address space that will be linearly sub-allocated + // We intentionally don't keep track of reserved blocks, as we never need to explicitly release them. + FullAllocationPointer = VirtualAlloc(nullptr, ReserveSize, MEM_RESERVE, PAGE_NOACCESS); + + VirtualAddressCursor = uintptr_t(FullAllocationPointer); + VirtualAddressMax = VirtualAddressCursor + ReserveSize; + } + + VirtualAddressCursor += TotalAllocationSize; + } +# else + void* FullAllocationPointer = FPlatformMemory::BinnedAllocFromOS(TotalAllocationSize); +# endif // PLATFORM_UNIX || PLATFORM_MAC + + if (!FullAllocationPointer) + { + return nullptr; + } + + void* ReturnedPointer = nullptr; + + ZEN_ASSERT_SLOW(IsAligned(FullAllocationPointer, PageSize)); + + if (bUseUnderrunMode) + { + ReturnedPointer = Align((uint8_t*)FullAllocationPointer + PageSize + AllocationDataSize, Alignment); + void* AllocDataPointerStart = static_cast<FAllocationData*>(ReturnedPointer) - 1; + ZEN_ASSERT_SLOW(AllocDataPointerStart >= FullAllocationPointer); + +# if ZEN_PLATFORM_WINDOWS && MALLOC_STOMP_KEEP_VIRTUAL_MEMORY + // Commit physical pages to the used range, leaving the first page unmapped. + void* CommittedMemory = VirtualAlloc(AllocDataPointerStart, AllocationDataSize + AlignedSize, MEM_COMMIT, PAGE_READWRITE); + if (!CommittedMemory) + { + // Failed to allocate and commit physical memory pages. + return nullptr; + } + ZEN_ASSERT(CommittedMemory == AlignDown(AllocDataPointerStart, PageSize)); +# else + // Page protect the first page, this will cause the exception in case there is an underrun. + FPlatformMemory::PageProtect((uint8*)AlignDown(AllocDataPointerStart, PageSize) - PageSize, PageSize, false, false); +# endif + } //-V773 + else + { + ReturnedPointer = AlignDown((uint8_t*)FullAllocationPointer + AllocFullPageSize - AlignedSize, Alignment); + void* ReturnedPointerEnd = (uint8_t*)ReturnedPointer + AlignedSize; + ZEN_ASSERT_SLOW(IsAligned(ReturnedPointerEnd, PageSize)); + + void* AllocDataPointerStart = static_cast<FAllocationData*>(ReturnedPointer) - 1; + ZEN_ASSERT_SLOW(AllocDataPointerStart >= FullAllocationPointer); + +# if ZEN_PLATFORM_WINDOWS && MALLOC_STOMP_KEEP_VIRTUAL_MEMORY + // Commit physical pages to the used range, leaving the last page unmapped. + void* CommitPointerStart = AlignDown(AllocDataPointerStart, PageSize); + void* CommittedMemory = VirtualAlloc(CommitPointerStart, + size_t((uint8_t*)ReturnedPointerEnd - (uint8_t*)CommitPointerStart), + MEM_COMMIT, + PAGE_READWRITE); + if (!CommittedMemory) + { + // Failed to allocate and commit physical memory pages. + return nullptr; + } + ZEN_ASSERT(CommittedMemory == CommitPointerStart); +# else + // Page protect the last page, this will cause the exception in case there is an overrun. + FPlatformMemory::PageProtect(ReturnedPointerEnd, PageSize, false, false); +# endif + } //-V773 + + ZEN_ASSERT_SLOW(IsAligned(FullAllocationPointer, PageSize)); + ZEN_ASSERT_SLOW(IsAligned(TotalAllocationSize, PageSize)); + ZEN_ASSERT_SLOW(IsAligned(ReturnedPointer, Alignment)); + ZEN_ASSERT_SLOW((uint8_t*)ReturnedPointer + AlignedSize <= (uint8_t*)FullAllocationPointer + TotalAllocationSize); + + FAllocationData& AllocationData = static_cast<FAllocationData*>(ReturnedPointer)[-1]; + AllocationData = {FullAllocationPointer, TotalAllocationSize, AlignedSize, 0}; + AllocationData.Sentinel = AllocationData.CalculateSentinel(); + + return ReturnedPointer; +} + +void* +FMallocStomp::Realloc(void* InPtr, size_t NewSize, uint32_t Alignment) +{ + void* Result = TryRealloc(InPtr, NewSize, Alignment); + + if (Result == nullptr && NewSize) + { + OutOfMemory(NewSize, Alignment); + } + + return Result; +} + +void* +FMallocStomp::TryRealloc(void* InPtr, size_t NewSize, uint32_t Alignment) +{ + if (NewSize == 0U) + { + Free(InPtr); + return nullptr; + } + + void* ReturnPtr = nullptr; + + if (InPtr != nullptr) + { + ReturnPtr = TryMalloc(NewSize, Alignment); + + if (ReturnPtr != nullptr) + { + FAllocationData* AllocDataPtr = reinterpret_cast<FAllocationData*>(reinterpret_cast<uint8_t*>(InPtr) - sizeof(FAllocationData)); + memcpy(ReturnPtr, InPtr, Min(AllocDataPtr->Size, NewSize)); + Free(InPtr); + } + } + else + { + ReturnPtr = TryMalloc(NewSize, Alignment); + } + + return ReturnPtr; +} + +void +FMallocStomp::Free(void* InPtr) +{ + if (InPtr == nullptr) + { + return; + } + + FAllocationData* AllocDataPtr = reinterpret_cast<FAllocationData*>(InPtr); + AllocDataPtr--; + + // Check the sentinel to verify that the allocation data is intact. + if (AllocDataPtr->Sentinel != AllocDataPtr->CalculateSentinel()) + { + // There was a memory underrun related to this allocation. + ZEN_DEBUG_BREAK(); + } + +# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + munmap(AllocDataPtr->FullAllocationPointer, AllocDataPtr->FullSize); +# elif ZEN_PLATFORM_WINDOWS && MALLOC_STOMP_KEEP_VIRTUAL_MEMORY + // Unmap physical memory, but keep virtual address range reserved to catch use-after-free errors. + + VirtualFree(AllocDataPtr->FullAllocationPointer, AllocDataPtr->FullSize, MEM_DECOMMIT); + +# else + FPlatformMemory::BinnedFreeToOS(AllocDataPtr->FullAllocationPointer, AllocDataPtr->FullSize); +# endif // PLATFORM_UNIX || PLATFORM_MAC +} + +bool +FMallocStomp::GetAllocationSize(void* Original, size_t& SizeOut) +{ + if (Original == nullptr) + { + SizeOut = 0U; + } + else + { + FAllocationData* AllocDataPtr = reinterpret_cast<FAllocationData*>(Original); + AllocDataPtr--; + SizeOut = AllocDataPtr->Size; + } + + return true; +} + +} // namespace zen + +#endif // WITH_MALLOC_STOMP diff --git a/src/zencore/memory/memory.cpp b/src/zencore/memory/memory.cpp new file mode 100644 index 000000000..fced2a4d3 --- /dev/null +++ b/src/zencore/memory/memory.cpp @@ -0,0 +1,312 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/commandline.h> +#include <zencore/memory/fmalloc.h> +#include <zencore/memory/mallocansi.h> +#include <zencore/memory/mallocmimalloc.h> +#include <zencore/memory/mallocrpmalloc.h> +#include <zencore/memory/mallocstomp.h> +#include <zencore/memory/memory.h> +#include <zencore/memory/memorytrace.h> +#include <zencore/string.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +ZEN_THIRD_PARTY_INCLUDES_START +# include <shellapi.h> // For command line parsing +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +#if ZEN_PLATFORM_LINUX +# include <stdio.h> +#endif + +namespace zen { + +enum class MallocImpl +{ + None = 0, + Ansi, + Stomp, + Mimalloc, + Rpmalloc +}; + +static int +InitGMalloc() +{ + MallocImpl Malloc = MallocImpl::None; + FMalloc* InitMalloc = GMalloc; + + // Pick a default base allocator based on availability/platform + +#if ZEN_MIMALLOC_ENABLED + if (Malloc == MallocImpl::None) + { + Malloc = MallocImpl::Mimalloc; + } +#endif + +#if ZEN_RPMALLOC_ENABLED + if (Malloc == MallocImpl::None) + { + Malloc = MallocImpl::Rpmalloc; + } +#endif + + // Process any command line overrides + // + // Note that calls can come into this function before we enter the regular main function + // and we can therefore not rely on the regular command line parsing for the application + + using namespace std::literals; + + auto ProcessMallocArg = [&](const std::string_view& Arg) { +#if ZEN_RPMALLOC_ENABLED + if (Arg == "rpmalloc"sv) + { + Malloc = MallocImpl::Rpmalloc; + } +#endif + +#if ZEN_MIMALLOC_ENABLED + if (Arg == "mimalloc"sv) + { + Malloc = MallocImpl::Mimalloc; + } +#endif + + if (Arg == "ansi"sv) + { + Malloc = MallocImpl::Ansi; + } + + if (Arg == "stomp"sv) + { + Malloc = MallocImpl::Stomp; + } + }; + + constexpr std::string_view MallocOption = "--malloc"sv; + + std::function<void(const std::string_view&)> ProcessArg = [&](const std::string_view& Arg) { + if (Arg.starts_with(MallocOption)) + { + std::string_view::value_type DelimChar = Arg[MallocOption.length()]; + + if (DelimChar == ' ' || DelimChar == '=') + { + const std::string_view OptionArgs = Arg.substr(MallocOption.size() + 1); + + IterateCommaSeparatedValue(OptionArgs, ProcessMallocArg); + } + } + }; + + IterateCommandlineArgs(ProcessArg); + + switch (Malloc) + { +#if ZEN_WITH_MALLOC_STOMP + case MallocImpl::Stomp: + GMalloc = new FMallocStomp(); + break; +#endif + +#if ZEN_RPMALLOC_ENABLED + case MallocImpl::Rpmalloc: + GMalloc = new FMallocRpmalloc(); + break; +#endif + +#if ZEN_MIMALLOC_ENABLED + case MallocImpl::Mimalloc: + GMalloc = new FMallocMimalloc(); + break; +#endif + default: + break; + } + + if (GMalloc == InitMalloc) + { + GMalloc = new FMallocAnsi(); + } + + return 1; +} + +void +Memory::GCreateMalloc() +{ + static int InitFlag = InitGMalloc(); +} + +void +Memory::Initialize() +{ + GCreateMalloc(); +} + +////////////////////////////////////////////////////////////////////////// + +void* +Memory::SystemMalloc(size_t Size) +{ + void* Ptr = ::malloc(Size); + MemoryTrace_Alloc(uint64_t(Ptr), Size, 0, EMemoryTraceRootHeap::SystemMemory); + return Ptr; +} + +void +Memory::SystemFree(void* Ptr) +{ + MemoryTrace_Free(uint64_t(Ptr), EMemoryTraceRootHeap::SystemMemory); + ::free(Ptr); +} + +} // namespace zen + +////////////////////////////////////////////////////////////////////////// + +static ZEN_NOINLINE bool +InvokeNewHandler(bool NoThrow) +{ + std::new_handler h = std::get_new_handler(); + + if (!h) + { +#if defined(_CPPUNWIND) || defined(__cpp_exceptions) + if (NoThrow == false) + throw std::bad_alloc(); +#else + ZEN_UNUSED(NoThrow); +#endif + return false; + } + else + { + h(); + return true; + } +} + +////////////////////////////////////////////////////////////////////////// + +ZEN_NOINLINE void* +RetryNew(size_t Size, bool NoThrow) +{ + void* Ptr = nullptr; + while (!Ptr && InvokeNewHandler(NoThrow)) + { + Ptr = zen::Memory::Malloc(Size, zen::DEFAULT_ALIGNMENT); + } + return Ptr; +} + +void* +zen_new(size_t Size) +{ + void* Ptr = zen::Memory::Malloc(Size, zen::DEFAULT_ALIGNMENT); + + if (!Ptr) [[unlikely]] + { + const bool NoThrow = false; + return RetryNew(Size, NoThrow); + } + + return Ptr; +} + +void* +zen_new_nothrow(size_t Size) noexcept +{ + void* Ptr = zen::Memory::Malloc(Size, zen::DEFAULT_ALIGNMENT); + + if (!Ptr) [[unlikely]] + { + const bool NoThrow = true; + return RetryNew(Size, NoThrow); + } + + return Ptr; +} + +void* +zen_new_aligned(size_t Size, size_t Alignment) +{ + void* Ptr; + + do + { + Ptr = zen::Memory::Malloc(Size, uint32_t(Alignment)); + } while (!Ptr && InvokeNewHandler(/* NoThrow */ false)); + + return Ptr; +} + +void* +zen_new_aligned_nothrow(size_t Size, size_t Alignment) noexcept +{ + void* Ptr; + + do + { + Ptr = zen::Memory::Malloc(Size, uint32_t(Alignment)); + } while (!Ptr && InvokeNewHandler(/* NoThrow */ true)); + + return Ptr; +} + +void +zen_free(void* Ptr) noexcept +{ + zen::Memory::Free(Ptr); +} + +void +zen_free_size(void* Ptr, size_t Size) noexcept +{ + ZEN_UNUSED(Size); + zen::Memory::Free(Ptr); +} + +void +zen_free_size_aligned(void* Ptr, size_t Size, size_t Alignment) noexcept +{ + ZEN_UNUSED(Size, Alignment); + zen::Memory::Free(Ptr); +} + +void +zen_free_aligned(void* Ptr, size_t Alignment) noexcept +{ + ZEN_UNUSED(Alignment); + zen::Memory::Free(Ptr); +} + +// EASTL operator new + +void* +operator new[](size_t size, const char* pName, int flags, unsigned debugFlags, const char* file, int line) +{ + ZEN_UNUSED(pName, flags, debugFlags, file, line); + return zen_new(size); +} + +void* +operator new[](size_t size, + size_t alignment, + size_t alignmentOffset, + const char* pName, + int flags, + unsigned debugFlags, + const char* file, + int line) +{ + ZEN_UNUSED(alignmentOffset, pName, flags, debugFlags, file, line); + + ZEN_ASSERT_SLOW(alignmentOffset == 0); // currently not supported + + return zen_new_aligned(size, alignment); +} diff --git a/src/zencore/memoryview.cpp b/src/zencore/memoryview.cpp new file mode 100644 index 000000000..1f6a6996c --- /dev/null +++ b/src/zencore/memoryview.cpp @@ -0,0 +1,45 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/except.h> +#include <zencore/fmtutils.h> +#include <zencore/intmath.h> +#include <zencore/memory/memory.h> +#include <zencore/memoryview.h> +#include <zencore/testing.h> +#include <zencore/zencore.h> + +#include <cstdlib> + +namespace zen { + +// +// Unit tests +// + +#if ZEN_WITH_TESTS + +TEST_CASE("MemoryView") +{ + { + uint8_t Array1[16] = {}; + MemoryView View1 = MakeMemoryView(Array1); + CHECK(View1.GetSize() == 16); + } + + { + uint32_t Array2[16] = {}; + MemoryView View2 = MakeMemoryView(Array2); + CHECK(View2.GetSize() == 64); + } + + CHECK(MakeMemoryView<float>({1.0f, 1.2f}).GetSize() == 8); +} + +void +memory_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zencore/memtrack/callstacktrace.cpp b/src/zencore/memtrack/callstacktrace.cpp new file mode 100644 index 000000000..d860c05d1 --- /dev/null +++ b/src/zencore/memtrack/callstacktrace.cpp @@ -0,0 +1,1059 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "callstacktrace.h" + +#include <zenbase/zenbase.h> +#include <zencore/string.h> + +#if UE_CALLSTACK_TRACE_ENABLED + +namespace zen { + +// Platform implementations of back tracing +//////////////////////////////////////////////////////////////////////////////// +void CallstackTrace_CreateInternal(FMalloc*); +void CallstackTrace_InitializeInternal(); + +//////////////////////////////////////////////////////////////////////////////// +UE_TRACE_CHANNEL_DEFINE(CallstackChannel) +UE_TRACE_EVENT_DEFINE(Memory, CallstackSpec) + +uint32 GCallStackTracingTlsSlotIndex = FPlatformTLS::InvalidTlsSlot; + +//////////////////////////////////////////////////////////////////////////////// +void +CallstackTrace_Create(class FMalloc* InMalloc) +{ + static auto InitOnce = [&] { + CallstackTrace_CreateInternal(InMalloc); + return true; + }(); +} + +//////////////////////////////////////////////////////////////////////////////// +void +CallstackTrace_Initialize() +{ + GCallStackTracingTlsSlotIndex = FPlatformTLS::AllocTlsSlot(); + + static auto InitOnce = [&] { + CallstackTrace_InitializeInternal(); + return true; + }(); +} + +} // namespace zen + +#endif + +#if ZEN_PLATFORM_WINDOWS +# include "moduletrace.h" + +# include "growonlylockfreehash.h" + +# include <zencore/scopeguard.h> +# include <zencore/thread.h> +# include <zencore/trace.h> + +# include <atomic> +# include <span> + +# include <zencore/windows.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <winnt.h> +# include <winternl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +# ifndef UE_CALLSTACK_TRACE_FULL_CALLSTACKS +# define UE_CALLSTACK_TRACE_FULL_CALLSTACKS 0 +# endif + +// 0=off, 1=stats, 2=validation, 3=truth_compare +# define BACKTRACE_DBGLVL 0 + +# define BACKTRACE_LOCK_FREE (1 && (BACKTRACE_DBGLVL == 0)) + +static bool GModulesAreInitialized = false; + +// This implementation is using unwind tables which is results in very fast +// stack walking. In some cases this is not suitable, and we then fall back +// to the standard stack walking implementation. +# if !defined(UE_CALLSTACK_TRACE_USE_UNWIND_TABLES) +# if defined(__clang__) +# define UE_CALLSTACK_TRACE_USE_UNWIND_TABLES 0 +# else +# define UE_CALLSTACK_TRACE_USE_UNWIND_TABLES 1 +# endif +# endif + +// stacktrace tracking using clang intrinsic __builtin_frame_address(0) doesn't work correctly on all windows platforms +# if !defined(PLATFORM_USE_CALLSTACK_ADDRESS_POINTER) +# if defined(__clang__) +# define PLATFORM_USE_CALLSTACK_ADDRESS_POINTER 0 +# else +# define PLATFORM_USE_CALLSTACK_ADDRESS_POINTER 1 +# endif +# endif + +# if !defined(UE_CALLSTACK_TRACE_RESERVE_MB) +// Initial size of the known set of callstacks +# define UE_CALLSTACK_TRACE_RESERVE_MB 8 // ~500k callstacks +# endif + +# if !defined(UE_CALLSTACK_TRACE_RESERVE_GROWABLE) +// If disabled the known set will not grow. New callstacks will not be +// reported if the set is full +# define UE_CALLSTACK_TRACE_RESERVE_GROWABLE 1 +# endif + +namespace zen { + +class FMalloc; + +UE_TRACE_CHANNEL_EXTERN(CallstackChannel) + +UE_TRACE_EVENT_BEGIN_EXTERN(Memory, CallstackSpec, NoSync) + UE_TRACE_EVENT_FIELD(uint32, CallstackId) + UE_TRACE_EVENT_FIELD(uint64[], Frames) +UE_TRACE_EVENT_END() + +class FCallstackTracer +{ +public: + struct FBacktraceEntry + { + uint64_t Hash = 0; + uint32_t FrameCount = 0; + uint64_t* Frames; + }; + + FCallstackTracer(FMalloc* InMalloc) : KnownSet(InMalloc) {} + + uint32_t AddCallstack(const FBacktraceEntry& Entry) + { + bool bAlreadyAdded = false; + + // Our set implementation doesn't allow for zero entries (zero represents an empty element + // in the hash table), so if we get one due to really bad luck in our 64-bit Id calculation, + // treat it as a "1" instead, for purposes of tracking if we've seen that callstack. + const uint64_t Hash = FMath::Max(Entry.Hash, 1ull); + uint32_t Id; + KnownSet.Find(Hash, &Id, &bAlreadyAdded); + if (!bAlreadyAdded) + { + Id = CallstackIdCounter.fetch_add(1, std::memory_order_relaxed); + // On the first callstack reserve memory up front + if (Id == 1) + { + KnownSet.Reserve(InitialReserveCount); + } +# if !UE_CALLSTACK_TRACE_RESERVE_GROWABLE + // If configured as not growable, start returning unknown id's when full. + if (Id >= InitialReserveCount) + { + return 0; + } +# endif + KnownSet.Emplace(Hash, Id); + UE_TRACE_LOG(Memory, CallstackSpec, CallstackChannel) + << CallstackSpec.CallstackId(Id) << CallstackSpec.Frames(Entry.Frames, Entry.FrameCount); + } + + return Id; + } + +private: + struct FEncounteredCallstackSetEntry + { + std::atomic_uint64_t Key; + std::atomic_uint32_t Value; + + inline uint64 GetKey() const { return Key.load(std::memory_order_relaxed); } + inline uint32_t GetValue() const { return Value.load(std::memory_order_relaxed); } + inline bool IsEmpty() const { return Key.load(std::memory_order_relaxed) == 0; } + inline void SetKeyValue(uint64_t InKey, uint32_t InValue) + { + Value.store(InValue, std::memory_order_release); + Key.store(InKey, std::memory_order_relaxed); + } + static inline uint32_t KeyHash(uint64_t Key) { return static_cast<uint32_t>(Key); } + static inline void ClearEntries(FEncounteredCallstackSetEntry* Entries, int32_t EntryCount) + { + memset(Entries, 0, EntryCount * sizeof(FEncounteredCallstackSetEntry)); + } + }; + + typedef TGrowOnlyLockFreeHash<FEncounteredCallstackSetEntry, uint64_t, uint32_t> FEncounteredCallstackSet; + + constexpr static uint32_t InitialReserveBytes = UE_CALLSTACK_TRACE_RESERVE_MB * 1024 * 1024; + constexpr static uint32_t InitialReserveCount = InitialReserveBytes / sizeof(FEncounteredCallstackSetEntry); + + FEncounteredCallstackSet KnownSet; + std::atomic_uint32_t CallstackIdCounter{1}; // 0 is reserved for "unknown callstack" +}; + +# if UE_CALLSTACK_TRACE_USE_UNWIND_TABLES + +/* + * Windows' x64 binaries contain a ".pdata" section that describes the location + * and size of its functions and details on how to unwind them. The unwind + * information includes descriptions about a function's stack frame size and + * the non-volatile registers it pushes onto the stack. From this we can + * calculate where a call instruction wrote its return address. This is enough + * to walk the callstack and by caching this information it can be done + * efficiently. + * + * Some functions need a variable amount of stack (such as those that use + * alloc() for example) will use a frame pointer. Frame pointers involve saving + * and restoring the stack pointer in the function's prologue/epilogue. This + * frees the function up to modify the stack pointer arbitrarily. This + * significantly complicates establishing where a return address is, so this + * pdata scheme of walking the stack just doesn't support functions like this. + * Walking stops if it encounters such a function. Fortunately there are + * usually very few such functions, saving us from having to read and track + * non-volatile registers which adds a significant amount of work. + * + * A further optimisation is to to assume we are only interested methods that + * are part of engine or game code. As such we only build lookup tables for + * such modules and never accept OS or third party modules. Backtracing stops + * if an address is encountered which doesn't map to a known module. + */ + +//////////////////////////////////////////////////////////////////////////////// +static uint32_t +AddressToId(uintptr_t Address) +{ + return uint32_t(Address >> 16); +} + +static uintptr_t +IdToAddress(uint32_t Id) +{ + return static_cast<uint32_t>(uintptr_t(Id) << 16); +} + +struct FIdPredicate +{ + template<class T> + bool operator()(uint32_t Id, const T& Item) const + { + return Id < Item.Id; + } + template<class T> + bool operator()(const T& Item, uint32_t Id) const + { + return Item.Id < Id; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +struct FUnwindInfo +{ + uint8_t Version : 3; + uint8_t Flags : 5; + uint8_t PrologBytes; + uint8_t NumUnwindCodes; + uint8_t FrameReg : 4; + uint8_t FrameRspBias : 4; +}; + +# pragma warning(push) +# pragma warning(disable : 4200) +struct FUnwindCode +{ + uint8_t PrologOffset; + uint8_t OpCode : 4; + uint8_t OpInfo : 4; + uint16_t Params[]; +}; +# pragma warning(pop) + +enum +{ + UWOP_PUSH_NONVOL = 0, // 1 node + UWOP_ALLOC_LARGE = 1, // 2 or 3 nodes + UWOP_ALLOC_SMALL = 2, // 1 node + UWOP_SET_FPREG = 3, // 1 node + UWOP_SAVE_NONVOL = 4, // 2 nodes + UWOP_SAVE_NONVOL_FAR = 5, // 3 nodes + UWOP_SAVE_XMM128 = 8, // 2 nodes + UWOP_SAVE_XMM128_FAR = 9, // 3 nodes + UWOP_PUSH_MACHFRAME = 10, // 1 node +}; + +//////////////////////////////////////////////////////////////////////////////// +class FBacktracer +{ +public: + FBacktracer(FMalloc* InMalloc); + ~FBacktracer(); + static FBacktracer* Get(); + void AddModule(uintptr_t Base, const char16_t* Name); + void RemoveModule(uintptr_t Base); + uint32_t GetBacktraceId(void* AddressOfReturnAddress); + +private: + struct FFunction + { + uint32_t Id; + int32_t RspBias; +# if BACKTRACE_DBGLVL >= 2 + uint32_t Size; + const FUnwindInfo* UnwindInfo; +# endif + }; + + struct FModule + { + uint32_t Id; + uint32_t IdSize; + uint32_t NumFunctions; +# if BACKTRACE_DBGLVL >= 1 + uint16 NumFpTypes; + // uint16 *padding* +# else + // uint32_t *padding* +# endif + FFunction* Functions; + }; + + struct FLookupState + { + FModule Module; + }; + + struct FFunctionLookupSetEntry + { + // Bottom 48 bits are key (pointer), top 16 bits are data (RSP bias for function) + std::atomic_uint64_t Data; + + inline uint64_t GetKey() const { return Data.load(std::memory_order_relaxed) & 0xffffffffffffull; } + inline int32_t GetValue() const { return static_cast<int64_t>(Data.load(std::memory_order_relaxed)) >> 48; } + inline bool IsEmpty() const { return Data.load(std::memory_order_relaxed) == 0; } + inline void SetKeyValue(uint64_t Key, int32_t Value) + { + Data.store(Key | (static_cast<int64_t>(Value) << 48), std::memory_order_relaxed); + } + static inline uint32_t KeyHash(uint64_t Key) + { + // 64 bit pointer to 32 bit hash + Key = (~Key) + (Key << 21); + Key = Key ^ (Key >> 24); + Key = Key * 265; + Key = Key ^ (Key >> 14); + Key = Key * 21; + Key = Key ^ (Key >> 28); + Key = Key + (Key << 31); + return static_cast<uint32_t>(Key); + } + static void ClearEntries(FFunctionLookupSetEntry* Entries, int32_t EntryCount) + { + memset(Entries, 0, EntryCount * sizeof(FFunctionLookupSetEntry)); + } + }; + typedef TGrowOnlyLockFreeHash<FFunctionLookupSetEntry, uint64_t, int32_t> FFunctionLookupSet; + + const FFunction* LookupFunction(uintptr_t Address, FLookupState& State) const; + static FBacktracer* Instance; + mutable zen::RwLock Lock; + FModule* Modules; + int32_t ModulesNum; + int32_t ModulesCapacity; + FMalloc* Malloc; + FCallstackTracer CallstackTracer; +# if BACKTRACE_LOCK_FREE + mutable FFunctionLookupSet FunctionLookups; + mutable bool bReentranceCheck = false; +# endif +# if BACKTRACE_DBGLVL >= 1 + mutable uint32_t NumFpTruncations = 0; + mutable uint32_t TotalFunctions = 0; +# endif +}; + +//////////////////////////////////////////////////////////////////////////////// +FBacktracer* FBacktracer::Instance = nullptr; + +//////////////////////////////////////////////////////////////////////////////// +FBacktracer::FBacktracer(FMalloc* InMalloc) +: Malloc(InMalloc) +, CallstackTracer(InMalloc) +# if BACKTRACE_LOCK_FREE +, FunctionLookups(InMalloc) +# endif +{ +# if BACKTRACE_LOCK_FREE + FunctionLookups.Reserve(512 * 1024); // 4 MB +# endif + ModulesCapacity = 8; + ModulesNum = 0; + Modules = (FModule*)Malloc->Malloc(sizeof(FModule) * ModulesCapacity); + + Instance = this; +} + +//////////////////////////////////////////////////////////////////////////////// +FBacktracer::~FBacktracer() +{ + std::span<FModule> ModulesView(Modules, ModulesNum); + for (FModule& Module : ModulesView) + { + Malloc->Free(Module.Functions); + } +} + +//////////////////////////////////////////////////////////////////////////////// +FBacktracer* +FBacktracer::Get() +{ + return Instance; +} + +bool GFullBacktraces = false; + +//////////////////////////////////////////////////////////////////////////////// +void +FBacktracer::AddModule(uintptr_t ModuleBase, const char16_t* Name) +{ + if (!GFullBacktraces) + { + const size_t NameLen = StringLength(Name); + if (!(NameLen > 4 && StringEquals(Name + NameLen - 4, u".exe"))) + { + return; + } + } + + const auto* DosHeader = (IMAGE_DOS_HEADER*)ModuleBase; + const auto* NtHeader = (IMAGE_NT_HEADERS*)(ModuleBase + DosHeader->e_lfanew); + const IMAGE_FILE_HEADER* FileHeader = &(NtHeader->FileHeader); + + uint32_t NumSections = FileHeader->NumberOfSections; + const auto* Sections = (IMAGE_SECTION_HEADER*)(uintptr_t(&(NtHeader->OptionalHeader)) + FileHeader->SizeOfOptionalHeader); + + // Find ".pdata" section + uintptr_t PdataBase = 0; + uintptr_t PdataEnd = 0; + for (uint32_t i = 0; i < NumSections; ++i) + { + const IMAGE_SECTION_HEADER* Section = Sections + i; + if (*(uint64_t*)(Section->Name) == + 0x61'74'61'64'70'2eull) // Sections names are eight bytes and zero padded. This constant is '.pdata' + { + PdataBase = ModuleBase + Section->VirtualAddress; + PdataEnd = PdataBase + Section->SizeOfRawData; + break; + } + } + + if (PdataBase == 0) + { + return; + } + + // Count the number of functions. The assumption here is that if we have got this far then there is at least one function + uint32_t NumFunctions = uint32_t(PdataEnd - PdataBase) / sizeof(RUNTIME_FUNCTION); + if (NumFunctions == 0) + { + return; + } + + const auto* FunctionTables = (RUNTIME_FUNCTION*)PdataBase; + do + { + const RUNTIME_FUNCTION* Function = FunctionTables + NumFunctions - 1; + if (uint32_t(Function->BeginAddress) < uint32_t(Function->EndAddress)) + { + break; + } + + --NumFunctions; + } while (NumFunctions != 0); + + // Allocate some space for the module's function-to-frame-size table + auto* OutTable = (FFunction*)Malloc->Malloc(sizeof(FFunction) * NumFunctions); + FFunction* OutTableCursor = OutTable; + + // Extract frame size for each function from pdata's unwind codes. + uint32_t NumFpFuncs = 0; + for (uint32_t i = 0; i < NumFunctions; ++i) + { + const RUNTIME_FUNCTION* FunctionTable = FunctionTables + i; + + uintptr_t UnwindInfoAddr = ModuleBase + FunctionTable->UnwindInfoAddress; + const auto* UnwindInfo = (FUnwindInfo*)UnwindInfoAddr; + + if (UnwindInfo->Version != 1) + { + /* some v2s have been seen in msvc. Always seem to be assembly + * routines (memset, memcpy, etc) */ + continue; + } + + int32_t FpInfo = 0; + int32_t RspBias = 0; + +# if BACKTRACE_DBGLVL >= 2 + uint32_t PrologVerify = UnwindInfo->PrologBytes; +# endif + + const auto* Code = (FUnwindCode*)(UnwindInfo + 1); + const auto* EndCode = Code + UnwindInfo->NumUnwindCodes; + while (Code < EndCode) + { +# if BACKTRACE_DBGLVL >= 2 + if (Code->PrologOffset > PrologVerify) + { + PLATFORM_BREAK(); + } + PrologVerify = Code->PrologOffset; +# endif + + switch (Code->OpCode) + { + case UWOP_PUSH_NONVOL: + RspBias += 8; + Code += 1; + break; + + case UWOP_ALLOC_LARGE: + if (Code->OpInfo) + { + RspBias += *(uint32_t*)(Code->Params); + Code += 3; + } + else + { + RspBias += Code->Params[0] * 8; + Code += 2; + } + break; + + case UWOP_ALLOC_SMALL: + RspBias += (Code->OpInfo * 8) + 8; + Code += 1; + break; + + case UWOP_SET_FPREG: + // Function will adjust RSP (e.g. through use of alloca()) so it + // uses a frame pointer register. There's instructions like; + // + // push FRAME_REG + // lea FRAME_REG, [rsp + (FRAME_RSP_BIAS * 16)] + // ... + // add rsp, rax + // ... + // sub rsp, FRAME_RSP_BIAS * 16 + // pop FRAME_REG + // ret + // + // To recover the stack frame we would need to track non-volatile + // registers which adds a lot of overhead for a small subset of + // functions. Instead we'll end backtraces at these functions. + + // MSB is set to detect variable sized frames that we can't proceed + // past when back-tracing. + NumFpFuncs++; + FpInfo |= 0x80000000 | (uint32_t(UnwindInfo->FrameReg) << 27) | (uint32_t(UnwindInfo->FrameRspBias) << 23); + Code += 1; + break; + + case UWOP_PUSH_MACHFRAME: + RspBias = Code->OpInfo ? 48 : 40; + Code += 1; + break; + + case UWOP_SAVE_NONVOL: + Code += 2; + break; /* saves are movs instead of pushes */ + case UWOP_SAVE_NONVOL_FAR: + Code += 3; + break; + case UWOP_SAVE_XMM128: + Code += 2; + break; + case UWOP_SAVE_XMM128_FAR: + Code += 3; + break; + + default: +# if BACKTRACE_DBGLVL >= 2 + PLATFORM_BREAK(); +# endif + break; + } + } + + // "Chained" simply means that multiple RUNTIME_FUNCTIONs pertains to a + // single actual function in the .text segment. + bool bIsChained = (UnwindInfo->Flags & UNW_FLAG_CHAININFO); + + RspBias /= sizeof(void*); // stack push/popds in units of one machine word + RspBias += !bIsChained; // and one extra push for the ret address + RspBias |= FpInfo; // pack in details about possible frame pointer + + if (bIsChained) + { + OutTableCursor[-1].RspBias += RspBias; +# if BACKTRACE_DBGLVL >= 2 + OutTableCursor[-1].Size += (FunctionTable->EndAddress - FunctionTable->BeginAddress); +# endif + } + else + { + *OutTableCursor = { + FunctionTable->BeginAddress, + RspBias, +# if BACKTRACE_DBGLVL >= 2 + FunctionTable->EndAddress - FunctionTable->BeginAddress, + UnwindInfo, +# endif + }; + + ++OutTableCursor; + } + } + + uintptr_t ModuleSize = NtHeader->OptionalHeader.SizeOfImage; + ModuleSize += 0xffff; // to align up to next 64K page. it'll get shifted by AddressToId() + + FModule Module = { + AddressToId(ModuleBase), + AddressToId(ModuleSize), + uint32_t(uintptr_t(OutTableCursor - OutTable)), +# if BACKTRACE_DBGLVL >= 1 + uint16(NumFpFuncs), +# endif + OutTable, + }; + + { + zen::RwLock::ExclusiveLockScope _(Lock); + + if (ModulesNum + 1 > ModulesCapacity) + { + ModulesCapacity += 8; + Modules = (FModule*)Malloc->Realloc(Modules, sizeof(FModule) * ModulesCapacity); + } + Modules[ModulesNum++] = Module; + + std::sort(Modules, Modules + ModulesNum, [](const FModule& A, const FModule& B) { return A.Id < B.Id; }); + } + +# if BACKTRACE_DBGLVL >= 1 + NumFpTruncations += NumFpFuncs; + TotalFunctions += NumFunctions; +# endif +} + +//////////////////////////////////////////////////////////////////////////////// +void +FBacktracer::RemoveModule(uintptr_t ModuleBase) +{ + // When Windows' RequestExit() is called it hard-terminates all threads except + // the main thread and then proceeds to unload the process' DLLs. This hard + // thread termination can result is dangling locked locks. Not an issue as + // the rule is "do not do anything multithreaded in DLL load/unload". And here + // we are, taking write locks during DLL unload which is, quite unsurprisingly, + // deadlocking. In reality tracking Windows' DLL unloads doesn't tell us + // anything due to how DLLs and processes' address spaces work. So we will... +# if defined PLATFORM_WINDOWS + ZEN_UNUSED(ModuleBase); + + return; +# else + + zen::RwLock::ExclusiveLockScope _(Lock); + + uint32_t ModuleId = AddressToId(ModuleBase); + TArrayView<FModule> ModulesView(Modules, ModulesNum); + int32_t Index = Algo::LowerBound(ModulesView, ModuleId, FIdPredicate()); + if (Index >= ModulesNum) + { + return; + } + + const FModule& Module = Modules[Index]; + if (Module.Id != ModuleId) + { + return; + } + +# if BACKTRACE_DBGLVL >= 1 + NumFpTruncations -= Module.NumFpTypes; + TotalFunctions -= Module.NumFunctions; +# endif + + // no code should be executing at this point so we can safely free the + // table knowing know one is looking at it. + Malloc->Free(Module.Functions); + + for (SIZE_T i = Index; i < ModulesNum; i++) + { + Modules[i] = Modules[i + 1]; + } + + --ModulesNum; +# endif +} + +//////////////////////////////////////////////////////////////////////////////// +const FBacktracer::FFunction* +FBacktracer::LookupFunction(uintptr_t Address, FLookupState& State) const +{ + // This function caches the previous module look up. The theory here is that + // a series of return address in a backtrace often cluster around one module + + FIdPredicate IdPredicate; + + // Look up the module that Address belongs to. + uint32_t AddressId = AddressToId(Address); + if ((AddressId - State.Module.Id) >= State.Module.IdSize) + { + auto FindIt = std::upper_bound(Modules, Modules + ModulesNum, AddressId, IdPredicate); + + if (FindIt == Modules) + { + return nullptr; + } + + State.Module = *--FindIt; + } + + // Check that the address is within the address space of the best-found module + const FModule* Module = &(State.Module); + if ((AddressId - Module->Id) >= Module->IdSize) + { + return nullptr; + } + + // Now we've a module we have a table of functions and their stack sizes so + // we can get the frame size for Address + uint32_t FuncId = uint32_t(Address - IdToAddress(Module->Id)); + std::span<FFunction> FuncsView(Module->Functions, Module->NumFunctions); + auto FindIt = std::upper_bound(begin(FuncsView), end(FuncsView), FuncId, IdPredicate); + if (FindIt == begin(FuncsView)) + { + return nullptr; + } + + const FFunction* Function = &(*--FindIt); +# if BACKTRACE_DBGLVL >= 2 + if ((FuncId - Function->Id) >= Function->Size) + { + PLATFORM_BREAK(); + return nullptr; + } +# endif + return Function; +} + +//////////////////////////////////////////////////////////////////////////////// +uint32_t +FBacktracer::GetBacktraceId(void* AddressOfReturnAddress) +{ + FLookupState LookupState = {}; + uint64_t Frames[256]; + + uintptr_t* StackPointer = (uintptr_t*)AddressOfReturnAddress; + +# if BACKTRACE_DBGLVL >= 3 + uintptr_t TruthBacktrace[1024]; + uint32_t NumTruth = RtlCaptureStackBackTrace(0, 1024, (void**)TruthBacktrace, nullptr); + uintptr_t* TruthCursor = TruthBacktrace; + for (; *TruthCursor != *StackPointer; ++TruthCursor) + ; +# endif + +# if BACKTRACE_DBGLVL >= 2 + struct + { + void* Sp; + void* Ip; + const FFunction* Function; + } Backtrace[1024] = {}; + uint32_t NumBacktrace = 0; +# endif + + uint64_t BacktraceHash = 0; + uint32_t FrameIdx = 0; + +# if BACKTRACE_LOCK_FREE + // When running lock free, we defer the lock until a lock free function lookup fails + bool Locked = false; +# else + FScopeLock _(&Lock); +# endif + do + { + uintptr_t RetAddr = *StackPointer; + + Frames[FrameIdx++] = RetAddr; + + // This is a simple order-dependent LCG. Should be sufficient enough + BacktraceHash += RetAddr; + BacktraceHash *= 0x30be8efa499c249dull; + +# if BACKTRACE_LOCK_FREE + int32_t RspBias; + bool bIsAlreadyInTable; + FunctionLookups.Find(RetAddr, &RspBias, &bIsAlreadyInTable); + if (bIsAlreadyInTable) + { + if (RspBias < 0) + { + break; + } + else + { + StackPointer += RspBias; + continue; + } + } + if (!Locked) + { + Lock.AcquireExclusive(); + Locked = true; + + // If FunctionLookups.Emplace triggers a reallocation, it can cause an infinite recursion + // when the allocation reenters the stack trace code. We need to break out of the recursion + // in that case, and let the allocation complete, with the assumption that we don't care + // about call stacks for internal allocations in the memory reporting system. The "Lock()" + // above will only fall through with this flag set if it's a second lock in the same thread. + if (bReentranceCheck) + { + break; + } + } +# endif // BACKTRACE_LOCK_FREE + + const FFunction* Function = LookupFunction(RetAddr, LookupState); + if (Function == nullptr) + { +# if BACKTRACE_LOCK_FREE + // LookupFunction fails when modules are not yet registered. In this case, we do not want the address + // to be added to the lookup map, but to retry the lookup later when modules are properly registered. + if (GModulesAreInitialized) + { + bReentranceCheck = true; + auto OnExit = zen::MakeGuard([&] { bReentranceCheck = false; }); + FunctionLookups.Emplace(RetAddr, -1); + } +# endif + break; + } + +# if BACKTRACE_LOCK_FREE + { + // This conversion improves probing performance for the hash set. Additionally it is critical + // to avoid incorrect values when RspBias is compressed into 16 bits in the hash map. + int32_t StoreBias = Function->RspBias < 0 ? -1 : Function->RspBias; + bReentranceCheck = true; + auto OnExit = zen::MakeGuard([&] { bReentranceCheck = false; }); + FunctionLookups.Emplace(RetAddr, StoreBias); + } +# endif + +# if BACKTRACE_DBGLVL >= 2 + if (NumBacktrace < 1024) + { + Backtrace[NumBacktrace++] = { + StackPointer, + (void*)RetAddr, + Function, + }; + } +# endif + + if (Function->RspBias < 0) + { + // This is a frame with a variable-sized stack pointer. We don't + // track enough information to proceed. +# if BACKTRACE_DBGLVL >= 1 + NumFpTruncations++; +# endif + break; + } + + StackPointer += Function->RspBias; + } + // Trunkate callstacks longer than MaxStackDepth + while (*StackPointer && FrameIdx < ZEN_ARRAY_COUNT(Frames)); + + // Build the backtrace entry for submission + FCallstackTracer::FBacktraceEntry BacktraceEntry; + BacktraceEntry.Hash = BacktraceHash; + BacktraceEntry.FrameCount = FrameIdx; + BacktraceEntry.Frames = Frames; + +# if BACKTRACE_DBGLVL >= 3 + for (uint32_t i = 0; i < NumBacktrace; ++i) + { + if ((void*)TruthCursor[i] != Backtrace[i].Ip) + { + PLATFORM_BREAK(); + break; + } + } +# endif + +# if BACKTRACE_LOCK_FREE + if (Locked) + { + Lock.ReleaseExclusive(); + } +# endif + // Add to queue to be processed. This might block until there is room in the + // queue (i.e. the processing thread has caught up processing). + return CallstackTracer.AddCallstack(BacktraceEntry); +} +} + +# else // UE_CALLSTACK_TRACE_USE_UNWIND_TABLES + +namespace zen { + + //////////////////////////////////////////////////////////////////////////////// + class FBacktracer + { + public: + FBacktracer(FMalloc* InMalloc); + ~FBacktracer(); + static FBacktracer* Get(); + inline uint32_t GetBacktraceId(void* AddressOfReturnAddress); + uint32_t GetBacktraceId(uint64_t ReturnAddress); + void AddModule(uintptr_t Base, const char16_t* Name) {} + void RemoveModule(uintptr_t Base) {} + + private: + static FBacktracer* Instance; + FMalloc* Malloc; + FCallstackTracer CallstackTracer; + }; + + //////////////////////////////////////////////////////////////////////////////// + FBacktracer* FBacktracer::Instance = nullptr; + + //////////////////////////////////////////////////////////////////////////////// + FBacktracer::FBacktracer(FMalloc* InMalloc) : Malloc(InMalloc), CallstackTracer(InMalloc) { Instance = this; } + + //////////////////////////////////////////////////////////////////////////////// + FBacktracer::~FBacktracer() {} + + //////////////////////////////////////////////////////////////////////////////// + FBacktracer* FBacktracer::Get() { return Instance; } + + //////////////////////////////////////////////////////////////////////////////// + uint32_t FBacktracer::GetBacktraceId(void* AddressOfReturnAddress) + { + const uint64_t ReturnAddress = *(uint64_t*)AddressOfReturnAddress; + return GetBacktraceId(ReturnAddress); + } + + //////////////////////////////////////////////////////////////////////////////// + uint32_t FBacktracer::GetBacktraceId(uint64_t ReturnAddress) + { +# if !UE_BUILD_SHIPPING + uint64_t StackFrames[256]; + int32_t NumStackFrames = FPlatformStackWalk::CaptureStackBackTrace(StackFrames, UE_ARRAY_COUNT(StackFrames)); + if (NumStackFrames > 0) + { + FCallstackTracer::FBacktraceEntry BacktraceEntry; + uint64_t BacktraceId = 0; + uint32_t FrameIdx = 0; + bool bUseAddress = false; + for (int32_t Index = 0; Index < NumStackFrames; Index++) + { + if (!bUseAddress) + { + // start using backtrace only after ReturnAddress + if (StackFrames[Index] == (uint64_t)ReturnAddress) + { + bUseAddress = true; + } + } + if (bUseAddress || NumStackFrames == 1) + { + uint64_t RetAddr = StackFrames[Index]; + StackFrames[FrameIdx++] = RetAddr; + + // This is a simple order-dependent LCG. Should be sufficient enough + BacktraceId += RetAddr; + BacktraceId *= 0x30be8efa499c249dull; + } + } + + // Save the collected id + BacktraceEntry.Hash = BacktraceId; + BacktraceEntry.FrameCount = FrameIdx; + BacktraceEntry.Frames = StackFrames; + + // Add to queue to be processed. This might block until there is room in the + // queue (i.e. the processing thread has caught up processing). + return CallstackTracer.AddCallstack(BacktraceEntry); + } +# endif + + return 0; + } + +} + +# endif // UE_CALLSTACK_TRACE_USE_UNWIND_TABLES + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +void +CallstackTrace_CreateInternal(FMalloc* Malloc) +{ + if (FBacktracer::Get() != nullptr) + { + return; + } + + // Allocate, construct and intentionally leak backtracer + void* Alloc = Malloc->Malloc(sizeof(FBacktracer), alignof(FBacktracer)); + new (Alloc) FBacktracer(Malloc); + + Modules_Create(Malloc); + Modules_Subscribe([](bool bLoad, void* Module, const char16_t* Name) { + bLoad ? FBacktracer::Get()->AddModule(uintptr_t(Module), Name) //-V522 + : FBacktracer::Get()->RemoveModule(uintptr_t(Module)); + }); +} + +//////////////////////////////////////////////////////////////////////////////// +void +CallstackTrace_InitializeInternal() +{ + Modules_Initialize(); + GModulesAreInitialized = true; +} + +//////////////////////////////////////////////////////////////////////////////// +uint32_t +CallstackTrace_GetCurrentId() +{ + if (!UE_TRACE_CHANNELEXPR_IS_ENABLED(CallstackChannel)) + { + return 0; + } + + void* StackAddress = PLATFORM_RETURN_ADDRESS_FOR_CALLSTACKTRACING(); + if (FBacktracer* Instance = FBacktracer::Get()) + { +# if PLATFORM_USE_CALLSTACK_ADDRESS_POINTER + return Instance->GetBacktraceId(StackAddress); +# else + return Instance->GetBacktraceId((uint64_t)StackAddress); +# endif + } + + return 0; +} + +} // namespace zen + +#endif diff --git a/src/zencore/memtrack/callstacktrace.h b/src/zencore/memtrack/callstacktrace.h new file mode 100644 index 000000000..3e191490b --- /dev/null +++ b/src/zencore/memtrack/callstacktrace.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/trace.h> + +#if ZEN_PLATFORM_WINDOWS +# include <intrin.h> + +# define PLATFORM_RETURN_ADDRESS() _ReturnAddress() +# define PLATFORM_RETURN_ADDRESS_POINTER() _AddressOfReturnAddress() +# define PLATFORM_RETURN_ADDRESS_FOR_CALLSTACKTRACING PLATFORM_RETURN_ADDRESS_POINTER +#endif + +//////////////////////////////////////////////////////////////////////////////// +#if !defined(UE_CALLSTACK_TRACE_ENABLED) +# if UE_TRACE_ENABLED +# if ZEN_PLATFORM_WINDOWS +# define UE_CALLSTACK_TRACE_ENABLED 1 +# endif +# endif +#endif + +#if !defined(UE_CALLSTACK_TRACE_ENABLED) +# define UE_CALLSTACK_TRACE_ENABLED 0 +#endif + +//////////////////////////////////////////////////////////////////////////////// +#if UE_CALLSTACK_TRACE_ENABLED + +# include "platformtls.h" + +namespace zen { + +/** + * Creates callstack tracing. + * @param Malloc Allocator instance to use. + */ +void CallstackTrace_Create(class FMalloc* Malloc); + +/** + * Initializes callstack tracing. On some platforms this has to be delayed due to initialization order. + */ +void CallstackTrace_Initialize(); + +/** + * Capture the current callstack, and trace the definition if it has not already been encountered. The returned value + * can be used in trace events and be resolved in analysis. + * @return Unique id identifying the current callstack. + */ +uint32_t CallstackTrace_GetCurrentId(); + +/** + * Callstack Trace Scoped Macro to avoid resolving the full callstack + * can be used when some external libraries are not compiled with frame pointers + * preventing us to resolve it without crashing. Instead the callstack will be + * only the caller address. + */ +# define CALLSTACK_TRACE_LIMIT_CALLSTACKRESOLVE_SCOPE() FCallStackTraceLimitResolveScope PREPROCESSOR_JOIN(FCTLMScope, __LINE__) + +extern uint32_t GCallStackTracingTlsSlotIndex; + +/** + * @return the fallback callstack address + */ +inline void* +CallstackTrace_GetFallbackPlatformReturnAddressData() +{ + if (FPlatformTLS::IsValidTlsSlot(GCallStackTracingTlsSlotIndex)) + return FPlatformTLS::GetTlsValue(GCallStackTracingTlsSlotIndex); + else + return nullptr; +} + +/** + * @return Needs full callstack resolve + */ +inline bool +CallstackTrace_ResolveFullCallStack() +{ + return CallstackTrace_GetFallbackPlatformReturnAddressData() == nullptr; +} + +/* + * Callstack Trace scope for override CallStack + */ +class FCallStackTraceLimitResolveScope +{ +public: + ZEN_FORCENOINLINE FCallStackTraceLimitResolveScope() + { + if (FPlatformTLS::IsValidTlsSlot(GCallStackTracingTlsSlotIndex)) + { + FPlatformTLS::SetTlsValue(GCallStackTracingTlsSlotIndex, PLATFORM_RETURN_ADDRESS_FOR_CALLSTACKTRACING()); + } + } + + ZEN_FORCENOINLINE ~FCallStackTraceLimitResolveScope() + { + if (FPlatformTLS::IsValidTlsSlot(GCallStackTracingTlsSlotIndex)) + { + FPlatformTLS::SetTlsValue(GCallStackTracingTlsSlotIndex, nullptr); + } + } +}; + +} // namespace zen + +#else // UE_CALLSTACK_TRACE_ENABLED + +namespace zen { + +inline void +CallstackTrace_Create(class FMalloc* /*Malloc*/) +{ +} + +inline void +CallstackTrace_Initialize() +{ +} + +inline uint32_t +CallstackTrace_GetCurrentId() +{ + return 0; +} + +inline void* +CallstackTrace_GetCurrentReturnAddressData() +{ + return nullptr; +} + +inline void* +CallstackTrace_GetFallbackPlatformReturnAddressData() +{ + return nullptr; +} + +inline bool +CallstackTrace_ResolveFullCallStack() +{ + return true; +} + +# define CALLSTACK_TRACE_LIMIT_CALLSTACKRESOLVE_SCOPE() + +} // namespace zen + +#endif // UE_CALLSTACK_TRACE_ENABLED diff --git a/src/zencore/memtrack/growonlylockfreehash.h b/src/zencore/memtrack/growonlylockfreehash.h new file mode 100644 index 000000000..d6ff4fc32 --- /dev/null +++ b/src/zencore/memtrack/growonlylockfreehash.h @@ -0,0 +1,255 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> +#include <zencore/intmath.h> +#include <zencore/thread.h> + +#include <zencore/memory/fmalloc.h> + +#include <atomic> + +namespace zen { + +// Hash table with fast lock free reads, that only supports insertion of items, and no modification of +// values. KeyType must be an integer. EntryType should be a POD with an identifiable "empty" state +// that can't occur in the table, and include the following member functions: +// +// KeyType GetKey() const; // Get the key from EntryType +// ValueType GetValue() const; // Get the value from EntryType +// bool IsEmpty() const; // Query whether EntryType is empty +// void SetKeyValue(KeyType Key, ValueType Value); // Write key and value into EntryType (ATOMICALLY! See below) +// static uint32 KeyHash(KeyType Key); // Convert Key to more well distributed hash +// static void ClearEntries(EntryType* Entries, int32 EntryCount); // Fill an array of entries with empty values +// +// The function "SetKeyValue" must be multi-thread safe when writing new items! This means writing the +// Key last and atomically, or writing the entire EntryType in a single write (say if the key and value +// are packed into a single integer word). Inline is recommended, since these functions are called a +// lot in the inner loop of the algorithm. A simple implementation of "KeyHash" can just return the +// Key (if it's already reasonable as a hash), or mix the bits if better distribution is required. A +// simple implementation of "ClearEntries" can just be a memset, if zero represents an empty entry. +// +// A set can be approximated by making "GetValue" a nop function, and just paying attention to the bool +// result from FindEntry, although you do need to either reserve a certain Key as invalid, or add +// space to store a valid flag as the Value. This class should only be used for small value types, as +// the values are embedded into the hash table, and not stored separately. +// +// Writes are implemented using a lock -- it would be possible to make writes lock free (or lock free +// when resizing doesn't occur), but it adds complexity. If we were to go that route, it would make +// sense to create a fully generic lock free set, which would be much more involved to implement and +// validate than this simple class, and might also offer somewhat worse read perf. Lock free containers +// that support item removal either need additional synchronization overhead on readers, so writers can +// tell if a reader is active and spin, or need graveyard markers and a garbage collection pass called +// periodically, which makes it no longer a simple standalone container. +// +// Lock free reads are accomplished by the reader atomically pulling the hash table pointer from the +// class. The hash table is self contained, with its size stored in the table itself, and hash tables +// are not freed until the class's destruction. So if the table needs to be reallocated due to a write, +// active readers will still have valid memory. This does mean that tables leak, but worst case, you +// end up with half of the memory being waste. It would be possible to garbage collect the excess +// tables, but you'd need some kind of global synchronization to make sure no readers are active. +// +// Besides cleanup of wasted tables, it might be useful to provide a function to clear a table. This +// would involve clearing the Key for all the elements in the table (but leaving the memory allocated), +// and can be done safely with active readers. It's not possible to safely remove individual items due +// to the need to potentially move other items, which would break an active reader that has already +// searched past a moved item. But in the case of removing all items, we don't care when a reader fails, +// it's expected that eventually all readers will fail, regardless of where they are searching. A clear +// function could be useful if a lot of the data you are caching is no longer used, and you want to +// reset the cache. +// +template<typename EntryType, typename KeyType, typename ValueType> +class TGrowOnlyLockFreeHash +{ +public: + TGrowOnlyLockFreeHash(FMalloc* InMalloc) : Malloc(InMalloc), HashTable(nullptr) {} + + ~TGrowOnlyLockFreeHash() + { + FHashHeader* HashTableNext; + for (FHashHeader* HashTableCurrent = HashTable; HashTableCurrent; HashTableCurrent = HashTableNext) + { + HashTableNext = HashTableCurrent->Next; + + Malloc->Free(HashTableCurrent); + } + } + + /** + * Preallocate the hash table to a certain size + * @param Count - Number of EntryType elements to allocate + * @warning Can only be called once, and only before any items have been added! + */ + void Reserve(uint32_t Count) + { + zen::RwLock::ExclusiveLockScope _(WriteCriticalSection); + ZEN_ASSERT(HashTable.load(std::memory_order_relaxed) == nullptr); + + if (Count <= 0) + { + Count = DEFAULT_INITIAL_SIZE; + } + Count = uint32_t(zen::NextPow2(Count)); + FHashHeader* HashTableLocal = (FHashHeader*)Malloc->Malloc(sizeof(FHashHeader) + (Count - 1) * sizeof(EntryType)); + + HashTableLocal->Next = nullptr; + HashTableLocal->TableSize = Count; + HashTableLocal->Used = 0; + EntryType::ClearEntries(HashTableLocal->Elements, Count); + + HashTable.store(HashTableLocal, std::memory_order_release); + } + + /** + * Find an entry in the hash table + * @param Key - Key to search for + * @param OutValue - Memory location to write result value to. Left unmodified if Key isn't found. + * @param bIsAlreadyInTable - Optional result for whether key was found in table. + */ + void Find(KeyType Key, ValueType* OutValue, bool* bIsAlreadyInTable = nullptr) const + { + FHashHeader* HashTableLocal = HashTable.load(std::memory_order_acquire); + if (HashTableLocal) + { + uint32_t TableMask = HashTableLocal->TableSize - 1; + + // Linear probing + for (uint32_t TableIndex = EntryType::KeyHash(Key) & TableMask; !HashTableLocal->Elements[TableIndex].IsEmpty(); + TableIndex = (TableIndex + 1) & TableMask) + { + if (HashTableLocal->Elements[TableIndex].GetKey() == Key) + { + if (OutValue) + { + *OutValue = HashTableLocal->Elements[TableIndex].GetValue(); + } + if (bIsAlreadyInTable) + { + *bIsAlreadyInTable = true; + } + return; + } + } + } + + if (bIsAlreadyInTable) + { + *bIsAlreadyInTable = false; + } + } + + /** + * Add an entry with the given Key to the hash table, will do nothing if the item already exists + * @param Key - Key to add + * @param Value - Value to add for key + * @param bIsAlreadyInTable -- Optional result for whether item was already in table + */ + void Emplace(KeyType Key, ValueType Value, bool* bIsAlreadyInTable = nullptr) + { + zen::RwLock::ExclusiveLockScope _(WriteCriticalSection); + + // After locking, check if the item is already in the hash table. + ValueType ValueIgnore; + bool bFindResult; + Find(Key, &ValueIgnore, &bFindResult); + if (bFindResult == true) + { + if (bIsAlreadyInTable) + { + *bIsAlreadyInTable = true; + } + return; + } + + // Check if there is space in the hash table for a new item. We resize when the hash + // table gets half full or more. @todo: allow client to specify max load factor? + FHashHeader* HashTableLocal = HashTable; + + if (!HashTableLocal || (HashTableLocal->Used >= HashTableLocal->TableSize / 2)) + { + int32_t GrowCount = HashTableLocal ? HashTableLocal->TableSize * 2 : DEFAULT_INITIAL_SIZE; + FHashHeader* HashTableGrow = (FHashHeader*)Malloc->Malloc(sizeof(FHashHeader) + (GrowCount - 1) * sizeof(EntryType)); + + HashTableGrow->Next = HashTableLocal; + HashTableGrow->TableSize = GrowCount; + HashTableGrow->Used = 0; + EntryType::ClearEntries(HashTableGrow->Elements, GrowCount); + + if (HashTableLocal) + { + // Copy existing elements from the old table to the new table + for (int32_t TableIndex = 0; TableIndex < HashTableLocal->TableSize; TableIndex++) + { + EntryType& Entry = HashTableLocal->Elements[TableIndex]; + if (!Entry.IsEmpty()) + { + HashInsertInternal(HashTableGrow, Entry.GetKey(), Entry.GetValue()); + } + } + } + + HashTableLocal = HashTableGrow; + HashTable.store(HashTableGrow, std::memory_order_release); + } + + // Then add our new item + HashInsertInternal(HashTableLocal, Key, Value); + + if (bIsAlreadyInTable) + { + *bIsAlreadyInTable = false; + } + } + + void FindOrAdd(KeyType Key, ValueType Value, bool* bIsAlreadyInTable = nullptr) + { + // Attempt to find the item lock free, before calling "Emplace", which locks the container + bool bFindResult; + ValueType IgnoreResult; + Find(Key, &IgnoreResult, &bFindResult); + if (bFindResult) + { + if (bIsAlreadyInTable) + { + *bIsAlreadyInTable = true; + } + return; + } + + Emplace(Key, Value, bIsAlreadyInTable); + } + +private: + struct FHashHeader + { + FHashHeader* Next; // Old buffers are stored in a linked list for cleanup + int32_t TableSize; + int32_t Used; + EntryType Elements[1]; // Variable sized + }; + + FMalloc* Malloc; + std::atomic<FHashHeader*> HashTable; + zen::RwLock WriteCriticalSection; + + static constexpr int32_t DEFAULT_INITIAL_SIZE = 1024; + + static void HashInsertInternal(FHashHeader* HashTableLocal, KeyType Key, ValueType Value) + { + int32_t TableMask = HashTableLocal->TableSize - 1; + + // Linear probing + for (int32_t TableIndex = EntryType::KeyHash(Key) & TableMask;; TableIndex = (TableIndex + 1) & TableMask) + { + if (HashTableLocal->Elements[TableIndex].IsEmpty()) + { + HashTableLocal->Elements[TableIndex].SetKeyValue(Key, Value); + HashTableLocal->Used++; + break; + } + } + } +}; + +} // namespace zen diff --git a/src/zencore/memtrack/memorytrace.cpp b/src/zencore/memtrack/memorytrace.cpp new file mode 100644 index 000000000..8f723866d --- /dev/null +++ b/src/zencore/memtrack/memorytrace.cpp @@ -0,0 +1,817 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/memory/memorytrace.h> +#include <zencore/memory/tagtrace.h> + +#include "callstacktrace.h" +#include "tracemalloc.h" +#include "vatrace.h" + +#include <zencore/commandline.h> +#include <zencore/enumflags.h> +#include <zencore/guardvalue.h> +#include <zencore/intmath.h> +#include <zencore/string.h> +#include <zencore/trace.h> + +#include <string.h> + +#if ZEN_PLATFORM_WINDOWS +# include <shellapi.h> +#endif + +class FMalloc; + +#if UE_TRACE_ENABLED +namespace zen { +UE_TRACE_CHANNEL_DEFINE(MemAllocChannel, "Memory allocations", true) +} +#endif + +#if UE_MEMORY_TRACE_ENABLED + +//////////////////////////////////////////////////////////////////////////////// + +namespace zen { + +void MemoryTrace_InitTags(FMalloc*); +void MemoryTrace_EnableTracePump(); + +} // namespace zen + +//////////////////////////////////////////////////////////////////////////////// +namespace { +// Controls how often time markers are emitted (must be POW2-1 as this is used as a mask) +constexpr uint32_t MarkerSamplePeriod = 128 - 1; + +// Number of shifted bits to SizeLower +constexpr uint32_t SizeShift = 3; + +// Counter to track when time marker is emitted +std::atomic<uint32_t> GMarkerCounter(0); + +// If enabled also pumps the Trace system itself. Used on process shutdown +// when worker thread has been killed, but memory events still occurs. +bool GDoPumpTrace; + +// Temporarily disables any internal operation that causes allocations. Used to +// avoid recursive behaviour when memory tracing needs to allocate memory through +// TraceMalloc. +thread_local bool GDoNotAllocateInTrace; + +// Set on initialization; on some platforms we hook allocator functions very early +// before Trace has the ability to allocate memory. +bool GTraceAllowed; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +namespace UE { namespace Trace { + TRACELOG_API void Update(); +}} // namespace UE::Trace + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +UE_TRACE_EVENT_BEGIN(Memory, Init, NoSync | Important) + UE_TRACE_EVENT_FIELD(uint64_t, PageSize) // new in UE 5.5 + UE_TRACE_EVENT_FIELD(uint32_t, MarkerPeriod) + UE_TRACE_EVENT_FIELD(uint8, Version) + UE_TRACE_EVENT_FIELD(uint8, MinAlignment) + UE_TRACE_EVENT_FIELD(uint8, SizeShift) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, Marker) + UE_TRACE_EVENT_FIELD(uint64_t, Cycle) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, Alloc) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint32_t, Size) + UE_TRACE_EVENT_FIELD(uint8, AlignmentPow2_SizeLower) + UE_TRACE_EVENT_FIELD(uint8, RootHeap) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, AllocSystem) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint32_t, Size) + UE_TRACE_EVENT_FIELD(uint8, AlignmentPow2_SizeLower) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, AllocVideo) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint32_t, Size) + UE_TRACE_EVENT_FIELD(uint8, AlignmentPow2_SizeLower) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, Free) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint8, RootHeap) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, FreeSystem) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, FreeVideo) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, ReallocAlloc) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint32_t, Size) + UE_TRACE_EVENT_FIELD(uint8, AlignmentPow2_SizeLower) + UE_TRACE_EVENT_FIELD(uint8, RootHeap) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, ReallocAllocSystem) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint32_t, Size) + UE_TRACE_EVENT_FIELD(uint8, AlignmentPow2_SizeLower) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, ReallocFree) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint8, RootHeap) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, ReallocFreeSystem) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, MemorySwapOp) + UE_TRACE_EVENT_FIELD(uint64_t, Address) // page fault real address + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint32_t, CompressedSize) + UE_TRACE_EVENT_FIELD(uint8, SwapOp) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, HeapSpec, NoSync | Important) + UE_TRACE_EVENT_FIELD(HeapId, Id) + UE_TRACE_EVENT_FIELD(HeapId, ParentId) + UE_TRACE_EVENT_FIELD(uint16, Flags) + UE_TRACE_EVENT_FIELD(UE::Trace::WideString, Name) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, HeapMarkAlloc) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(uint16, Flags) + UE_TRACE_EVENT_FIELD(HeapId, Heap) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, HeapUnmarkAlloc) + UE_TRACE_EVENT_FIELD(uint64_t, Address) + UE_TRACE_EVENT_FIELD(uint32_t, CallstackId) + UE_TRACE_EVENT_FIELD(HeapId, Heap) +UE_TRACE_EVENT_END() + +// If the layout of the above events is changed, bump this version number. +// version 1: Initial version (UE 5.0, UE 5.1) +// version 2: Added CallstackId for Free events and also for HeapMarkAlloc, HeapUnmarkAlloc events (UE 5.2). +constexpr uint8 MemoryTraceVersion = 2; + +//////////////////////////////////////////////////////////////////////////////// +class FMallocWrapper : public FMalloc +{ +public: + FMallocWrapper(FMalloc* InMalloc); + +private: + struct FCookie + { + uint64_t Tag : 16; + uint64_t Bias : 8; + uint64_t Size : 40; + }; + + static uint32_t GetActualAlignment(SIZE_T Size, uint32_t Alignment); + + virtual void* Malloc(SIZE_T Size, uint32_t Alignment) override; + virtual void* Realloc(void* PrevAddress, SIZE_T NewSize, uint32_t Alignment) override; + virtual void Free(void* Address) override; + virtual bool GetAllocationSize(void* Address, SIZE_T& SizeOut) override { return InnerMalloc->GetAllocationSize(Address, SizeOut); } + virtual void OnMallocInitialized() override { InnerMalloc->OnMallocInitialized(); } + + FMalloc* InnerMalloc; +}; + +//////////////////////////////////////////////////////////////////////////////// +FMallocWrapper::FMallocWrapper(FMalloc* InMalloc) : InnerMalloc(InMalloc) +{ +} + +//////////////////////////////////////////////////////////////////////////////// +uint32_t +FMallocWrapper::GetActualAlignment(SIZE_T Size, uint32_t Alignment) +{ + // Defaults; if size is < 16 then alignment is 8 else 16. + uint32_t DefaultAlignment = 8 << uint32_t(Size >= 16); + return (Alignment < DefaultAlignment) ? DefaultAlignment : Alignment; +} + +//////////////////////////////////////////////////////////////////////////////// +void* +FMallocWrapper::Malloc(SIZE_T Size, uint32_t Alignment) +{ + uint32_t ActualAlignment = GetActualAlignment(Size, Alignment); + void* Address = InnerMalloc->Malloc(Size, Alignment); + + MemoryTrace_Alloc((uint64_t)Address, Size, ActualAlignment); + + return Address; +} + +//////////////////////////////////////////////////////////////////////////////// +void* +FMallocWrapper::Realloc(void* PrevAddress, SIZE_T NewSize, uint32_t Alignment) +{ + // This simplifies things and means reallocs trace events are true reallocs + if (PrevAddress == nullptr) + { + return Malloc(NewSize, Alignment); + } + + MemoryTrace_ReallocFree((uint64_t)PrevAddress); + + void* RetAddress = InnerMalloc->Realloc(PrevAddress, NewSize, Alignment); + + Alignment = GetActualAlignment(NewSize, Alignment); + MemoryTrace_ReallocAlloc((uint64_t)RetAddress, NewSize, Alignment); + + return RetAddress; +} + +//////////////////////////////////////////////////////////////////////////////// +void +FMallocWrapper::Free(void* Address) +{ + if (Address == nullptr) + { + return; + } + + MemoryTrace_Free((uint64_t)Address); + + void* InnerAddress = Address; + + return InnerMalloc->Free(InnerAddress); +} + +//////////////////////////////////////////////////////////////////////////////// +template<class T> +class alignas(alignof(T)) FUndestructed +{ +public: + template<typename... ArgTypes> + void Construct(ArgTypes... Args) + { + ::new (Buffer) T(Args...); + bIsConstructed = true; + } + + bool IsConstructed() const { return bIsConstructed; } + + T* operator&() { return (T*)Buffer; } + T* operator->() { return (T*)Buffer; } + +protected: + uint8 Buffer[sizeof(T)]; + bool bIsConstructed; +}; + +//////////////////////////////////////////////////////////////////////////////// +static FUndestructed<FTraceMalloc> GTraceMalloc; + +//////////////////////////////////////////////////////////////////////////////// +static EMemoryTraceInit +MemoryTrace_ShouldEnable(const TraceOptions& Options) +{ + EMemoryTraceInit Mode = EMemoryTraceInit::Disabled; + + // Process any command line trace options + // + // Note that calls can come into this function before we enter the regular main function + // and we can therefore not rely on the regular command line parsing for the application + + using namespace std::literals; + + auto ProcessTraceArg = [&](const std::string_view& Arg) { + if (Arg == "memalloc"sv) + { + Mode |= EMemoryTraceInit::AllocEvents; + } + else if (Arg == "callstack"sv) + { + Mode |= EMemoryTraceInit::Callstacks; + } + else if (Arg == "memtag"sv) + { + Mode |= EMemoryTraceInit::Tags; + } + else if (Arg == "memory"sv) + { + Mode |= EMemoryTraceInit::Full; + } + else if (Arg == "memory_light"sv) + { + Mode |= EMemoryTraceInit::Light; + } + }; + + IterateCommaSeparatedValue(Options.Channels, ProcessTraceArg); + return Mode; +} + +//////////////////////////////////////////////////////////////////////////////// +FMalloc* +MemoryTrace_CreateInternal(FMalloc* InMalloc, EMemoryTraceInit Mode) +{ + using namespace zen; + + // If allocation events are not desired we don't need to do anything, even + // if user has enabled only callstacks it will be enabled later. + if (!EnumHasAnyFlags(Mode, EMemoryTraceInit::AllocEvents)) + { + return InMalloc; + } + + // Some OSes (i.e. Windows) will terminate all threads except the main + // one as part of static deinit. However we may receive more memory + // trace events that would get lost as Trace's worker thread has been + // terminated. So flush the last remaining memory events trace needs + // to be updated which we will do that in response to to memory events. + // We'll use an atexit can to know when Trace is probably no longer + // getting ticked. + atexit([]() { MemoryTrace_EnableTracePump(); }); + + GTraceMalloc.Construct(InMalloc); + + // Both tag and callstack tracing need to use the wrapped trace malloc + // so we can break out tracing memory overhead (and not cause recursive behaviour). + if (EnumHasAnyFlags(Mode, EMemoryTraceInit::Tags)) + { + MemoryTrace_InitTags(>raceMalloc); + } + + if (EnumHasAnyFlags(Mode, EMemoryTraceInit::Callstacks)) + { + CallstackTrace_Create(>raceMalloc); + } + + static FUndestructed<FMallocWrapper> SMallocWrapper; + SMallocWrapper.Construct(InMalloc); + + return &SMallocWrapper; +} + +//////////////////////////////////////////////////////////////////////////////// +FMalloc* +MemoryTrace_CreateInternal(FMalloc* InMalloc, const TraceOptions& Options) +{ + const EMemoryTraceInit Mode = MemoryTrace_ShouldEnable(Options); + return MemoryTrace_CreateInternal(InMalloc, Mode); +} + +//////////////////////////////////////////////////////////////////////////////// +FMalloc* +MemoryTrace_Create(FMalloc* InMalloc, const TraceOptions& Options) +{ + FMalloc* OutMalloc = MemoryTrace_CreateInternal(InMalloc, Options); + + if (OutMalloc != InMalloc) + { +# if PLATFORM_SUPPORTS_TRACE_WIN32_VIRTUAL_MEMORY_HOOKS + FVirtualWinApiHooks::Initialize(false); +# endif + } + + return OutMalloc; +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_Initialize() +{ + // At this point we initialized the system to allow tracing. + GTraceAllowed = true; + + const int MIN_ALIGNMENT = 8; + + UE_TRACE_LOG(Memory, Init, MemAllocChannel) + << Init.PageSize(4096) << Init.MarkerPeriod(MarkerSamplePeriod + 1) << Init.Version(MemoryTraceVersion) + << Init.MinAlignment(uint8(MIN_ALIGNMENT)) << Init.SizeShift(uint8(SizeShift)); + + const HeapId SystemRootHeap = MemoryTrace_RootHeapSpec(u"System memory"); + ZEN_ASSERT(SystemRootHeap == EMemoryTraceRootHeap::SystemMemory); + const HeapId VideoRootHeap = MemoryTrace_RootHeapSpec(u"Video memory"); + ZEN_ASSERT(VideoRootHeap == EMemoryTraceRootHeap::VideoMemory); + + static_assert((1 << SizeShift) - 1 <= MIN_ALIGNMENT, "Not enough bits to pack size fields"); + +# if !UE_MEMORY_TRACE_LATE_INIT + // On some platforms callstack initialization cannot happen this early in the process. It is initialized + // in other locations when UE_MEMORY_TRACE_LATE_INIT is defined. Until that point allocations cannot have + // callstacks. + CallstackTrace_Initialize(); +# endif +} + +void +MemoryTrace_Shutdown() +{ + // Disable any further activity + GTraceAllowed = false; +} + +//////////////////////////////////////////////////////////////////////////////// +bool +MemoryTrace_IsActive() +{ + return GTraceAllowed; +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_EnableTracePump() +{ + GDoPumpTrace = true; +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_UpdateInternal() +{ + const uint32_t TheCount = GMarkerCounter.fetch_add(1, std::memory_order_relaxed); + if ((TheCount & MarkerSamplePeriod) == 0) + { + UE_TRACE_LOG(Memory, Marker, MemAllocChannel) << Marker.Cycle(UE::Trace::Private::TimeGetTimestamp()); + } + + if (GDoPumpTrace) + { + UE::Trace::Update(); + } +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_Alloc(uint64_t Address, uint64_t Size, uint32_t Alignment, HeapId RootHeap, uint32_t ExternalCallstackId) +{ + if (!GTraceAllowed) + { + return; + } + + ZEN_ASSERT_SLOW(RootHeap < 16); + + const uint32_t AlignmentPow2 = uint32_t(zen::CountTrailingZeros64(Alignment)); + const uint32_t Alignment_SizeLower = (AlignmentPow2 << SizeShift) | uint32_t(Size & ((1 << SizeShift) - 1)); + const uint32_t CallstackId = ExternalCallstackId ? ExternalCallstackId : GDoNotAllocateInTrace ? 0 : CallstackTrace_GetCurrentId(); + + switch (RootHeap) + { + case EMemoryTraceRootHeap::SystemMemory: + { + UE_TRACE_LOG(Memory, AllocSystem, MemAllocChannel) + << AllocSystem.Address(uint64_t(Address)) << AllocSystem.CallstackId(CallstackId) + << AllocSystem.Size(uint32_t(Size >> SizeShift)) << AllocSystem.AlignmentPow2_SizeLower(uint8(Alignment_SizeLower)); + break; + } + + case EMemoryTraceRootHeap::VideoMemory: + { + UE_TRACE_LOG(Memory, AllocVideo, MemAllocChannel) + << AllocVideo.Address(uint64_t(Address)) << AllocVideo.CallstackId(CallstackId) + << AllocVideo.Size(uint32_t(Size >> SizeShift)) << AllocVideo.AlignmentPow2_SizeLower(uint8(Alignment_SizeLower)); + break; + } + + default: + { + UE_TRACE_LOG(Memory, Alloc, MemAllocChannel) + << Alloc.Address(uint64_t(Address)) << Alloc.CallstackId(CallstackId) << Alloc.Size(uint32_t(Size >> SizeShift)) + << Alloc.AlignmentPow2_SizeLower(uint8(Alignment_SizeLower)) << Alloc.RootHeap(uint8(RootHeap)); + break; + } + } + + MemoryTrace_UpdateInternal(); +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_Free(uint64_t Address, HeapId RootHeap, uint32_t ExternalCallstackId) +{ + if (!GTraceAllowed) + { + return; + } + + ZEN_ASSERT_SLOW(RootHeap < 16); + + const uint32_t CallstackId = ExternalCallstackId ? ExternalCallstackId : GDoNotAllocateInTrace ? 0 : CallstackTrace_GetCurrentId(); + + switch (RootHeap) + { + case EMemoryTraceRootHeap::SystemMemory: + { + UE_TRACE_LOG(Memory, FreeSystem, MemAllocChannel) + << FreeSystem.Address(uint64_t(Address)) << FreeSystem.CallstackId(CallstackId); + break; + } + case EMemoryTraceRootHeap::VideoMemory: + { + UE_TRACE_LOG(Memory, FreeVideo, MemAllocChannel) + << FreeVideo.Address(uint64_t(Address)) << FreeVideo.CallstackId(CallstackId); + break; + } + default: + { + UE_TRACE_LOG(Memory, Free, MemAllocChannel) + << Free.Address(uint64_t(Address)) << Free.CallstackId(CallstackId) << Free.RootHeap(uint8(RootHeap)); + break; + } + } + + MemoryTrace_UpdateInternal(); +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_ReallocAlloc(uint64_t Address, uint64_t Size, uint32_t Alignment, HeapId RootHeap, uint32_t ExternalCallstackId) +{ + if (!GTraceAllowed) + { + return; + } + + ZEN_ASSERT_SLOW(RootHeap < 16); + + const uint32_t AlignmentPow2 = uint32_t(zen::CountTrailingZeros64(Alignment)); + const uint32_t Alignment_SizeLower = (AlignmentPow2 << SizeShift) | uint32_t(Size & ((1 << SizeShift) - 1)); + const uint32_t CallstackId = ExternalCallstackId ? ExternalCallstackId : GDoNotAllocateInTrace ? 0 : CallstackTrace_GetCurrentId(); + + switch (RootHeap) + { + case EMemoryTraceRootHeap::SystemMemory: + { + UE_TRACE_LOG(Memory, ReallocAllocSystem, MemAllocChannel) + << ReallocAllocSystem.Address(uint64_t(Address)) << ReallocAllocSystem.CallstackId(CallstackId) + << ReallocAllocSystem.Size(uint32_t(Size >> SizeShift)) + << ReallocAllocSystem.AlignmentPow2_SizeLower(uint8(Alignment_SizeLower)); + break; + } + + default: + { + UE_TRACE_LOG(Memory, ReallocAlloc, MemAllocChannel) + << ReallocAlloc.Address(uint64_t(Address)) << ReallocAlloc.CallstackId(CallstackId) + << ReallocAlloc.Size(uint32_t(Size >> SizeShift)) << ReallocAlloc.AlignmentPow2_SizeLower(uint8(Alignment_SizeLower)) + << ReallocAlloc.RootHeap(uint8(RootHeap)); + break; + } + } + + MemoryTrace_UpdateInternal(); +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_ReallocFree(uint64_t Address, HeapId RootHeap, uint32_t ExternalCallstackId) +{ + if (!GTraceAllowed) + { + return; + } + + ZEN_ASSERT_SLOW(RootHeap < 16); + + const uint32_t CallstackId = ExternalCallstackId ? ExternalCallstackId : GDoNotAllocateInTrace ? 0 : CallstackTrace_GetCurrentId(); + + switch (RootHeap) + { + case EMemoryTraceRootHeap::SystemMemory: + { + UE_TRACE_LOG(Memory, ReallocFreeSystem, MemAllocChannel) + << ReallocFreeSystem.Address(uint64_t(Address)) << ReallocFreeSystem.CallstackId(CallstackId); + break; + } + + default: + { + UE_TRACE_LOG(Memory, ReallocFree, MemAllocChannel) + << ReallocFree.Address(uint64_t(Address)) << ReallocFree.CallstackId(CallstackId) + << ReallocFree.RootHeap(uint8(RootHeap)); + break; + } + } + + MemoryTrace_UpdateInternal(); +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_SwapOp(uint64_t PageAddress, EMemoryTraceSwapOperation SwapOperation, uint32_t CompressedSize, uint32_t CallstackId) +{ + if (!GTraceAllowed) + { + return; + } + + UE_TRACE_LOG(Memory, MemorySwapOp, MemAllocChannel) + << MemorySwapOp.Address(PageAddress) << MemorySwapOp.CallstackId(CallstackId) << MemorySwapOp.CompressedSize(CompressedSize) + << MemorySwapOp.SwapOp((uint8)SwapOperation); + + MemoryTrace_UpdateInternal(); +} + +//////////////////////////////////////////////////////////////////////////////// +HeapId +MemoryTrace_HeapSpec(HeapId ParentId, const char16_t* Name, EMemoryTraceHeapFlags Flags) +{ + if (!GTraceAllowed) + { + return 0; + } + + static std::atomic<HeapId> HeapIdCount(EMemoryTraceRootHeap::EndReserved + 1); // Reserve indexes for root heaps + const HeapId Id = HeapIdCount.fetch_add(1); + const uint32_t NameLen = uint32_t(zen::StringLength(Name)); + const uint32_t DataSize = NameLen * sizeof(char16_t); + ZEN_ASSERT(ParentId < Id); + + UE_TRACE_LOG(Memory, HeapSpec, MemAllocChannel, DataSize) + << HeapSpec.Id(Id) << HeapSpec.ParentId(ParentId) << HeapSpec.Name(Name, NameLen) << HeapSpec.Flags(uint16(Flags)); + + return Id; +} + +//////////////////////////////////////////////////////////////////////////////// +HeapId +MemoryTrace_RootHeapSpec(const char16_t* Name, EMemoryTraceHeapFlags Flags) +{ + if (!GTraceAllowed) + { + return 0; + } + + static std::atomic<HeapId> RootHeapCount(0); + const HeapId Id = RootHeapCount.fetch_add(1); + ZEN_ASSERT(Id <= EMemoryTraceRootHeap::EndReserved); + + const uint32_t NameLen = uint32_t(zen::StringLength(Name)); + const uint32_t DataSize = NameLen * sizeof(char16_t); + + UE_TRACE_LOG(Memory, HeapSpec, MemAllocChannel, DataSize) + << HeapSpec.Id(Id) << HeapSpec.ParentId(HeapId(~0)) << HeapSpec.Name(Name, NameLen) + << HeapSpec.Flags(uint16(EMemoryTraceHeapFlags::Root | Flags)); + + return Id; +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_MarkAllocAsHeap(uint64_t Address, HeapId Heap, EMemoryTraceHeapAllocationFlags Flags, uint32_t ExternalCallstackId) +{ + if (!GTraceAllowed) + { + return; + } + + const uint32_t CallstackId = ExternalCallstackId ? ExternalCallstackId : GDoNotAllocateInTrace ? 0 : CallstackTrace_GetCurrentId(); + + UE_TRACE_LOG(Memory, HeapMarkAlloc, MemAllocChannel) + << HeapMarkAlloc.Address(uint64_t(Address)) << HeapMarkAlloc.CallstackId(CallstackId) + << HeapMarkAlloc.Flags(uint16(EMemoryTraceHeapAllocationFlags::Heap | Flags)) << HeapMarkAlloc.Heap(Heap); +} + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_UnmarkAllocAsHeap(uint64_t Address, HeapId Heap, uint32_t ExternalCallstackId) +{ + if (!GTraceAllowed) + { + return; + } + + const uint32_t CallstackId = ExternalCallstackId ? ExternalCallstackId : GDoNotAllocateInTrace ? 0 : CallstackTrace_GetCurrentId(); + + // Sets all flags to zero + UE_TRACE_LOG(Memory, HeapUnmarkAlloc, MemAllocChannel) + << HeapUnmarkAlloc.Address(uint64_t(Address)) << HeapUnmarkAlloc.CallstackId(CallstackId) << HeapUnmarkAlloc.Heap(Heap); +} + +} // namespace zen + +#else // UE_MEMORY_TRACE_ENABLED + +///////////////////////////////////////////////////////////////////////////// +bool +MemoryTrace_IsActive() +{ + return false; +} + +#endif // UE_MEMORY_TRACE_ENABLED + +namespace zen { + +///////////////////////////////////////////////////////////////////////////// +FTraceMalloc::FTraceMalloc(FMalloc* InMalloc) +{ + WrappedMalloc = InMalloc; +} + +///////////////////////////////////////////////////////////////////////////// +FTraceMalloc::~FTraceMalloc() +{ +} + +///////////////////////////////////////////////////////////////////////////// +void* +FTraceMalloc::Malloc(SIZE_T Count, uint32_t Alignment) +{ +#if UE_MEMORY_TRACE_ENABLED + // UE_TRACE_METADATA_CLEAR_SCOPE(); + ZEN_MEMSCOPE(TRACE_TAG); + + void* NewPtr; + { + zen::TGuardValue<bool> _(GDoNotAllocateInTrace, true); + NewPtr = WrappedMalloc->Malloc(Count, Alignment); + } + + const uint64_t Size = Count; + const uint32_t AlignmentPow2 = uint32_t(zen::CountTrailingZeros64(Alignment)); + const uint32_t Alignment_SizeLower = (AlignmentPow2 << SizeShift) | uint32_t(Size & ((1 << SizeShift) - 1)); + + UE_TRACE_LOG(Memory, Alloc, MemAllocChannel) + << Alloc.Address(uint64_t(NewPtr)) << Alloc.CallstackId(0) << Alloc.Size(uint32_t(Size >> SizeShift)) + << Alloc.AlignmentPow2_SizeLower(uint8(Alignment_SizeLower)) << Alloc.RootHeap(uint8(EMemoryTraceRootHeap::SystemMemory)); + + return NewPtr; +#else + return WrappedMalloc->Malloc(Count, Alignment); +#endif // UE_MEMORY_TRACE_ENABLED +} + +///////////////////////////////////////////////////////////////////////////// +void* +FTraceMalloc::Realloc(void* Original, SIZE_T Count, uint32_t Alignment) +{ +#if UE_MEMORY_TRACE_ENABLED + // UE_TRACE_METADATA_CLEAR_SCOPE(); + ZEN_MEMSCOPE(TRACE_TAG); + + UE_TRACE_LOG(Memory, ReallocFree, MemAllocChannel) + << ReallocFree.Address(uint64_t(Original)) << ReallocFree.RootHeap(uint8(EMemoryTraceRootHeap::SystemMemory)); + + void* NewPtr; + { + zen::TGuardValue<bool> _(GDoNotAllocateInTrace, true); + NewPtr = WrappedMalloc->Realloc(Original, Count, Alignment); + } + + const uint64_t Size = Count; + const uint32_t AlignmentPow2 = uint32_t(zen::CountTrailingZeros64(Alignment)); + const uint32_t Alignment_SizeLower = (AlignmentPow2 << SizeShift) | uint32_t(Size & ((1 << SizeShift) - 1)); + + UE_TRACE_LOG(Memory, ReallocAlloc, MemAllocChannel) + << ReallocAlloc.Address(uint64_t(NewPtr)) << ReallocAlloc.CallstackId(0) << ReallocAlloc.Size(uint32_t(Size >> SizeShift)) + << ReallocAlloc.AlignmentPow2_SizeLower(uint8(Alignment_SizeLower)) + << ReallocAlloc.RootHeap(uint8(EMemoryTraceRootHeap::SystemMemory)); + + return NewPtr; +#else + return WrappedMalloc->Realloc(Original, Count, Alignment); +#endif // UE_MEMORY_TRACE_ENABLED +} + +///////////////////////////////////////////////////////////////////////////// +void +FTraceMalloc::Free(void* Original) +{ +#if UE_MEMORY_TRACE_ENABLED + UE_TRACE_LOG(Memory, Free, MemAllocChannel) + << Free.Address(uint64_t(Original)) << Free.RootHeap(uint8(EMemoryTraceRootHeap::SystemMemory)); + + { + zen::TGuardValue<bool> _(GDoNotAllocateInTrace, true); + WrappedMalloc->Free(Original); + } +#else + WrappedMalloc->Free(Original); +#endif // UE_MEMORY_TRACE_ENABLED +} + +} // namespace zen diff --git a/src/zencore/memtrack/moduletrace.cpp b/src/zencore/memtrack/moduletrace.cpp new file mode 100644 index 000000000..cf37c5932 --- /dev/null +++ b/src/zencore/memtrack/moduletrace.cpp @@ -0,0 +1,296 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenbase/zenbase.h> +#include <zencore/memory/llm.h> +#include <zencore/memory/memorytrace.h> +#include <zencore/memory/tagtrace.h> + +#if ZEN_PLATFORM_WINDOWS +# define PLATFORM_SUPPORTS_TRACE_WIN32_MODULE_DIAGNOSTICS 1 +#else +# define PLATFORM_SUPPORTS_TRACE_WIN32_MODULE_DIAGNOSTICS 0 +#endif + +#include "moduletrace_events.h" + +#if PLATFORM_SUPPORTS_TRACE_WIN32_MODULE_DIAGNOSTICS + +# include <zencore/windows.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <winternl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +# include <zencore/trace.h> + +# include <array> + +namespace zen { + +class FMalloc; + +typedef uint32_t HeapId; + +//////////////////////////////////////////////////////////////////////////////// +struct FNtDllFunction +{ + FARPROC Addr; + + FNtDllFunction(const char* Name) + { + HMODULE NtDll = LoadLibraryW(L"ntdll.dll"); + ZEN_ASSERT(NtDll); + Addr = GetProcAddress(NtDll, Name); + } + + template<typename... ArgTypes> + unsigned int operator()(ArgTypes... Args) + { + typedef unsigned int(NTAPI * Prototype)(ArgTypes...); + return (Prototype((void*)Addr))(Args...); + } +}; + +////////////////////////////////////////////////////////////////////////////////7777 +class FModuleTrace +{ +public: + typedef void (*SubscribeFunc)(bool, void*, const char16_t*); + + FModuleTrace(FMalloc* InMalloc); + ~FModuleTrace(); + static FModuleTrace* Get(); + void Initialize(); + void Subscribe(SubscribeFunc Function); + +private: + void OnDllLoaded(const UNICODE_STRING& Name, uintptr_t Base); + void OnDllUnloaded(uintptr_t Base); + void OnDllNotification(unsigned int Reason, const void* DataPtr); + static FModuleTrace* Instance; + SubscribeFunc Subscribers[64]; + int SubscriberCount = 0; + void* CallbackCookie = nullptr; + HeapId ProgramHeapId = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// +FModuleTrace* FModuleTrace::Instance = nullptr; + +//////////////////////////////////////////////////////////////////////////////// +FModuleTrace::FModuleTrace(FMalloc* InMalloc) +{ + ZEN_UNUSED(InMalloc); + Instance = this; +} + +//////////////////////////////////////////////////////////////////////////////// +FModuleTrace::~FModuleTrace() +{ + if (CallbackCookie) + { + FNtDllFunction UnregisterFunc("LdrUnregisterDllNotification"); + UnregisterFunc(CallbackCookie); + } +} + +//////////////////////////////////////////////////////////////////////////////// +FModuleTrace* +FModuleTrace::Get() +{ + return Instance; +} + +//////////////////////////////////////////////////////////////////////////////// +void +FModuleTrace::Initialize() +{ + using namespace UE::Trace; + + ProgramHeapId = MemoryTrace_HeapSpec(SystemMemory, u"Module", EMemoryTraceHeapFlags::None); + + UE_TRACE_LOG(Diagnostics, ModuleInit, ModuleChannel, sizeof(char) * 3) + << ModuleInit.SymbolFormat("pdb", 3) << ModuleInit.ModuleBaseShift(uint8(0)); + + // Register for DLL load/unload notifications. + auto Thunk = [](ULONG Reason, const void* Data, void* Context) { + auto* Self = (FModuleTrace*)Context; + Self->OnDllNotification(Reason, Data); + }; + + typedef void(CALLBACK * ThunkType)(ULONG, const void*, void*); + auto ThunkImpl = ThunkType(Thunk); + + FNtDllFunction RegisterFunc("LdrRegisterDllNotification"); + RegisterFunc(0, ThunkImpl, this, &CallbackCookie); + + // Enumerate already loaded modules. + const TEB* ThreadEnvBlock = NtCurrentTeb(); + const PEB* ProcessEnvBlock = ThreadEnvBlock->ProcessEnvironmentBlock; + const LIST_ENTRY* ModuleIter = ProcessEnvBlock->Ldr->InMemoryOrderModuleList.Flink; + const LIST_ENTRY* ModuleIterEnd = ModuleIter->Blink; + do + { + const auto& ModuleData = *(LDR_DATA_TABLE_ENTRY*)(ModuleIter - 1); + if (ModuleData.DllBase == 0) + { + break; + } + + OnDllLoaded(ModuleData.FullDllName, UPTRINT(ModuleData.DllBase)); + ModuleIter = ModuleIter->Flink; + } while (ModuleIter != ModuleIterEnd); +} + +//////////////////////////////////////////////////////////////////////////////// +void +FModuleTrace::Subscribe(SubscribeFunc Function) +{ + ZEN_ASSERT(SubscriberCount < ZEN_ARRAY_COUNT(Subscribers)); + Subscribers[SubscriberCount++] = Function; +} + +//////////////////////////////////////////////////////////////////////////////// +void +FModuleTrace::OnDllNotification(unsigned int Reason, const void* DataPtr) +{ + enum + { + LDR_DLL_NOTIFICATION_REASON_LOADED = 1, + LDR_DLL_NOTIFICATION_REASON_UNLOADED = 2, + }; + + struct FNotificationData + { + uint32_t Flags; + const UNICODE_STRING& FullPath; + const UNICODE_STRING& BaseName; + uintptr_t Base; + }; + const auto& Data = *(FNotificationData*)DataPtr; + + switch (Reason) + { + case LDR_DLL_NOTIFICATION_REASON_LOADED: + OnDllLoaded(Data.FullPath, Data.Base); + break; + case LDR_DLL_NOTIFICATION_REASON_UNLOADED: + OnDllUnloaded(Data.Base); + break; + } +} + +//////////////////////////////////////////////////////////////////////////////// +void +FModuleTrace::OnDllLoaded(const UNICODE_STRING& Name, UPTRINT Base) +{ + const auto* DosHeader = (IMAGE_DOS_HEADER*)Base; + const auto* NtHeaders = (IMAGE_NT_HEADERS*)(Base + DosHeader->e_lfanew); + const IMAGE_OPTIONAL_HEADER& OptionalHeader = NtHeaders->OptionalHeader; + uint8_t ImageId[20]; + + // Find the guid and age of the binary, used to match debug files + const IMAGE_DATA_DIRECTORY& DebugInfoEntry = OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_DEBUG]; + const auto* DebugEntries = (IMAGE_DEBUG_DIRECTORY*)(Base + DebugInfoEntry.VirtualAddress); + for (uint32_t i = 0, n = DebugInfoEntry.Size / sizeof(DebugEntries[0]); i < n; ++i) + { + const IMAGE_DEBUG_DIRECTORY& Entry = DebugEntries[i]; + if (Entry.Type == IMAGE_DEBUG_TYPE_CODEVIEW) + { + struct FCodeView7 + { + uint32_t Signature; + uint32_t Guid[4]; + uint32_t Age; + }; + + if (Entry.SizeOfData < sizeof(FCodeView7)) + { + continue; + } + + const auto* CodeView7 = (FCodeView7*)(Base + Entry.AddressOfRawData); + if (CodeView7->Signature != 'SDSR') + { + continue; + } + + memcpy(ImageId, (uint8_t*)&CodeView7->Guid, sizeof(uint32_t) * 4); + memcpy(&ImageId[16], (uint8_t*)&CodeView7->Age, sizeof(uint32_t)); + break; + } + } + + // Note: UNICODE_STRING.Length is the size in bytes of the string buffer. + UE_TRACE_LOG(Diagnostics, ModuleLoad, ModuleChannel, uint32_t(Name.Length + sizeof(ImageId))) + << ModuleLoad.Name((const char16_t*)Name.Buffer, Name.Length / 2) << ModuleLoad.Base(uint64_t(Base)) + << ModuleLoad.Size(OptionalHeader.SizeOfImage) << ModuleLoad.ImageId(ImageId, uint32_t(sizeof(ImageId))); + +# if UE_MEMORY_TRACE_ENABLED + { + ZEN_MEMSCOPE(ELLMTag::ProgramSize); + MemoryTrace_Alloc(Base, OptionalHeader.SizeOfImage, 4 * 1024, EMemoryTraceRootHeap::SystemMemory); + MemoryTrace_MarkAllocAsHeap(Base, ProgramHeapId); + MemoryTrace_Alloc(Base, OptionalHeader.SizeOfImage, 4 * 1024, EMemoryTraceRootHeap::SystemMemory); + } +# endif // UE_MEMORY_TRACE_ENABLED + + for (int i = 0; i < SubscriberCount; ++i) + { + Subscribers[i](true, (void*)Base, (const char16_t*)Name.Buffer); + } +} + +//////////////////////////////////////////////////////////////////////////////// +void +FModuleTrace::OnDllUnloaded(UPTRINT Base) +{ +# if UE_MEMORY_TRACE_ENABLED + MemoryTrace_Free(Base, EMemoryTraceRootHeap::SystemMemory); + MemoryTrace_UnmarkAllocAsHeap(Base, ProgramHeapId); + MemoryTrace_Free(Base, EMemoryTraceRootHeap::SystemMemory); +# endif // UE_MEMORY_TRACE_ENABLED + + UE_TRACE_LOG(Diagnostics, ModuleUnload, ModuleChannel) << ModuleUnload.Base(uint64(Base)); + + for (int i = 0; i < SubscriberCount; ++i) + { + Subscribers[i](false, (void*)Base, nullptr); + } +} + +//////////////////////////////////////////////////////////////////////////////// +void +Modules_Create(FMalloc* Malloc) +{ + if (FModuleTrace::Get() != nullptr) + { + return; + } + + static FModuleTrace Instance(Malloc); +} + +//////////////////////////////////////////////////////////////////////////////// +void +Modules_Initialize() +{ + if (FModuleTrace* Instance = FModuleTrace::Get()) + { + Instance->Initialize(); + } +} + +//////////////////////////////////////////////////////////////////////////////// +void +Modules_Subscribe(void (*Function)(bool, void*, const char16_t*)) +{ + if (FModuleTrace* Instance = FModuleTrace::Get()) + { + Instance->Subscribe(Function); + } +} + +} // namespace zen + +#endif // PLATFORM_SUPPORTS_WIN32_MEMORY_TRACE diff --git a/src/zencore/memtrack/moduletrace.h b/src/zencore/memtrack/moduletrace.h new file mode 100644 index 000000000..5e7374faa --- /dev/null +++ b/src/zencore/memtrack/moduletrace.h @@ -0,0 +1,11 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +namespace zen { + +void Modules_Create(class FMalloc*); +void Modules_Subscribe(void (*)(bool, void*, const char16_t*)); +void Modules_Initialize(); + +} // namespace zen diff --git a/src/zencore/memtrack/moduletrace_events.cpp b/src/zencore/memtrack/moduletrace_events.cpp new file mode 100644 index 000000000..9c6a9b648 --- /dev/null +++ b/src/zencore/memtrack/moduletrace_events.cpp @@ -0,0 +1,16 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/trace.h> + +#include "moduletrace_events.h" + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +UE_TRACE_CHANNEL_DEFINE(ModuleChannel, "Module information needed for symbols resolution", true) + +UE_TRACE_EVENT_DEFINE(Diagnostics, ModuleInit) +UE_TRACE_EVENT_DEFINE(Diagnostics, ModuleLoad) +UE_TRACE_EVENT_DEFINE(Diagnostics, ModuleUnload) + +} // namespace zen diff --git a/src/zencore/memtrack/moduletrace_events.h b/src/zencore/memtrack/moduletrace_events.h new file mode 100644 index 000000000..1bda42fe8 --- /dev/null +++ b/src/zencore/memtrack/moduletrace_events.h @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +#pragma once + +#include <zencore/trace.h> + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +UE_TRACE_CHANNEL_EXTERN(ModuleChannel) + +UE_TRACE_EVENT_BEGIN_EXTERN(Diagnostics, ModuleInit, NoSync | Important) + UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, SymbolFormat) + UE_TRACE_EVENT_FIELD(uint8, ModuleBaseShift) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN_EXTERN(Diagnostics, ModuleLoad, NoSync | Important) + UE_TRACE_EVENT_FIELD(UE::Trace::WideString, Name) + UE_TRACE_EVENT_FIELD(uint64, Base) + UE_TRACE_EVENT_FIELD(uint32, Size) + UE_TRACE_EVENT_FIELD(uint8[], ImageId) // Platform specific id for this image, used to match debug files were available +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN_EXTERN(Diagnostics, ModuleUnload, NoSync | Important) + UE_TRACE_EVENT_FIELD(uint64, Base) +UE_TRACE_EVENT_END() + +} // namespace zen diff --git a/src/zencore/memtrack/platformtls.h b/src/zencore/memtrack/platformtls.h new file mode 100644 index 000000000..f134e68a8 --- /dev/null +++ b/src/zencore/memtrack/platformtls.h @@ -0,0 +1,107 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +/** + * It should be possible to provide a generic implementation as long as a threadID is provided. We don't do that yet. + */ +struct FGenericPlatformTLS +{ + static const uint32_t InvalidTlsSlot = 0xFFFFFFFF; + + /** + * Return false if this is an invalid TLS slot + * @param SlotIndex the TLS index to check + * @return true if this looks like a valid slot + */ + static bool IsValidTlsSlot(uint32_t SlotIndex) { return SlotIndex != InvalidTlsSlot; } +}; + +#if ZEN_PLATFORM_WINDOWS + +# include <zencore/windows.h> + +class FWindowsPlatformTLS : public FGenericPlatformTLS +{ +public: + static uint32_t AllocTlsSlot() { return ::TlsAlloc(); } + + static void FreeTlsSlot(uint32_t SlotIndex) { ::TlsFree(SlotIndex); } + + static void SetTlsValue(uint32_t SlotIndex, void* Value) { ::TlsSetValue(SlotIndex, Value); } + + /** + * Reads the value stored at the specified TLS slot + * + * @return the value stored in the slot + */ + static void* GetTlsValue(uint32_t SlotIndex) { return ::TlsGetValue(SlotIndex); } + + /** + * Return false if this is an invalid TLS slot + * @param SlotIndex the TLS index to check + * @return true if this looks like a valid slot + */ + static bool IsValidTlsSlot(uint32_t SlotIndex) { return SlotIndex != InvalidTlsSlot; } +}; + +typedef FWindowsPlatformTLS FPlatformTLS; + +#elif ZEN_PLATFORM_MAC + +# include <pthread.h + +/** + * Apple implementation of the TLS OS functions + **/ +struct FApplePlatformTLS : public FGenericPlatformTLS +{ + /** + * Returns the currently executing thread's id + */ + static uint32_t GetCurrentThreadId(void) { return (uint32_t)pthread_mach_thread_np(pthread_self()); } + + /** + * Allocates a thread local store slot + */ + static uint32_t AllocTlsSlot(void) + { + // allocate a per-thread mem slot + pthread_key_t SlotKey = 0; + if (pthread_key_create(&SlotKey, NULL) != 0) + { + SlotKey = InvalidTlsSlot; // matches the Windows TlsAlloc() retval. + } + return SlotKey; + } + + /** + * Sets a value in the specified TLS slot + * + * @param SlotIndex the TLS index to store it in + * @param Value the value to store in the slot + */ + static void SetTlsValue(uint32_t SlotIndex, void* Value) { pthread_setspecific((pthread_key_t)SlotIndex, Value); } + + /** + * Reads the value stored at the specified TLS slot + * + * @return the value stored in the slot + */ + static void* GetTlsValue(uint32_t SlotIndex) { return pthread_getspecific((pthread_key_t)SlotIndex); } + + /** + * Frees a previously allocated TLS slot + * + * @param SlotIndex the TLS index to store it in + */ + static void FreeTlsSlot(uint32_t SlotIndex) { pthread_key_delete((pthread_key_t)SlotIndex); } +}; + +typedef FApplePlatformTLS FPlatformTLS; + +#else +# error Platform not yet supported +#endif diff --git a/src/zencore/memtrack/tagtrace.cpp b/src/zencore/memtrack/tagtrace.cpp new file mode 100644 index 000000000..575b1fe53 --- /dev/null +++ b/src/zencore/memtrack/tagtrace.cpp @@ -0,0 +1,247 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/memory/fmalloc.h> +#include <zencore/memory/llm.h> +#include <zencore/memory/tagtrace.h> + +#include "growonlylockfreehash.h" + +#if UE_MEMORY_TAGS_TRACE_ENABLED && UE_TRACE_ENABLED + +# include <zencore/string.h> + +namespace zen { +//////////////////////////////////////////////////////////////////////////////// + +UE_TRACE_CHANNEL_EXTERN(MemAllocChannel); + +UE_TRACE_EVENT_BEGIN(Memory, TagSpec, Important | NoSync) + UE_TRACE_EVENT_FIELD(int32, Tag) + UE_TRACE_EVENT_FIELD(int32, Parent) + UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, Display) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, MemoryScope, NoSync) + UE_TRACE_EVENT_FIELD(int32, Tag) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Memory, MemoryScopePtr, NoSync) + UE_TRACE_EVENT_FIELD(uint64, Ptr) +UE_TRACE_EVENT_END() + +//////////////////////////////////////////////////////////////////////////////// +// Per thread active tag, i.e. the top level FMemScope +thread_local int32 GActiveTag; + +//////////////////////////////////////////////////////////////////////////////// +FMemScope::FMemScope() +{ +} + +FMemScope::FMemScope(int32_t InTag, bool bShouldActivate /*= true*/) +{ + if (UE_TRACE_CHANNELEXPR_IS_ENABLED(MemAllocChannel) & bShouldActivate) + { + ActivateScope(InTag); + } +} + +//////////////////////////////////////////////////////////////////////////////// +FMemScope::FMemScope(ELLMTag InTag, bool bShouldActivate /*= true*/) +{ + if (UE_TRACE_CHANNELEXPR_IS_ENABLED(MemAllocChannel) & bShouldActivate) + { + ActivateScope(static_cast<int32_t>(InTag)); + } +} + +FMemScope::FMemScope(FLLMTag InTag, bool bShouldActivate /*= true*/) +{ + if (UE_TRACE_CHANNELEXPR_IS_ENABLED(MemAllocChannel) & bShouldActivate) + { + ActivateScope(static_cast<int32_t>(InTag.GetTag())); + } +} + +//////////////////////////////////////////////////////////////////////////////// +void +FMemScope::ActivateScope(int32_t InTag) +{ + if (InTag == GActiveTag) + return; + + if (auto LogScope = FMemoryMemoryScopeFields::LogScopeType::ScopedEnter<FMemoryMemoryScopeFields>()) + { + if (const auto& __restrict MemoryScope = *(FMemoryMemoryScopeFields*)(&LogScope)) + { + Inner.SetActive(); + LogScope += LogScope << MemoryScope.Tag(InTag); + PrevTag = GActiveTag; + GActiveTag = InTag; + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +FMemScope::~FMemScope() +{ + if (Inner.bActive) + { + GActiveTag = PrevTag; + } +} + +//////////////////////////////////////////////////////////////////////////////// +FMemScopePtr::FMemScopePtr(uint64_t InPtr) +{ + if (InPtr != 0 && TRACE_PRIVATE_CHANNELEXPR_IS_ENABLED(MemAllocChannel)) + { + if (auto LogScope = FMemoryMemoryScopePtrFields::LogScopeType::ScopedEnter<FMemoryMemoryScopePtrFields>()) + { + if (const auto& __restrict MemoryScope = *(FMemoryMemoryScopePtrFields*)(&LogScope)) + { + Inner.SetActive(), LogScope += LogScope << MemoryScope.Ptr(InPtr); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////// +FMemScopePtr::~FMemScopePtr() +{ +} + +///////////////////////////////////////////////////////////////////////////////// + +/** + * Utility class that manages tracing the specification of unique LLM tags + * and custom name based tags. + */ +class FTagTrace +{ +public: + FTagTrace(FMalloc* InMalloc); + void AnnounceGenericTags() const; + void AnnounceSpecialTags() const; + int32 AnnounceCustomTag(int32 Tag, int32 ParentTag, const ANSICHAR* Display) const; + +private: + struct FTagNameSetEntry + { + std::atomic_int32_t Data; + + int32_t GetKey() const { return Data.load(std::memory_order_relaxed); } + bool GetValue() const { return true; } + bool IsEmpty() const { return Data.load(std::memory_order_relaxed) == 0; } // NAME_None is treated as empty + void SetKeyValue(int32_t Key, bool Value) + { + ZEN_UNUSED(Value); + Data.store(Key, std::memory_order_relaxed); + } + static uint32_t KeyHash(int32_t Key) { return static_cast<uint32>(Key); } + static void ClearEntries(FTagNameSetEntry* Entries, int32_t EntryCount) + { + memset(Entries, 0, EntryCount * sizeof(FTagNameSetEntry)); + } + }; + typedef TGrowOnlyLockFreeHash<FTagNameSetEntry, int32_t, bool> FTagNameSet; + + FTagNameSet AnnouncedNames; + static FMalloc* Malloc; +}; + +FMalloc* FTagTrace::Malloc = nullptr; +static FTagTrace* GTagTrace = nullptr; + +//////////////////////////////////////////////////////////////////////////////// +FTagTrace::FTagTrace(FMalloc* InMalloc) : AnnouncedNames(InMalloc) +{ + Malloc = InMalloc; + AnnouncedNames.Reserve(1024); + AnnounceGenericTags(); + AnnounceSpecialTags(); +} + +//////////////////////////////////////////////////////////////////////////////// +void +FTagTrace::AnnounceGenericTags() const +{ +# define TRACE_TAG_SPEC(Enum, Str, ParentTag) \ + { \ + const uint32_t DisplayLen = (uint32_t)StringLength(Str); \ + UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(ANSICHAR)) \ + << TagSpec.Tag((int32_t)ELLMTag::Enum) << TagSpec.Parent((int32_t)ParentTag) << TagSpec.Display(Str, DisplayLen); \ + } + LLM_ENUM_GENERIC_TAGS(TRACE_TAG_SPEC); +# undef TRACE_TAG_SPEC +} + +//////////////////////////////////////////////////////////////////////////////// + +void +FTagTrace::AnnounceSpecialTags() const +{ + auto EmitTag = [](const char16_t* DisplayString, int32_t Tag, int32_t ParentTag) { + const uint32_t DisplayLen = (uint32_t)StringLength(DisplayString); + UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(ANSICHAR)) + << TagSpec.Tag(Tag) << TagSpec.Parent(ParentTag) << TagSpec.Display(DisplayString, DisplayLen); + }; + + EmitTag(u"Trace", TRACE_TAG, -1); +} + +//////////////////////////////////////////////////////////////////////////////// +int32_t +FTagTrace::AnnounceCustomTag(int32_t Tag, int32_t ParentTag, const ANSICHAR* Display) const +{ + const uint32_t DisplayLen = (uint32_t)StringLength(Display); + UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(ANSICHAR)) + << TagSpec.Tag(Tag) << TagSpec.Parent(ParentTag) << TagSpec.Display(Display, DisplayLen); + return Tag; +} + +} // namespace zen + +#endif // UE_MEMORY_TAGS_TRACE_ENABLED + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +void +MemoryTrace_InitTags(FMalloc* InMalloc) +{ +#if UE_MEMORY_TAGS_TRACE_ENABLED && UE_TRACE_ENABLED + GTagTrace = (FTagTrace*)InMalloc->Malloc(sizeof(FTagTrace), alignof(FTagTrace)); + new (GTagTrace) FTagTrace(InMalloc); +#else + ZEN_UNUSED(InMalloc); +#endif +} + +//////////////////////////////////////////////////////////////////////////////// +int32_t +MemoryTrace_AnnounceCustomTag(int32_t Tag, int32_t ParentTag, const char* Display) +{ +#if UE_MEMORY_TAGS_TRACE_ENABLED && UE_TRACE_ENABLED + if (GTagTrace) + { + return GTagTrace->AnnounceCustomTag(Tag, ParentTag, Display); + } +#else + ZEN_UNUSED(Tag, ParentTag, Display); +#endif + return -1; +} + +//////////////////////////////////////////////////////////////////////////////// +int32_t +MemoryTrace_GetActiveTag() +{ +#if UE_MEMORY_TAGS_TRACE_ENABLED && UE_TRACE_ENABLED + return GActiveTag; +#else + return -1; +#endif +} + +} // namespace zen diff --git a/src/zencore/memtrack/tracemalloc.h b/src/zencore/memtrack/tracemalloc.h new file mode 100644 index 000000000..54606ac45 --- /dev/null +++ b/src/zencore/memtrack/tracemalloc.h @@ -0,0 +1,24 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +#pragma once + +#include <zencore/memory/fmalloc.h> +#include <zencore/memory/memorytrace.h> + +namespace zen { + +class FTraceMalloc : public FMalloc +{ +public: + FTraceMalloc(FMalloc* InMalloc); + virtual ~FTraceMalloc(); + + virtual void* Malloc(SIZE_T Count, uint32 Alignment) override; + virtual void* Realloc(void* Original, SIZE_T Count, uint32 Alignment) override; + virtual void Free(void* Original) override; + + virtual void OnMallocInitialized() override { WrappedMalloc->OnMallocInitialized(); } + + FMalloc* WrappedMalloc; +}; + +} // namespace zen diff --git a/src/zencore/memtrack/vatrace.cpp b/src/zencore/memtrack/vatrace.cpp new file mode 100644 index 000000000..4dea27f1b --- /dev/null +++ b/src/zencore/memtrack/vatrace.cpp @@ -0,0 +1,361 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "vatrace.h" + +#if PLATFORM_SUPPORTS_TRACE_WIN32_VIRTUAL_MEMORY_HOOKS + +# include <zencore/memory/memorytrace.h> + +# if (NTDDI_VERSION >= NTDDI_WIN10_RS4) +# pragma comment(lib, "mincore.lib") // VirtualAlloc2 +# endif + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +class FTextSectionEditor +{ +public: + ~FTextSectionEditor(); + template<typename T> + T* Hook(T* Target, T* HookFunction); + +private: + struct FTrampolineBlock + { + FTrampolineBlock* Next; + uint32_t Size; + uint32_t Used; + }; + + static void* GetActualAddress(void* Function); + FTrampolineBlock* AllocateTrampolineBlock(void* Reference); + uint8_t* AllocateTrampoline(void* Reference, unsigned int Size); + void* HookImpl(void* Target, void* HookFunction); + FTrampolineBlock* HeadBlock = nullptr; +}; + +//////////////////////////////////////////////////////////////////////////////// +FTextSectionEditor::~FTextSectionEditor() +{ + for (FTrampolineBlock* Block = HeadBlock; Block != nullptr; Block = Block->Next) + { + DWORD Unused; + VirtualProtect(Block, Block->Size, PAGE_EXECUTE_READ, &Unused); + } + + FlushInstructionCache(GetCurrentProcess(), nullptr, 0); +} + +//////////////////////////////////////////////////////////////////////////////// +void* +FTextSectionEditor::GetActualAddress(void* Function) +{ + // Follow a jmp instruction (0xff/4 only for now) at function and returns + // where it would jmp to. + + uint8_t* Addr = (uint8_t*)Function; + int Offset = unsigned(Addr[0] & 0xf0) == 0x40; // REX prefix + if (Addr[Offset + 0] == 0xff && Addr[Offset + 1] == 0x25) + { + Addr += Offset; + Addr = *(uint8_t**)(Addr + 6 + *(uint32_t*)(Addr + 2)); + } + return Addr; +} + +//////////////////////////////////////////////////////////////////////////////// +FTextSectionEditor::FTrampolineBlock* +FTextSectionEditor::AllocateTrampolineBlock(void* Reference) +{ + static const size_t BlockSize = 0x10000; // 64KB is Windows' canonical granularity + + // Find the start of the main allocation that mapped Reference + MEMORY_BASIC_INFORMATION MemInfo; + VirtualQuery(Reference, &MemInfo, sizeof(MemInfo)); + auto* Ptr = (uint8_t*)(MemInfo.AllocationBase); + + // Step backwards one block at a time and try and allocate that address + while (true) + { + Ptr -= BlockSize; + if (VirtualAlloc(Ptr, BlockSize, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE) != nullptr) + { + break; + } + + uintptr_t Distance = uintptr_t(Reference) - uintptr_t(Ptr); + if (Distance >= 1ull << 31) + { + ZEN_ASSERT(!"Failed to allocate trampoline blocks for memory tracing hooks"); + } + } + + auto* Block = (FTrampolineBlock*)Ptr; + Block->Next = HeadBlock; + Block->Size = BlockSize; + Block->Used = sizeof(FTrampolineBlock); + HeadBlock = Block; + + return Block; +} + +//////////////////////////////////////////////////////////////////////////////// +uint8_t* +FTextSectionEditor::AllocateTrampoline(void* Reference, unsigned int Size) +{ + // Try and find a block that's within 2^31 bytes before Reference + FTrampolineBlock* Block; + for (Block = HeadBlock; Block != nullptr; Block = Block->Next) + { + uintptr_t Distance = uintptr_t(Reference) - uintptr_t(Block); + if (Distance < 1ull << 31) + { + break; + } + } + + // If we didn't find a block then we need to allocate a new one. + if (Block == nullptr) + { + Block = AllocateTrampolineBlock(Reference); + } + + // Allocate space for the trampoline. + uint32_t NextUsed = Block->Used + Size; + if (NextUsed > Block->Size) + { + // Block is full. We could allocate a new block here but as it is not + // expected that so many hooks will be made this path shouldn't happen + ZEN_ASSERT(!"Unable to allocate memory for memory tracing's hooks"); + } + + uint8_t* Out = (uint8_t*)Block + Block->Used; + Block->Used = NextUsed; + + return Out; +} + +//////////////////////////////////////////////////////////////////////////////// +template<typename T> +T* +FTextSectionEditor::Hook(T* Target, T* HookFunction) +{ + return (T*)HookImpl((void*)Target, (void*)HookFunction); +} + +//////////////////////////////////////////////////////////////////////////////// +void* +FTextSectionEditor::HookImpl(void* Target, void* HookFunction) +{ + Target = GetActualAddress(Target); + + // Very rudimentary x86_64 instruction length decoding that only supports op + // code ranges (0x80,0x8b) and (0x50,0x5f). Enough for simple prologues + uint8_t* __restrict Start = (uint8_t*)Target; + const uint8_t* Read = Start; + do + { + Read += (Read[0] & 0xf0) == 0x40; // REX prefix + uint8_t Inst = *Read++; + if (unsigned(Inst - 0x80) < 0x0cu) + { + uint8_t ModRm = *Read++; + Read += ((ModRm & 0300) < 0300) & ((ModRm & 0007) == 0004); // SIB + switch (ModRm & 0300) // Disp[8|32] + { + case 0100: + Read += 1; + break; + case 0200: + Read += 5; + break; + } + Read += (Inst == 0x83); + } + else if (unsigned(Inst - 0x50) >= 0x10u) + { + ZEN_ASSERT(!"Unknown instruction"); + } + } while (Read - Start < 6); + + static const int TrampolineSize = 24; + int PatchSize = int(Read - Start); + uint8_t* TrampolinePtr = AllocateTrampoline(Start, PatchSize + TrampolineSize); + + // Write the trampoline + *(void**)TrampolinePtr = HookFunction; + + uint8_t* PatchJmp = TrampolinePtr + sizeof(void*); + memcpy(PatchJmp, Start, PatchSize); + + PatchJmp += PatchSize; + *PatchJmp = 0xe9; + *(int32_t*)(PatchJmp + 1) = int32_t(intptr_t(Start + PatchSize) - intptr_t(PatchJmp)) - 5; + + // Need to make the text section writeable + DWORD ProtPrev; + uintptr_t ProtBase = uintptr_t(Target) & ~0x0fff; // 0x0fff is mask of VM page size + size_t ProtSize = ((ProtBase + 16 + 0x1000) & ~0x0fff) - ProtBase; // 16 is enough for one x86 instruction + VirtualProtect((void*)ProtBase, ProtSize, PAGE_EXECUTE_READWRITE, &ProtPrev); + + // Patch function to jmp to the hook + uint16_t* HookJmp = (uint16_t*)Target; + HookJmp[0] = 0x25ff; + *(int32_t*)(HookJmp + 1) = int32_t(intptr_t(TrampolinePtr) - intptr_t(HookJmp + 3)); + + // Put the protection back the way it was + VirtualProtect((void*)ProtBase, ProtSize, ProtPrev, &ProtPrev); + + return PatchJmp - PatchSize; +} + +////////////////////////////////////////////////////////////////////////// + +bool FVirtualWinApiHooks::bLight; +LPVOID(WINAPI* FVirtualWinApiHooks::VmAllocOrig)(LPVOID, SIZE_T, DWORD, DWORD); +LPVOID(WINAPI* FVirtualWinApiHooks::VmAllocExOrig)(HANDLE, LPVOID, SIZE_T, DWORD, DWORD); +# if (NTDDI_VERSION >= NTDDI_WIN10_RS4) +PVOID(WINAPI* FVirtualWinApiHooks::VmAlloc2Orig)(HANDLE, PVOID, SIZE_T, ULONG, ULONG, MEM_EXTENDED_PARAMETER*, ULONG); +# else +LPVOID(WINAPI* FVirtualWinApiHooks::VmAlloc2Orig)(HANDLE, LPVOID, SIZE_T, ULONG, ULONG, /*MEM_EXTENDED_PARAMETER* */ void*, ULONG); +# endif +BOOL(WINAPI* FVirtualWinApiHooks::VmFreeOrig)(LPVOID, SIZE_T, DWORD); +BOOL(WINAPI* FVirtualWinApiHooks::VmFreeExOrig)(HANDLE, LPVOID, SIZE_T, DWORD); + +void +FVirtualWinApiHooks::Initialize(bool bInLight) +{ + bLight = bInLight; + + FTextSectionEditor Editor; + + // Note that hooking alloc functions is done last as applying the hook can + // allocate some memory pages. + + VmFreeOrig = Editor.Hook(VirtualFree, &FVirtualWinApiHooks::VmFree); + VmFreeExOrig = Editor.Hook(VirtualFreeEx, &FVirtualWinApiHooks::VmFreeEx); + +# if ZEN_PLATFORM_WINDOWS +# if (NTDDI_VERSION >= NTDDI_WIN10_RS4) + { + VmAlloc2Orig = Editor.Hook(VirtualAlloc2, &FVirtualWinApiHooks::VmAlloc2); + } +# else // NTDDI_VERSION + { + VmAlloc2Orig = nullptr; + HINSTANCE DllInstance; + DllInstance = LoadLibrary(TEXT("kernelbase.dll")); + if (DllInstance != NULL) + { +# pragma warning(push) +# pragma warning(disable : 4191) // 'type cast': unsafe conversion from 'FARPROC' to 'FVirtualWinApiHooks::FnVirtualAlloc2' + VmAlloc2Orig = (FnVirtualAlloc2)GetProcAddress(DllInstance, "VirtualAlloc2"); +# pragma warning(pop) + FreeLibrary(DllInstance); + } + if (VmAlloc2Orig) + { + VmAlloc2Orig = Editor.Hook(VmAlloc2Orig, &FVirtualWinApiHooks::VmAlloc2); + } + } +# endif // NTDDI_VERSION +# endif // PLATFORM_WINDOWS + + VmAllocExOrig = Editor.Hook(VirtualAllocEx, &FVirtualWinApiHooks::VmAllocEx); + VmAllocOrig = Editor.Hook(VirtualAlloc, &FVirtualWinApiHooks::VmAlloc); +} + +//////////////////////////////////////////////////////////////////////////////// +LPVOID WINAPI +FVirtualWinApiHooks::VmAlloc(LPVOID Address, SIZE_T Size, DWORD Type, DWORD Protect) +{ + LPVOID Ret = VmAllocOrig(Address, Size, Type, Protect); + + // Track any reserve for now. Going forward we need events to differentiate reserves/commits and + // corresponding information on frees. + if (Ret != nullptr && ((Type & MEM_RESERVE) || ((Type & MEM_COMMIT) && Address == nullptr))) + { + MemoryTrace_Alloc((uint64_t)Ret, Size, 0, EMemoryTraceRootHeap::SystemMemory); + MemoryTrace_MarkAllocAsHeap((uint64_t)Ret, EMemoryTraceRootHeap::SystemMemory); + } + + return Ret; +} + +//////////////////////////////////////////////////////////////////////////////// +BOOL WINAPI +FVirtualWinApiHooks::VmFree(LPVOID Address, SIZE_T Size, DWORD Type) +{ + if (Type & MEM_RELEASE) + { + MemoryTrace_UnmarkAllocAsHeap((uint64_t)Address, EMemoryTraceRootHeap::SystemMemory); + MemoryTrace_Free((uint64_t)Address, EMemoryTraceRootHeap::SystemMemory); + } + + return VmFreeOrig(Address, Size, Type); +} + +//////////////////////////////////////////////////////////////////////////////// +LPVOID WINAPI +FVirtualWinApiHooks::VmAllocEx(HANDLE Process, LPVOID Address, SIZE_T Size, DWORD Type, DWORD Protect) +{ + LPVOID Ret = VmAllocExOrig(Process, Address, Size, Type, Protect); + + if (Process == GetCurrentProcess() && Ret != nullptr && ((Type & MEM_RESERVE) || ((Type & MEM_COMMIT) && Address == nullptr))) + { + MemoryTrace_Alloc((uint64_t)Ret, Size, 0, EMemoryTraceRootHeap::SystemMemory); + MemoryTrace_MarkAllocAsHeap((uint64_t)Ret, EMemoryTraceRootHeap::SystemMemory); + } + + return Ret; +} + +//////////////////////////////////////////////////////////////////////////////// +BOOL WINAPI +FVirtualWinApiHooks::VmFreeEx(HANDLE Process, LPVOID Address, SIZE_T Size, DWORD Type) +{ + if (Process == GetCurrentProcess() && (Type & MEM_RELEASE)) + { + MemoryTrace_UnmarkAllocAsHeap((uint64_t)Address, EMemoryTraceRootHeap::SystemMemory); + MemoryTrace_Free((uint64_t)Address, EMemoryTraceRootHeap::SystemMemory); + } + + return VmFreeExOrig(Process, Address, Size, Type); +} + +//////////////////////////////////////////////////////////////////////////////// +# if (NTDDI_VERSION >= NTDDI_WIN10_RS4) +PVOID WINAPI +FVirtualWinApiHooks::VmAlloc2(HANDLE Process, + PVOID BaseAddress, + SIZE_T Size, + ULONG Type, + ULONG PageProtection, + MEM_EXTENDED_PARAMETER* ExtendedParameters, + ULONG ParameterCount) +# else +LPVOID WINAPI +FVirtualWinApiHooks::VmAlloc2(HANDLE Process, + LPVOID BaseAddress, + SIZE_T Size, + ULONG Type, + ULONG PageProtection, + /*MEM_EXTENDED_PARAMETER* */ void* ExtendedParameters, + ULONG ParameterCount) +# endif +{ + LPVOID Ret = VmAlloc2Orig(Process, BaseAddress, Size, Type, PageProtection, ExtendedParameters, ParameterCount); + + if (Process == GetCurrentProcess() && Ret != nullptr && ((Type & MEM_RESERVE) || ((Type & MEM_COMMIT) && BaseAddress == nullptr))) + { + MemoryTrace_Alloc((uint64_t)Ret, Size, 0, EMemoryTraceRootHeap::SystemMemory); + MemoryTrace_MarkAllocAsHeap((uint64_t)Ret, EMemoryTraceRootHeap::SystemMemory); + } + + return Ret; +} + +} // namespace zen + +#endif // PLATFORM_SUPPORTS_TRACE_WIN32_VIRTUAL_MEMORY_HOOKS diff --git a/src/zencore/memtrack/vatrace.h b/src/zencore/memtrack/vatrace.h new file mode 100644 index 000000000..59cc7fe97 --- /dev/null +++ b/src/zencore/memtrack/vatrace.h @@ -0,0 +1,61 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +#if ZEN_PLATFORM_WINDOWS && !defined(PLATFORM_SUPPORTS_TRACE_WIN32_VIRTUAL_MEMORY_HOOKS) +# define PLATFORM_SUPPORTS_TRACE_WIN32_VIRTUAL_MEMORY_HOOKS 1 +#endif + +#ifndef PLATFORM_SUPPORTS_TRACE_WIN32_VIRTUAL_MEMORY_HOOKS +# define PLATFORM_SUPPORTS_TRACE_WIN32_VIRTUAL_MEMORY_HOOKS 0 +#endif + +#if PLATFORM_SUPPORTS_TRACE_WIN32_VIRTUAL_MEMORY_HOOKS +# include <zencore/windows.h> + +namespace zen { + +class FVirtualWinApiHooks +{ +public: + static void Initialize(bool bInLight); + +private: + FVirtualWinApiHooks(); + static bool bLight; + static LPVOID WINAPI VmAlloc(LPVOID Address, SIZE_T Size, DWORD Type, DWORD Protect); + static LPVOID WINAPI VmAllocEx(HANDLE Process, LPVOID Address, SIZE_T Size, DWORD Type, DWORD Protect); +# if (NTDDI_VERSION >= NTDDI_WIN10_RS4) + static PVOID WINAPI VmAlloc2(HANDLE Process, + PVOID BaseAddress, + SIZE_T Size, + ULONG AllocationType, + ULONG PageProtection, + MEM_EXTENDED_PARAMETER* ExtendedParameters, + ULONG ParameterCount); + static PVOID(WINAPI* VmAlloc2Orig)(HANDLE, PVOID, SIZE_T, ULONG, ULONG, MEM_EXTENDED_PARAMETER*, ULONG); + typedef PVOID(__stdcall* FnVirtualAlloc2)(HANDLE, PVOID, SIZE_T, ULONG, ULONG, MEM_EXTENDED_PARAMETER*, ULONG); +# else + static LPVOID WINAPI VmAlloc2(HANDLE Process, + LPVOID BaseAddress, + SIZE_T Size, + ULONG AllocationType, + ULONG PageProtection, + void* ExtendedParameters, + ULONG ParameterCount); + static LPVOID(WINAPI* VmAlloc2Orig)(HANDLE, LPVOID, SIZE_T, ULONG, ULONG, /*MEM_EXTENDED_PARAMETER* */ void*, ULONG); + typedef LPVOID(__stdcall* FnVirtualAlloc2)(HANDLE, LPVOID, SIZE_T, ULONG, ULONG, /* MEM_EXTENDED_PARAMETER* */ void*, ULONG); +# endif + static BOOL WINAPI VmFree(LPVOID Address, SIZE_T Size, DWORD Type); + static BOOL WINAPI VmFreeEx(HANDLE Process, LPVOID Address, SIZE_T Size, DWORD Type); + static LPVOID(WINAPI* VmAllocOrig)(LPVOID, SIZE_T, DWORD, DWORD); + static LPVOID(WINAPI* VmAllocExOrig)(HANDLE, LPVOID, SIZE_T, DWORD, DWORD); + static BOOL(WINAPI* VmFreeOrig)(LPVOID, SIZE_T, DWORD); + static BOOL(WINAPI* VmFreeExOrig)(HANDLE, LPVOID, SIZE_T, DWORD); +}; + +} // namespace zen + +#endif diff --git a/src/zencore/parallelwork.cpp b/src/zencore/parallelwork.cpp new file mode 100644 index 000000000..d86d5815f --- /dev/null +++ b/src/zencore/parallelwork.cpp @@ -0,0 +1,264 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/parallelwork.h> + +#include <zencore/callstack.h> +#include <zencore/except.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> + +#include <typeinfo> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +ParallelWork::ParallelWork(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag, WorkerThreadPool::EMode Mode) +: m_AbortFlag(AbortFlag) +, m_PauseFlag(PauseFlag) +, m_Mode(Mode) +, m_PendingWork(1) +{ +} + +ParallelWork::~ParallelWork() +{ + try + { + if (!m_DispatchComplete) + { + ZEN_ASSERT(m_PendingWork.Remaining() > 0); + ZEN_WARN( + "ParallelWork disposed without explicit wait for completion, likely caused by an exception, waiting for dispatched threads " + "to complete"); + m_PendingWork.CountDown(); + m_DispatchComplete = true; + } + const bool WaitSucceeded = m_PendingWork.Wait(); + const ptrdiff_t RemainingWork = m_PendingWork.Remaining(); + if (!WaitSucceeded) + { + ZEN_ERROR("ParallelWork::~ParallelWork(): waiting for latch failed, pending work count at {}", RemainingWork); + } + if (RemainingWork != 0) + { + void* Frames[8]; + uint32_t FrameCount = GetCallstack(2, 8, Frames); + CallstackFrames* Callstack = CreateCallstack(FrameCount, Frames); + auto _ = MakeGuard([Callstack]() { FreeCallstack(Callstack); }); + ZEN_WARN("ParallelWork::~ParallelWork(): waited for outstanding work but pending work count is {} instead of 0", RemainingWork); + + uint32_t WaitedMs = 0; + while (m_PendingWork.Remaining() > 0 && WaitedMs < 2000) + { + Sleep(50); + WaitedMs += 50; + } + ptrdiff_t RemainingWorkAfterSafetyWait = m_PendingWork.Remaining(); + if (RemainingWorkAfterSafetyWait != 0) + { + ZEN_ERROR("ParallelWork::~ParallelWork(): safety wait for {} tasks failed, pending work count at {} after {}\n{}", + RemainingWork, + RemainingWorkAfterSafetyWait, + NiceLatencyNs(WaitedMs * 1000000u), + CallstackToString(Callstack, " ")) + } + else + { + ZEN_ERROR("ParallelWork::~ParallelWork(): safety wait for {} tasks completed after {}\n{}", + RemainingWork, + NiceLatencyNs(WaitedMs * 1000000u), + CallstackToString(Callstack, " ")); + } + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Exception in ParallelWork::~ParallelWork(): {}", Ex.what()); + } +} + +ParallelWork::ExceptionCallback +ParallelWork::DefaultErrorFunction() +{ + return [&](std::exception_ptr Ex, std::atomic<bool>& AbortFlag) { + m_ErrorLock.WithExclusiveLock([&]() { m_Errors.push_back(Ex); }); + AbortFlag = true; + }; +} + +void +ParallelWork::Wait(int32_t UpdateIntervalMS, UpdateCallback&& UpdateCallback) +{ + ZEN_ASSERT(!m_DispatchComplete); + ZEN_ASSERT(m_PendingWork.Remaining() > 0); + m_PendingWork.CountDown(); + m_DispatchComplete = true; + + while (!m_PendingWork.Wait(UpdateIntervalMS)) + { + UpdateCallback(m_AbortFlag.load(), m_PauseFlag.load(), m_PendingWork.Remaining()); + } + + RethrowErrors(); +} + +void +ParallelWork::Wait() +{ + ZEN_ASSERT(!m_DispatchComplete); + ZEN_ASSERT(m_PendingWork.Remaining() > 0); + m_PendingWork.CountDown(); + m_DispatchComplete = true; + + const bool WaitSucceeded = m_PendingWork.Wait(); + const ptrdiff_t RemainingWork = m_PendingWork.Remaining(); + if (!WaitSucceeded) + { + ZEN_ERROR("ParallelWork::Wait(): waiting for latch failed, pending work count at {}", RemainingWork); + } + else if (RemainingWork != 0) + { + ZEN_ERROR("ParallelWork::Wait(): pending work count at {} after successful wait for latch", RemainingWork); + } + + RethrowErrors(); +} + +void +ParallelWork::RethrowErrors() +{ + if (!m_Errors.empty()) + { + if (m_Errors.size() > 1) + { + ZEN_INFO("Multiple exceptions thrown during ParallelWork execution, dropping the following exceptions:"); + auto It = m_Errors.begin() + 1; + while (It != m_Errors.end()) + { + try + { + std::rethrow_exception(*It); + } + catch (const std::exception& Ex) + { + ZEN_INFO(" {}", Ex.what()); + } + It++; + } + } + std::exception_ptr Ex = m_Errors.front(); + m_Errors.clear(); + std::rethrow_exception(Ex); + } +} + +#if ZEN_WITH_TESTS + +TEST_CASE("parallellwork.nowork") +{ + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + Work.Wait(); +} + +TEST_CASE("parallellwork.basic") +{ + WorkerThreadPool WorkerPool(2); + + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + for (uint32_t I = 0; I < 5; I++) + { + Work.ScheduleWork(WorkerPool, [](std::atomic<bool>& AbortFlag) { CHECK(!AbortFlag); }); + } + Work.Wait(); +} + +TEST_CASE("parallellwork.throws_in_work") +{ + WorkerThreadPool WorkerPool(2); + + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + for (uint32_t I = 0; I < 10; I++) + { + Work.ScheduleWork(WorkerPool, [I](std::atomic<bool>& AbortFlag) { + ZEN_UNUSED(AbortFlag); + if (I > 3) + { + throw std::runtime_error("We throw in async thread"); + } + else + { + Sleep(10); + } + }); + } + CHECK_THROWS_WITH(Work.Wait(), "We throw in async thread"); +} + +TEST_CASE("parallellwork.throws_in_dispatch") +{ + WorkerThreadPool WorkerPool(2); + std::atomic<uint32_t> ExecutedCount; + try + { + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + for (uint32_t I = 0; I < 5; I++) + { + Work.ScheduleWork(WorkerPool, [I, &ExecutedCount](std::atomic<bool>& AbortFlag) { + if (AbortFlag.load()) + { + return; + } + ExecutedCount++; + }); + if (I == 3) + { + throw std::runtime_error("We throw in dispatcher thread"); + } + } + CHECK(false); + } + catch (const std::runtime_error& Ex) + { + CHECK_EQ("We throw in dispatcher thread", std::string(Ex.what())); + CHECK_LE(ExecutedCount.load(), 4); + } +} + +TEST_CASE("parallellwork.limitqueue") +{ + WorkerThreadPool WorkerPool(2); + + std::atomic<bool> AbortFlag; + std::atomic<bool> PauseFlag; + ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); + for (uint32_t I = 0; I < 5; I++) + { + Work.ScheduleWork(WorkerPool, [](std::atomic<bool>& AbortFlag) { + if (AbortFlag.load()) + { + return; + } + Sleep(10); + }); + } + Work.Wait(); +} + +void +parallellwork_forcelink() +{ +} +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index 2d0ec2de6..0b25d14f4 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -14,6 +14,7 @@ #if ZEN_PLATFORM_WINDOWS # include <shellapi.h> # include <Shlobj.h> +# include <TlHelp32.h> # include <zencore/windows.h> #else # include <fcntl.h> @@ -28,6 +29,12 @@ # include <unistd.h> #endif +#if ZEN_PLATFORM_MAC +# include <libproc.h> +# include <sys/types.h> +# include <sys/sysctl.h> +#endif + ZEN_THIRD_PARTY_INCLUDES_START #include <fmt/format.h> ZEN_THIRD_PARTY_INCLUDES_END @@ -35,7 +42,9 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { #if ZEN_PLATFORM_LINUX -const bool bNoZombieChildren = []() { +void +IgnoreChildSignals() +{ // When a child process exits it is put into a zombie state until the parent // collects its result. This doesn't fit the Windows-like model that Zen uses // where there is a less strict familial model and no zombification. Ignoring @@ -46,9 +55,80 @@ const bool bNoZombieChildren = []() { sigemptyset(&Action.sa_mask); Action.sa_handler = SIG_IGN; sigaction(SIGCHLD, &Action, nullptr); - return true; -}(); -#endif +} + +static char +GetPidStatus(int Pid, std::error_code& OutEc) +{ + std::filesystem::path EntryPath = std::filesystem::path("/proc") / fmt::format("{}", Pid); + std::filesystem::path StatPath = EntryPath / "stat"; + if (IsFile(StatPath)) + { + FILE* StatFile = fopen(StatPath.c_str(), "r"); + if (StatFile) + { + char Buffer[5120]; + int Size = fread(Buffer, 1, 5120 - 1, StatFile); + fclose(StatFile); + if (Size > 0) + { + Buffer[Size] = 0; + char* ScanPtr = strrchr(Buffer, ')'); + if (ScanPtr && ScanPtr[1] != '\0') + { + ScanPtr += 2; + char State = *ScanPtr; + return State; + } + } + } + else + { + OutEc = MakeErrorCodeFromLastError(); + } + } + return 0; +} + +bool +IsZombieProcess(int pid, std::error_code& OutEc) +{ + char Status = GetPidStatus(pid, OutEc); + if (OutEc) + { + return false; + } + if (Status == 'Z' || Status == 0) + { + return true; + } + return false; +} + +#endif // ZEN_PLATFORM_LINUX + +#if ZEN_PLATFORM_MAC +bool +IsZombieProcess(int pid, std::error_code& OutEc) +{ + struct kinfo_proc Info; + int Mib[4] = {CTL_KERN, KERN_PROC, KERN_PROC_PID, pid}; + size_t InfoSize = sizeof Info; + + int Res = sysctl(Mib, 4, &Info, &InfoSize, NULL, 0); + if (Res != 0) + { + OutEc = MakeErrorCodeFromLastError(); + return false; + } + if (Info.kp_proc.p_stat == SZOMB) + { + // Zombie process + return true; + } + return false; +} +#endif // ZEN_PLATFORM_MAC ProcessHandle::ProcessHandle() = default; @@ -75,8 +155,9 @@ ProcessHandle::~ProcessHandle() } void -ProcessHandle::Initialize(int Pid) +ProcessHandle::Initialize(int Pid, std::error_code& OutEc) { + OutEc.clear(); ZEN_ASSERT(m_ProcessHandle == nullptr); #if ZEN_PLATFORM_WINDOWS @@ -90,7 +171,21 @@ ProcessHandle::Initialize(int Pid) if (!m_ProcessHandle) { - ThrowLastError(fmt::format("ProcessHandle::Initialize(pid: {}) failed", Pid)); + OutEc = MakeErrorCodeFromLastError(); + } + + m_Pid = Pid; +} + +void +ProcessHandle::Initialize(int Pid) +{ + std::error_code Ec; + Initialize(Pid, Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("ProcessHandle::Initialize(pid: {}) failed", Pid)); } m_Pid = Pid; @@ -106,7 +201,8 @@ ProcessHandle::IsRunning() const GetExitCodeProcess(m_ProcessHandle, &ExitCode); bActive = (ExitCode == STILL_ACTIVE); #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - bActive = (kill(pid_t(m_Pid), 0) == 0); + std::error_code _; + bActive = IsProcessRunning(m_Pid, _); #endif return bActive; @@ -118,29 +214,40 @@ ProcessHandle::IsValid() const return (m_ProcessHandle != nullptr); } -void +bool ProcessHandle::Terminate(int ExitCode) { if (!IsRunning()) { - return; + return true; } - bool bSuccess = false; - #if ZEN_PLATFORM_WINDOWS - TerminateProcess(m_ProcessHandle, ExitCode); + BOOL bTerminated = TerminateProcess(m_ProcessHandle, ExitCode); + if (!bTerminated) + { + return false; + } DWORD WaitResult = WaitForSingleObject(m_ProcessHandle, INFINITE); - bSuccess = (WaitResult != WAIT_OBJECT_0); + bool bSuccess = (WaitResult == WAIT_OBJECT_0) || (WaitResult == WAIT_ABANDONED_0); + if (!bSuccess) + { + return false; + } #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC ZEN_UNUSED(ExitCode); - bSuccess = (kill(m_Pid, SIGKILL) == 0); -#endif - - if (!bSuccess) + int Res = kill(pid_t(m_Pid), SIGKILL); + if (Res != 0) { - // What might go wrong here, and what is meaningful to act on? + int err = errno; + if (err != ESRCH) + { + return false; + } } +#endif + Reset(); + return true; } void @@ -157,7 +264,7 @@ ProcessHandle::Reset() } bool -ProcessHandle::Wait(int TimeoutMs) +ProcessHandle::Wait(int TimeoutMs, std::error_code& OutEc) { using namespace std::literals; @@ -174,20 +281,39 @@ ProcessHandle::Wait(int TimeoutMs) case WAIT_TIMEOUT: return false; + case WAIT_ABANDONED_0: + return true; + case WAIT_FAILED: break; } -#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + + OutEc = MakeErrorCodeFromLastError(); + + return false; +#endif // ZEN_PLATFORM_WINDOWS + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC const int SleepMs = 20; timespec SleepTime = {0, SleepMs * 1000 * 1000}; for (int SleepedTimeMS = 0;; SleepedTimeMS += SleepMs) { int WaitState = 0; - waitpid(m_Pid, &WaitState, WNOHANG | WCONTINUED | WUNTRACED); + if (waitpid(m_Pid, &WaitState, WNOHANG | WCONTINUED | WUNTRACED) != -1) + { + if (WIFEXITED(WaitState)) + { + m_ExitCode = WEXITSTATUS(WaitState); + } + } - if (WIFEXITED(WaitState)) + if (!IsProcessRunning(m_Pid, OutEc)) { - m_ExitCode = WEXITSTATUS(WaitState); + return true; + } + else if (OutEc) + { + return false; } if (kill(m_Pid, 0) < 0) @@ -197,7 +323,13 @@ ProcessHandle::Wait(int TimeoutMs) { return true; } - ThrowSystemError(static_cast<uint32_t>(LastError), "Process::Wait kill failed"sv); + OutEc = MakeErrorCode(LastError); + return false; + } + else if (IsZombieProcess(m_Pid, OutEc)) + { + ZEN_INFO("Found process {} in zombie state, treating as not running", m_Pid); + return true; } if (TimeoutMs >= 0 && SleepedTimeMS >= TimeoutMs) @@ -207,17 +339,28 @@ ProcessHandle::Wait(int TimeoutMs) nanosleep(&SleepTime, nullptr); } -#endif + return false; +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +} - // What might go wrong here, and what is meaningful to act on? - ThrowLastError("Process::Wait failed"sv); +bool +ProcessHandle::Wait(int TimeoutMs) +{ + std::error_code Ec; + if (Wait(TimeoutMs, Ec) && !Ec) + { + return true; + } + else if (Ec) + { + throw std::system_error(Ec, std::string("Process::Wait kill failed")); + } + return false; } int -ProcessHandle::WaitExitCode() +ProcessHandle::GetExitCode() { - Wait(-1); - #if ZEN_PLATFORM_WINDOWS DWORD ExitCode = 0; GetExitCodeProcess(m_ProcessHandle, &ExitCode); @@ -234,6 +377,13 @@ ProcessHandle::WaitExitCode() #endif } +int +ProcessHandle::WaitExitCode() +{ + Wait(-1); + return GetExitCode(); +} + ////////////////////////////////////////////////////////////////////////// #if !ZEN_PLATFORM_WINDOWS || ZEN_WITH_TESTS @@ -295,6 +445,10 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma { CreationFlags |= CREATE_NEW_CONSOLE; } + if (Options.Flags & CreateProcOptions::Flag_NoConsole) + { + CreationFlags |= CREATE_NO_WINDOW; + } const wchar_t* WorkingDir = nullptr; if (Options.WorkingDirectory != nullptr) @@ -440,6 +594,10 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C { CreateProcFlags |= CREATE_NEW_CONSOLE; } + if (Options.Flags & CreateProcOptions::Flag_NoConsole) + { + CreateProcFlags |= CREATE_NO_WINDOW; + } ExtendableWideStringBuilder<256> CommandLineZ; CommandLineZ << CommandLine; @@ -576,12 +734,19 @@ ProcessMonitor::IsRunning() #if ZEN_PLATFORM_WINDOWS DWORD ExitCode = 0; - GetExitCodeProcess(Proc, &ExitCode); - ProcIsActive = (ExitCode == STILL_ACTIVE); - if (!ProcIsActive) + if (Proc) + { + GetExitCodeProcess(Proc, &ExitCode); + ProcIsActive = (ExitCode == STILL_ACTIVE); + if (!ProcIsActive) + { + CloseHandle(Proc); + } + } + else { - CloseHandle(Proc); + ProcIsActive = false; } #else int Pid = int(intptr_t(Proc)); @@ -613,11 +778,8 @@ ProcessMonitor::AddPid(int Pid) ProcessHandle = HandleType(intptr_t(Pid)); #endif - if (ProcessHandle) - { - RwLock::ExclusiveLockScope _(m_Lock); - m_ProcessHandles.push_back(ProcessHandle); - } + RwLock::ExclusiveLockScope _(m_Lock); + m_ProcessHandles.push_back(ProcessHandle); } bool @@ -630,7 +792,7 @@ ProcessMonitor::IsActive() const ////////////////////////////////////////////////////////////////////////// bool -IsProcessRunning(int pid) +IsProcessRunning(int pid, std::error_code& OutEc) { // This function is arguably not super useful, a pid can be re-used // by the OS so holding on to a pid and polling it over some time @@ -642,14 +804,19 @@ IsProcessRunning(int pid) if (!hProc) { DWORD Error = zen::GetLastError(); - if (Error == ERROR_INVALID_PARAMETER) { return false; } - - ThrowSystemError(Error, fmt::format("failed to open process with pid {}", pid)); + if (Error == ERROR_ACCESS_DENIED) + { + // Process is running under other user probably, assume it is running + return true; + } + OutEc = MakeErrorCode(Error); + return false; } + auto _ = MakeGuard([hProc]() { CloseHandle(hProc); }); bool bStillActive = true; DWORD ExitCode = 0; @@ -659,17 +826,59 @@ IsProcessRunning(int pid) } else { - ZEN_WARN("Unable to get exit code from handle for process '{}', treating the process as active", pid); + DWORD Error = GetLastError(); + OutEc = MakeErrorCode(Error); + return false; } - - CloseHandle(hProc); - return bStillActive; #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - return (kill(pid_t(pid), 0) == 0); + int Res = kill(pid_t(pid), 0); + if (Res == 0) + { + if (IsZombieProcess(pid, OutEc)) + { + ZEN_INFO("Found process {} in zombie state, treating as not running", pid); + return false; + } + if (OutEc) + { + return false; + } + return true; + } + int Error = errno; + if (Error == ESRCH) // No such process + { + return false; + } + else if (Error == ENOENT) + { + return false; + } + else if (Error == EPERM) + { + return true; // Running under a user we don't have access to, assume it is live + } + else + { + OutEc = MakeErrorCode(Error); + return false; + } #endif } +bool +IsProcessRunning(int pid) +{ + std::error_code Ec; + bool IsRunning = IsProcessRunning(pid, Ec); + if (Ec) + { + ThrowSystemError(Ec.value(), fmt::format("Failed determining if process with pid {} is running", pid)); + } + return IsRunning; +} + int GetCurrentProcessId() { @@ -690,6 +899,189 @@ GetProcessId(CreateProcResult ProcId) #endif } +std::filesystem::path +GetProcessExecutablePath(int Pid, std::error_code& OutEc) +{ +#if ZEN_PLATFORM_WINDOWS + HANDLE ModuleSnapshotHandle = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, (DWORD)Pid); + if (ModuleSnapshotHandle != INVALID_HANDLE_VALUE) + { + auto __ = MakeGuard([&]() { CloseHandle(ModuleSnapshotHandle); }); + MODULEENTRY32 ModuleEntry; + ModuleEntry.dwSize = sizeof(MODULEENTRY32); + if (Module32First(ModuleSnapshotHandle, (LPMODULEENTRY32)&ModuleEntry)) + { + std::filesystem::path ProcessExecutablePath(ModuleEntry.szExePath); + return ProcessExecutablePath; + } + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + OutEc = MakeErrorCodeFromLastError(); + return {}; +#endif // ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_MAC + char Buffer[PROC_PIDPATHINFO_MAXSIZE]; + int Res = proc_pidpath(Pid, Buffer, sizeof(Buffer)); + if (Res > 0) + { + std::filesystem::path ProcessExecutablePath(Buffer); + return ProcessExecutablePath; + } + OutEc = MakeErrorCodeFromLastError(); + return {}; +#endif // ZEN_PLATFORM_MAC +#if ZEN_PLATFORM_LINUX + std::filesystem::path EntryPath = std::filesystem::path("/proc") / fmt::format("{}", Pid); + std::filesystem::path ExeLinkPath = EntryPath / "exe"; + char Link[4096]; + ssize_t BytesRead = readlink(ExeLinkPath.c_str(), Link, sizeof(Link) - 1); + if (BytesRead > 0) + { + Link[BytesRead] = '\0'; + std::filesystem::path ProcessExecutablePath(Link); + return ProcessExecutablePath; + } + OutEc = MakeErrorCodeFromLastError(); + return {}; +#endif // ZEN_PLATFORM_LINUX +} + +std::error_code +FindProcess(const std::filesystem::path& ExecutableImage, ProcessHandle& OutHandle, bool IncludeSelf) +{ +#if ZEN_PLATFORM_WINDOWS + HANDLE ProcessSnapshotHandle = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); + if (ProcessSnapshotHandle == INVALID_HANDLE_VALUE) + { + return MakeErrorCodeFromLastError(); + } + auto _ = MakeGuard([&]() { CloseHandle(ProcessSnapshotHandle); }); + + const DWORD ThisProcessId = ::GetCurrentProcessId(); + + PROCESSENTRY32 Entry; + Entry.dwSize = sizeof(PROCESSENTRY32); + if (Process32First(ProcessSnapshotHandle, (LPPROCESSENTRY32)&Entry)) + { + do + { + if ((IncludeSelf || (Entry.th32ProcessID != ThisProcessId)) && (ExecutableImage.filename() == Entry.szExeFile)) + { + std::error_code Ec; + std::filesystem::path EntryPath = GetProcessExecutablePath(Entry.th32ProcessID, Ec); + if (!Ec) + { + if (EntryPath == ExecutableImage) + { + HANDLE Handle = + OpenProcess(PROCESS_TERMINATE | SYNCHRONIZE | PROCESS_QUERY_INFORMATION, FALSE, Entry.th32ProcessID); + if (Handle == NULL) + { + return MakeErrorCodeFromLastError(); + } + DWORD ExitCode = 0; + GetExitCodeProcess(Handle, &ExitCode); + if (ExitCode == STILL_ACTIVE) + { + OutHandle.Initialize((void*)Handle); + return {}; + } + } + } + } + } while (::Process32Next(ProcessSnapshotHandle, (LPPROCESSENTRY32)&Entry)); + return {}; + } + return MakeErrorCodeFromLastError(); +#endif // ZEN_PLATFORM_WINDOWS +#if ZEN_PLATFORM_MAC + int Mib[4] = {CTL_KERN, KERN_PROC, KERN_PROC_ALL, 0}; + size_t BufferSize = 0; + + struct kinfo_proc* Processes = nullptr; + uint32_t ProcCount = 0; + + const pid_t ThisProcessId = getpid(); + + if (sysctl(Mib, 4, NULL, &BufferSize, NULL, 0) != -1 && BufferSize > 0) + { + struct kinfo_proc* Processes = (struct kinfo_proc*)malloc(BufferSize); + auto _ = MakeGuard([&]() { free(Processes); }); + if (sysctl(Mib, 4, Processes, &BufferSize, NULL, 0) != -1) + { + ProcCount = (uint32_t)(BufferSize / sizeof(struct kinfo_proc)); + char Buffer[PROC_PIDPATHINFO_MAXSIZE]; + for (uint32_t ProcIndex = 0; ProcIndex < ProcCount; ProcIndex++) + { + pid_t Pid = Processes[ProcIndex].kp_proc.p_pid; + if (IncludeSelf || (Pid != ThisProcessId)) + { + std::error_code Ec; + std::filesystem::path EntryPath = GetProcessExecutablePath(Pid, Ec); + if (!Ec) + { + if (EntryPath == ExecutableImage) + { + if (Processes[ProcIndex].kp_proc.p_stat != SZOMB) + { + OutHandle.Initialize(Pid, Ec); + return Ec; + } + } + } + Ec.clear(); + } + } + return {}; + } + } + return MakeErrorCodeFromLastError(); +#endif // ZEN_PLATFORM_MAC +#if ZEN_PLATFORM_LINUX + const pid_t ThisProcessId = getpid(); + + std::vector<uint32_t> RunningPids; + DirectoryContent ProcList; + GetDirectoryContent("/proc", DirectoryContentFlags::IncludeDirs, ProcList); + for (const std::filesystem::path& EntryPath : ProcList.Directories) + { + std::string EntryName = EntryPath.stem(); + std::optional<uint32_t> PidMaybe = ParseInt<uint32_t>(EntryName); + if (PidMaybe.has_value()) + { + if (pid_t Pid = PidMaybe.value(); IncludeSelf || (Pid != ThisProcessId)) + { + RunningPids.push_back(Pid); + } + } + } + + for (uint32_t Pid : RunningPids) + { + std::error_code Ec; + std::filesystem::path EntryPath = GetProcessExecutablePath((int)Pid, Ec); + if (!Ec) + { + if (EntryPath == ExecutableImage) + { + char Status = GetPidStatus(Pid, Ec); + if (!Ec) + { + if (Status && (Status != 'Z')) + { + OutHandle.Initialize(Pid, Ec); + return Ec; + } + } + } + } + Ec.clear(); + } + return {}; +#endif // ZEN_PLATFORM_LINUX +} + #if ZEN_WITH_TESTS void @@ -706,6 +1098,28 @@ TEST_CASE("Process") CHECK(IsProcessRunning(Pid)); } +TEST_CASE("FindProcess") +{ + { + ProcessHandle Process; + std::error_code Ec = FindProcess(GetRunningExecutablePath(), Process, /*IncludeSelf*/ true); + CHECK(!Ec); + CHECK(Process.IsValid()); + } + { + ProcessHandle Process; + std::error_code Ec = FindProcess(GetRunningExecutablePath(), Process, /*IncludeSelf*/ false); + CHECK(!Ec); + CHECK(!Process.IsValid()); + } + { + ProcessHandle Process; + std::error_code Ec = FindProcess("this/does\\not/exist\\123914921929412312312312asdad\\12134.no", Process, /*IncludeSelf*/ false); + CHECK(!Ec); + CHECK(!Process.IsValid()); + } +} + TEST_CASE("BuildArgV") { const char* Words[] = {"one", "two", "three", "four", "five"}; diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp new file mode 100644 index 000000000..00e67dc85 --- /dev/null +++ b/src/zencore/sentryintegration.cpp @@ -0,0 +1,343 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/sentryintegration.h> + +#include <zencore/config.h> +#include <zencore/logging.h> +#include <zencore/session.h> +#include <zencore/uid.h> + +#include <stdarg.h> +#include <stdio.h> + +#if ZEN_PLATFORM_LINUX +# include <pwd.h> +#endif + +#if ZEN_PLATFORM_MAC +# include <pwd.h> +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <spdlog/spdlog.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_USE_SENTRY +# define SENTRY_BUILD_STATIC 1 +ZEN_THIRD_PARTY_INCLUDES_START +# include <sentry.h> +# include <spdlog/sinks/base_sink.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace sentry { + +namespace { + static const std::string DefaultDsn("https://[email protected]/5919284"); +} + +struct SentryAssertImpl : zen::AssertImpl +{ + virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION OnAssert(const char* Filename, + int LineNumber, + const char* FunctionName, + const char* Msg, + const zen::CallstackFrames* Callstack) override; +}; + +class sentry_sink final : public spdlog::sinks::base_sink<spdlog::details::null_mutex> +{ +public: + sentry_sink(); + ~sentry_sink(); + +protected: + void sink_it_(const spdlog::details::log_msg& msg) override; + void flush_() override; +}; + +////////////////////////////////////////////////////////////////////////// + +static constexpr sentry_level_t MapToSentryLevel[spdlog::level::level_enum::n_levels] = {SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_INFO, + SENTRY_LEVEL_WARNING, + SENTRY_LEVEL_ERROR, + SENTRY_LEVEL_FATAL, + SENTRY_LEVEL_DEBUG}; + +sentry_sink::sentry_sink() +{ +} +sentry_sink::~sentry_sink() +{ +} + +void +sentry_sink::sink_it_(const spdlog::details::log_msg& msg) +{ + if (msg.level != spdlog::level::err && msg.level != spdlog::level::critical) + { + return; + } + try + { + std::string Message = fmt::format("{}\n{}({}) [{}]", msg.payload, msg.source.filename, msg.source.line, msg.source.funcname); + sentry_value_t event = sentry_value_new_message_event( + /* level */ MapToSentryLevel[msg.level], + /* logger */ nullptr, + /* message */ Message.c_str()); + sentry_event_value_add_stacktrace(event, NULL, 0); + sentry_capture_event(event); + } + catch (const std::exception&) + { + // If our logging with Message formatting fails we do a non-allocating version and just post the msg.payload raw + char TmpBuffer[256]; + size_t MaxCopy = zen::Min<size_t>(msg.payload.size(), size_t(255)); + memcpy(TmpBuffer, msg.payload.data(), MaxCopy); + TmpBuffer[MaxCopy] = '\0'; + sentry_value_t event = sentry_value_new_message_event( + /* level */ SENTRY_LEVEL_ERROR, + /* logger */ nullptr, + /* message */ TmpBuffer); + sentry_event_value_add_stacktrace(event, NULL, 0); + sentry_capture_event(event); + } +} +void +sentry_sink::flush_() +{ +} + +void +SentryAssertImpl::OnAssert(const char* Filename, + int LineNumber, + const char* FunctionName, + const char* Msg, + const zen::CallstackFrames* Callstack) +{ + // Sentry will provide its own callstack + ZEN_UNUSED(Callstack); + try + { + std::string Message = fmt::format("ASSERT {}:({}) [{}]\n\"{}\"", Filename, LineNumber, FunctionName, Msg); + sentry_value_t event = sentry_value_new_message_event( + /* level */ SENTRY_LEVEL_ERROR, + /* logger */ nullptr, + /* message */ Message.c_str()); + sentry_event_value_add_stacktrace(event, NULL, 0); + sentry_capture_event(event); + } + catch (const std::exception&) + { + // If our logging with Message formatting fails we do a non-allocating version and just post the Msg raw + sentry_value_t event = sentry_value_new_message_event( + /* level */ SENTRY_LEVEL_ERROR, + /* logger */ nullptr, + /* message */ Msg); + sentry_event_value_add_stacktrace(event, NULL, 0); + sentry_capture_event(event); + } +} + +} // namespace sentry + +namespace zen { + +# if ZEN_USE_SENTRY +static void +SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata) +{ + char LogMessageBuffer[160]; + std::string LogMessage; + const char* MessagePtr = LogMessageBuffer; + + int n = vsnprintf(LogMessageBuffer, sizeof LogMessageBuffer, Message, Args); + + if (n >= int(sizeof LogMessageBuffer)) + { + LogMessage.resize(n + 1); + + n = vsnprintf(LogMessage.data(), LogMessage.size(), Message, Args); + + MessagePtr = LogMessage.c_str(); + } + + switch (Level) + { + case SENTRY_LEVEL_DEBUG: + ZEN_CONSOLE_DEBUG("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_INFO: + ZEN_CONSOLE_INFO("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_WARNING: + ZEN_CONSOLE_WARN("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_ERROR: + ZEN_CONSOLE_ERROR("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_FATAL: + ZEN_CONSOLE_CRITICAL("sentry: {}", MessagePtr); + break; + } +} +# endif + +SentryIntegration::SentryIntegration() +{ +} + +SentryIntegration::~SentryIntegration() +{ + if (m_IsInitialized && m_SentryErrorCode == 0) + { + logging::SetErrorLog(""); + m_SentryAssert.reset(); + sentry_close(); + } +} + +void +SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine) +{ + m_AllowPII = Conf.AllowPII; + + std::string SentryDatabasePath = Conf.DatabasePath; + if (SentryDatabasePath.starts_with("\\\\?\\")) + { + SentryDatabasePath = SentryDatabasePath.substr(4); + } + sentry_options_t* SentryOptions = sentry_options_new(); + + sentry_options_set_dsn(SentryOptions, Conf.Dsn.empty() ? sentry::DefaultDsn.c_str() : Conf.Dsn.c_str()); + sentry_options_set_database_path(SentryOptions, SentryDatabasePath.c_str()); + sentry_options_set_logger(SentryOptions, SentryLogFunction, this); + sentry_options_set_environment(SentryOptions, Conf.Environment.empty() ? "production" : Conf.Environment.c_str()); + + std::string SentryAttachmentsPath = Conf.AttachmentsPath; + if (!SentryAttachmentsPath.empty()) + { + if (SentryAttachmentsPath.starts_with("\\\\?\\")) + { + SentryAttachmentsPath = SentryAttachmentsPath.substr(4); + } + sentry_options_add_attachment(SentryOptions, SentryAttachmentsPath.c_str()); + } + sentry_options_set_release(SentryOptions, ZEN_CFG_VERSION); + + if (Conf.Debug) + { + sentry_options_set_debug(SentryOptions, 1); + } + + m_SentryErrorCode = sentry_init(SentryOptions); + + if (m_SentryErrorCode == 0) + { + sentry_value_t SentryUserObject = sentry_value_new_object(); + + if (m_AllowPII) + { +# if ZEN_PLATFORM_WINDOWS + CHAR Buffer[511 + 1]; + DWORD BufferLength = sizeof(Buffer) / sizeof(CHAR); + BOOL OK = GetUserNameA(Buffer, &BufferLength); + if (OK && BufferLength) + { + m_SentryUserName = std::string(Buffer, BufferLength - 1); + } + BufferLength = sizeof(Buffer) / sizeof(CHAR); + OK = GetComputerNameA(Buffer, &BufferLength); + if (OK && BufferLength) + { + m_SentryHostName = std::string(Buffer, BufferLength); + } + else + { + m_SentryHostName = "unknown"; + } +# endif // ZEN_PLATFORM_WINDOWS + +# if (ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC) + uid_t uid = geteuid(); + struct passwd* pw = getpwuid(uid); + if (pw) + { + m_SentryUserName = std::string(pw->pw_name); + } + else + { + m_SentryUserName = "unknown"; + } + char HostNameBuffer[1023 + 1]; + int err = gethostname(HostNameBuffer, sizeof(HostNameBuffer)); + if (err == 0) + { + m_SentryHostName = std::string(HostNameBuffer); + } + else + { + m_SentryHostName = "unknown"; + } +# endif + m_SentryId = fmt::format("{}@{}", m_SentryUserName, m_SentryHostName); + sentry_value_set_by_key(SentryUserObject, "id", sentry_value_new_string(m_SentryId.c_str())); + sentry_value_set_by_key(SentryUserObject, "username", sentry_value_new_string(m_SentryUserName.c_str())); + sentry_value_set_by_key(SentryUserObject, "ip_address", sentry_value_new_string("{{auto}}")); + } + + sentry_value_set_by_key(SentryUserObject, "cmd", sentry_value_new_string(CommandLine.c_str())); + + const std::string SessionId(GetSessionIdString()); + sentry_value_set_by_key(SentryUserObject, "session", sentry_value_new_string(SessionId.c_str())); + + sentry_set_user(SentryUserObject); + + m_SentryLogger = spdlog::create<sentry::sentry_sink>("sentry"); + logging::SetErrorLog("sentry"); + + m_SentryAssert = std::make_unique<sentry::SentryAssertImpl>(); + } + + m_IsInitialized = true; +} + +void +SentryIntegration::LogStartupInformation() +{ + if (m_IsInitialized) + { + if (m_SentryErrorCode == 0) + { + if (m_AllowPII) + { + ZEN_INFO("sentry initialized, username: '{}', hostname: '{}', id: '{}'", m_SentryUserName, m_SentryHostName, m_SentryId); + } + else + { + ZEN_INFO("sentry initialized with anonymous reports"); + } + } + else + { + ZEN_WARN( + "sentry_init returned failure! (error code: {}) note that sentry expects crashpad_handler to exist alongside the running " + "executable", + m_SentryErrorCode); + } + } +} + +void +SentryIntegration::ClearCaches() +{ + sentry_clear_modulecache(); +} + +} // namespace zen +#endif diff --git a/src/zencore/sharedbuffer.cpp b/src/zencore/sharedbuffer.cpp index 993ca40e6..78efb9d42 100644 --- a/src/zencore/sharedbuffer.cpp +++ b/src/zencore/sharedbuffer.cpp @@ -2,6 +2,7 @@ #include <zencore/except.h> #include <zencore/fmtutils.h> +#include <zencore/memory/memory.h> #include <zencore/sharedbuffer.h> #include <zencore/testing.h> diff --git a/src/zencore/stats.cpp b/src/zencore/stats.cpp index 7c1a9e086..8a424c5ad 100644 --- a/src/zencore/stats.cpp +++ b/src/zencore/stats.cpp @@ -3,9 +3,11 @@ #include "zencore/stats.h" #include <zencore/compactbinarybuilder.h> -#include "zencore/intmath.h" -#include "zencore/thread.h" -#include "zencore/timer.h" +#include <zencore/intmath.h> +#include <zencore/memory/llm.h> +#include <zencore/memory/tagtrace.h> +#include <zencore/thread.h> +#include <zencore/timer.h> #include <cmath> #include <gsl/gsl-lite.hpp> @@ -222,8 +224,10 @@ thread_local xoshiro256 ThreadLocalRng; ////////////////////////////////////////////////////////////////////////// -UniformSample::UniformSample(uint32_t ReservoirSize) : m_Values(ReservoirSize) +UniformSample::UniformSample(uint32_t ReservoirSize) { + ZEN_MEMSCOPE(ELLMTag::Metrics); + m_Values = std::vector<std::atomic<int64_t>>(ReservoirSize); } UniformSample::~UniformSample() @@ -273,6 +277,8 @@ UniformSample::Update(int64_t Value) SampleSnapshot UniformSample::Snapshot() const { + ZEN_MEMSCOPE(ELLMTag::Metrics); + uint64_t ValuesSize = Size(); std::vector<double> Values(ValuesSize); diff --git a/src/zencore/stream.cpp b/src/zencore/stream.cpp index ee97a53c4..a800ce121 100644 --- a/src/zencore/stream.cpp +++ b/src/zencore/stream.cpp @@ -1,7 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include <stdarg.h> -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <zencore/stream.h> #include <zencore/testing.h> @@ -25,8 +25,9 @@ BinaryWriter::Write(std::initializer_list<const MemoryView> Buffers) } for (const MemoryView& View : Buffers) { - memcpy(m_Buffer.data() + m_Offset, View.GetData(), View.GetSize()); - m_Offset += View.GetSize(); + size_t Size = View.GetSize(); + memcpy(m_Buffer.data() + m_Offset, View.GetData(), Size); + m_Offset += Size; } } diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index ad6ee78fc..c8c7c2cde 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zencore/memory.h> +#include <zencore/memory/memory.h> +#include <zencore/memoryview.h> #include <zencore/string.h> #include <zencore/testing.h> @@ -98,6 +99,20 @@ FilepathFindExtension(const std::string_view& Path, const char* ExtensionToMatch ////////////////////////////////////////////////////////////////////////// +bool +IsValidUtf8(const std::string_view& str) +{ + return utf8::is_valid(begin(str), end(str)); +} + +std::string_view::const_iterator +FindFirstInvalidUtf8Byte(const std::string_view& str) +{ + return utf8::find_invalid(begin(str), end(str)); +} + +////////////////////////////////////////////////////////////////////////// + void Utf8ToWide(const char8_t* Str8, WideStringBuilderBase& OutString) { @@ -468,12 +483,67 @@ template class StringBuilderImpl<char>; template class StringBuilderImpl<wchar_t>; ////////////////////////////////////////////////////////////////////////// + +void +UrlDecode(std::string_view InUrl, StringBuilderBase& OutUrl) +{ + std::string_view::size_type i = 0; + + for (; i != InUrl.size();) + { + char c = InUrl[i]; + + if ((c == '%') && ((i + 2) < InUrl.size())) + { + char hex[2] = {InUrl[i + 1], InUrl[i + 2]}; + uint8_t HexedChar; + if (ParseHexBytes(hex, 2, &HexedChar)) + { + OutUrl.Append(HexedChar); + i += 3; + + continue; + } + } + + OutUrl.Append(c); + ++i; + } +} + +std::string +UrlDecode(std::string_view InUrl) +{ + ExtendableStringBuilder<128> Url; + UrlDecode(InUrl, Url); + + return std::string(Url.ToView()); +} + +////////////////////////////////////////////////////////////////////////// // // Unit tests // #if ZEN_WITH_TESTS +TEST_CASE("url") +{ + using namespace std::literals; + + ExtendableStringBuilder<32> OutUrl; + UrlDecode("http://blah.com/foo?bar=hi%20ho", OutUrl); + CHECK_EQ(OutUrl.ToView(), "http://blah.com/foo?bar=hi ho"sv); + + OutUrl.Reset(); + + UrlDecode("http://blah.com/foo?bar=hi%ho", OutUrl); + CHECK_EQ(OutUrl.ToView(), "http://blah.com/foo?bar=hi%ho"sv); + + CHECK_EQ(UrlDecode("http://blah.com/foo?bar=hi%20ho"), "http://blah.com/foo?bar=hi ho"sv); + CHECK_EQ(UrlDecode("http://blah.com/foo?bar=hi%ho"), "http://blah.com/foo?bar=hi%ho"sv); +} + TEST_CASE("niceNum") { char Buffer[16]; diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index f51273e0d..f37bdf423 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -4,6 +4,7 @@ #include <zencore/compactbinarybuilder.h> #include <zencore/except.h> +#include <zencore/memory/memory.h> #include <zencore/string.h> #if ZEN_PLATFORM_WINDOWS diff --git a/src/zencore/testutils.cpp b/src/zencore/testutils.cpp index d4c8aeaef..9f50de032 100644 --- a/src/zencore/testutils.cpp +++ b/src/zencore/testutils.cpp @@ -4,6 +4,7 @@ #if ZEN_WITH_TESTS +# include <zencore/filesystem.h> # include <zencore/session.h> # include "zencore/string.h" @@ -19,8 +20,8 @@ CreateTemporaryDirectory() std::error_code Ec; std::filesystem::path DirPath = std::filesystem::temp_directory_path() / GetSessionIdString() / IntNum(++Sequence).c_str(); - std::filesystem::remove_all(DirPath, Ec); - std::filesystem::create_directories(DirPath); + DeleteDirectories(DirPath, Ec); + CreateDirectories(DirPath); return DirPath; } @@ -32,14 +33,14 @@ ScopedTemporaryDirectory::ScopedTemporaryDirectory() : m_RootPath(CreateTemporar ScopedTemporaryDirectory::ScopedTemporaryDirectory(std::filesystem::path Directory) : m_RootPath(Directory) { std::error_code Ec; - std::filesystem::remove_all(Directory, Ec); - std::filesystem::create_directories(Directory); + DeleteDirectories(Directory, Ec); + CreateDirectories(Directory); } ScopedTemporaryDirectory::~ScopedTemporaryDirectory() { std::error_code Ec; - std::filesystem::remove_all(m_RootPath, Ec); + DeleteDirectories(m_RootPath, Ec); } IoBuffer @@ -71,6 +72,26 @@ CreateRandomBlob(uint64_t Size) return Data; }; +IoBuffer +CreateSemiRandomBlob(uint64_t Size) +{ + IoBuffer Result(Size); + const size_t PartCount = (Size / (1u * 1024u * 64)) + 1; + const size_t PartSize = Size / PartCount; + auto Part = CreateRandomBlob(PartSize); + auto Remain = Result.GetMutableView().CopyFrom(Part.GetView()); + while (Remain.GetSize() >= PartSize) + { + Remain = Remain.CopyFrom(Part.GetView()); + } + if (Remain.GetSize() > 0) + { + auto RemainBuffer = CreateRandomBlob(Remain.GetSize()); + Remain.CopyFrom(RemainBuffer.GetView()); + } + return Result; +}; + } // namespace zen #endif // ZEN_WITH_TESTS diff --git a/src/zencore/thread.cpp b/src/zencore/thread.cpp index cb3aced33..abf282467 100644 --- a/src/zencore/thread.cpp +++ b/src/zencore/thread.cpp @@ -80,8 +80,12 @@ SetNameInternal(DWORD thread_id, const char* name) void SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName) { - std::string ThreadNameZ{ThreadName}; - const int ThreadId = GetCurrentThreadId(); + constexpr std::string_view::size_type MaxThreadNameLength = 255; + std::string_view LimitedThreadName = ThreadName.substr(0, MaxThreadNameLength); + StringBuilder<MaxThreadNameLength + 1> ThreadNameZ; + ThreadNameZ << LimitedThreadName; + const int ThreadId = GetCurrentThreadId(); + #if ZEN_WITH_TRACE trace::ThreadRegister(ThreadNameZ.c_str(), /* system id */ ThreadId, /* sort id */ 0); #endif // ZEN_WITH_TRACE @@ -93,7 +97,10 @@ SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName) if (SetThreadDescriptionFunc) { - SetThreadDescriptionFunc(::GetCurrentThread(), Utf8ToWide(ThreadName).c_str()); + WideStringBuilder<MaxThreadNameLength + 1> ThreadNameW; + Utf8ToWide(LimitedThreadName, ThreadNameW); + + SetThreadDescriptionFunc(::GetCurrentThread(), ThreadNameW.c_str()); } // The debugger needs to be around to catch the name in the exception. If @@ -153,10 +160,10 @@ Event::Event() m_EventHandle = CreateEvent(nullptr, bManualReset, bInitialState, nullptr); #else ZEN_UNUSED(bManualReset); - auto* Inner = new EventInner(); + auto* Inner = new EventInner(); + std::unique_lock Lock(Inner->Mutex); Inner->bSet = bInitialState; m_EventHandle = Inner; - std::atomic_thread_fence(std::memory_order_release); #endif } @@ -208,9 +215,8 @@ Event::Close() { std::unique_lock Lock(Inner->Mutex); Inner->bSet.store(true); + m_EventHandle = nullptr; } - m_EventHandle = nullptr; - std::atomic_thread_fence(std::memory_order_release); delete Inner; #endif } @@ -326,6 +332,7 @@ NamedEvent::NamedEvent(std::string_view EventName) Packed |= intptr_t(Fd) & 0xffff'ffff; m_EventHandle = (void*)Packed; #endif + ZEN_ASSERT(m_EventHandle != nullptr); } NamedEvent::~NamedEvent() @@ -348,8 +355,16 @@ NamedEvent::Close() if (flock(Fd, LOCK_EX | LOCK_NB) == 0) { - std::filesystem::path Name = PathFromHandle((void*)(intptr_t(Fd))); - unlink(Name.c_str()); + std::error_code Ec; + std::filesystem::path Name = PathFromHandle((void*)(intptr_t(Fd)), Ec); + if (Ec) + { + ZEN_WARN("Error reported on get file path from handle {} for named event unlink operation, reason '{}'", Fd, Ec.message()); + } + else + { + unlink(Name.c_str()); + } flock(Fd, LOCK_UN | LOCK_NB); close(Fd); @@ -362,20 +377,29 @@ NamedEvent::Close() m_EventHandle = nullptr; } -void +std::error_code NamedEvent::Set() { + ZEN_ASSERT(m_EventHandle != nullptr); #if ZEN_PLATFORM_WINDOWS - SetEvent(m_EventHandle); + if (!SetEvent(m_EventHandle)) + { + return MakeErrorCodeFromLastError(); + } #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC int Sem = int(intptr_t(m_EventHandle) >> 32); - semctl(Sem, 0, SETVAL, 0); + if (semctl(Sem, 0, SETVAL, 0) == -1) + { + return MakeErrorCodeFromLastError(); + } #endif + return {}; } bool NamedEvent::Wait(int TimeoutMs) { + ZEN_ASSERT(m_EventHandle != nullptr); #if ZEN_PLATFORM_WINDOWS const DWORD Timeout = (TimeoutMs < 0) ? INFINITE : TimeoutMs; @@ -461,7 +485,7 @@ NamedMutex::Create(std::string_view MutexName) ExtendableStringBuilder<64> Name; Name << "/tmp/" << MutexName; - int Inner = open(Name.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666); + int Inner = open(Name.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666); if (Inner < 0) { return false; @@ -584,12 +608,12 @@ TEST_CASE("NamedEvent") ReadyEvent.Set(); NamedEvent TestEvent(Name); - TestEvent.Wait(1000); + TestEvent.Wait(100); }); ReadyEvent.Wait(); - zen::Sleep(100); + zen::Sleep(10); TestEvent.Set(); Waiter.join(); @@ -597,7 +621,7 @@ TEST_CASE("NamedEvent") // Manual reset property for (uint32_t i = 0; i < 8; ++i) { - bool bEventSet = TestEvent.Wait(100); + bool bEventSet = TestEvent.Wait(10); CHECK(bEventSet); } } diff --git a/src/zencore/timer.cpp b/src/zencore/timer.cpp index 1655e912d..95536cb26 100644 --- a/src/zencore/timer.cpp +++ b/src/zencore/timer.cpp @@ -12,9 +12,20 @@ # include <unistd.h> #endif +#define GTSPS_IMPLEMENTATION +#include "GetTimeSinceProcessStart.h" + namespace zen { uint64_t +GetTimeSinceProcessStart() +{ + double TimeInSeconds = ::GetTimeSinceProcessStart(); + + return uint64_t(TimeInSeconds * 1000); +} + +uint64_t GetHifreqTimerValue() { uint64_t Timestamp; diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp index f7e4c4b68..fe8fb9a5d 100644 --- a/src/zencore/trace.cpp +++ b/src/zencore/trace.cpp @@ -4,10 +4,92 @@ # include <zencore/config.h> # include <zencore/zencore.h> +# include <zencore/commandline.h> +# include <zencore/string.h> +# include <zencore/logging.h> # define TRACE_IMPLEMENT 1 # include <zencore/trace.h> +# include <zencore/memory/fmalloc.h> +# include <zencore/memory/memorytrace.h> + +namespace zen { + +void +TraceConfigure(const TraceOptions& Options) +{ + // Configure channels based on command line options + + using namespace std::literals; + + std::function<void(const std::string_view&)> ProcessChannelList; + + auto ProcessTraceArg = [&](const std::string_view& Arg) { + if (Arg == "default"sv) + { + ProcessChannelList("cpu,log"sv); + } + else if (Arg == "memory"sv) + { + ProcessChannelList("memtag,memalloc,callstack,module"sv); + } + else if (Arg == "memory_light"sv) + { + ProcessChannelList("memtag,memalloc"sv); + } + else if (Arg == "memtag"sv) + { + // memtag actually traces to the memalloc channel + ProcessChannelList("memalloc"sv); + } + else + { + // Presume that the argument is a trace channel name + + StringBuilder<128> AnsiChannel; + AnsiChannel << Arg; + + const bool IsEnabled = trace::ToggleChannel(AnsiChannel.c_str(), true); + + if (IsEnabled == false) + { + // Logging here could be iffy, but we might want some other feedback mechanism here + // to indicate to users that they're not getting what they might be expecting + } + } + }; + + ProcessChannelList = [&](const std::string_view& OptionArgs) { IterateCommaSeparatedValue(OptionArgs, ProcessTraceArg); }; + + if (Options.Channels.empty()) + { + ProcessTraceArg("default"sv); + } + else + { + ProcessChannelList(Options.Channels); + } + + if (Options.Host.size()) + { + trace::SendTo(Options.Host.c_str()); + } + else if (Options.File.size()) + { + trace::WriteTo(Options.File.c_str()); + } + +# if ZEN_WITH_MEMTRACK + FMalloc* TraceMalloc = MemoryTrace_Create(GMalloc, Options); + if (TraceMalloc != GMalloc) + { + GMalloc = TraceMalloc; + MemoryTrace_Initialize(); + } +# endif +} + void TraceInit(std::string_view ProgramName) { @@ -38,6 +120,14 @@ TraceInit(std::string_view ProgramName) # endif CommandLineString, ZEN_CFG_VERSION_BUILD_STRING); + + atexit([] { +# if ZEN_WITH_MEMTRACK + zen::MemoryTrace_Shutdown(); +# endif + trace::Update(); + TraceShutdown(); + }); } void @@ -53,30 +143,9 @@ IsTracing() return trace::IsTracing(); } -void -TraceStart(std::string_view ProgramName, const char* HostOrPath, TraceType Type) -{ - TraceInit(ProgramName); - switch (Type) - { - case TraceType::Network: - trace::SendTo(HostOrPath); - break; - - case TraceType::File: - trace::WriteTo(HostOrPath); - break; - - case TraceType::None: - break; - } - trace::ToggleChannel("cpu", true); -} - bool TraceStop() { - trace::ToggleChannel("cpu", false); if (trace::Stop()) { return true; @@ -84,4 +153,57 @@ TraceStop() return false; } +bool +GetTraceOptionsFromCommandline(TraceOptions& OutOptions) +{ + bool HasOptions = false; + +# if ZEN_WITH_TRACE + using namespace std::literals; + + auto MatchesArg = [](std::string_view Option, std::string_view Arg) -> std::optional<std::string_view> { + if (Arg.starts_with(Option)) + { + std::string_view::value_type DelimChar = Arg[Option.length()]; + if (DelimChar == ' ' || DelimChar == '=') + { + return Arg.substr(Option.size() + 1); + } + } + return {}; + }; + + constexpr std::string_view TraceOption = "--trace"sv; + constexpr std::string_view TraceHostOption = "--tracehost"sv; + constexpr std::string_view TraceFileOption = "--tracefile"sv; + + std::function<void(const std::string_view&)> ProcessArg = [&](const std::string_view& Arg) { + if (auto Host = MatchesArg(TraceHostOption, Arg); Host.has_value()) + { + OutOptions.Host = Host.value(); + HasOptions = true; + } + else if (auto File = MatchesArg(TraceFileOption, Arg); File.has_value()) + { + OutOptions.File = File.value(); + HasOptions = true; + } + else if (auto Channels = MatchesArg(TraceOption, Arg); Channels.has_value()) + { + if (!OutOptions.Channels.empty()) + { + OutOptions.Channels = ","sv; + } + OutOptions.Channels += Channels.value(); + HasOptions = true; + } + }; + + IterateCommandlineArgs(ProcessArg); +# endif // ZEN_WITH_TRACE + return HasOptions; +} + +} // namespace zen + #endif // ZEN_WITH_TRACE diff --git a/src/zencore/uid.cpp b/src/zencore/uid.cpp index 0f04d70ac..d7636f2ad 100644 --- a/src/zencore/uid.cpp +++ b/src/zencore/uid.cpp @@ -83,11 +83,41 @@ Oid::FromHexString(const std::string_view String) } Oid -Oid::FromMemory(const void* Ptr) +Oid::TryFromHexString(const std::string_view String, const Oid& Default) { + if (String.length() != StringLength) + { + return Default; + } + Oid Id; - memcpy(Id.OidBits, Ptr, sizeof Id); - return Id; + + if (ParseHexBytes(String.data(), String.size(), reinterpret_cast<uint8_t*>(Id.OidBits))) + { + return Id; + } + else + { + return Default; + } +} + +bool +Oid::TryParse(std::string_view Str, Oid& Id) +{ + using namespace std::literals; + + if (Str.size() == Oid::StringLength) + { + return ParseHexBytes(Str.data(), Str.size(), reinterpret_cast<uint8_t*>(Id.OidBits)); + } + + if (Str.starts_with("0x"sv)) + { + return TryParse(Str.substr(2), Id); + } + + return false; } void @@ -97,6 +127,14 @@ Oid::ToString(char OutString[StringLength]) const OutString[StringLength] = '\0'; } +std::string +Oid::ToString() const +{ + char OutString[StringLength + 1]; + ToString(OutString); + return std::string(OutString, StringLength); +} + StringBuilderBase& Oid::ToString(StringBuilderBase& OutString) const { @@ -108,6 +146,14 @@ Oid::ToString(StringBuilderBase& OutString) const return OutString; } +Oid +Oid::FromMemory(const void* Ptr) +{ + Oid Id; + memcpy(Id.OidBits, Ptr, sizeof Id); + return Id; +} + #if ZEN_WITH_TESTS TEST_CASE("Oid") diff --git a/src/zencore/windows.cpp b/src/zencore/windows.cpp index 76d8ab445..d02fcd35e 100644 --- a/src/zencore/windows.cpp +++ b/src/zencore/windows.cpp @@ -9,6 +9,19 @@ namespace zen::windows { +bool +IsRunningOnWine() +{ + HMODULE NtDll = GetModuleHandleA("ntdll.dll"); + + if (NtDll) + { + return !!GetProcAddress(NtDll, "wine_get_version"); + } + + return false; +} + FileMapping::FileMapping(_In_ FileMapping& orig) { m_pData = NULL; diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp index 16b2310ff..e241c0de8 100644 --- a/src/zencore/workthreadpool.cpp +++ b/src/zencore/workthreadpool.cpp @@ -3,7 +3,9 @@ #include <zencore/workthreadpool.h> #include <zencore/blockingqueue.h> +#include <zencore/except.h> #include <zencore/logging.h> +#include <zencore/scopeguard.h> #include <zencore/string.h> #include <zencore/testing.h> #include <zencore/thread.h> @@ -12,6 +14,10 @@ #include <thread> #include <vector> +ZEN_THIRD_PARTY_INCLUDES_START +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #define ZEN_USE_WINDOWS_THREADPOOL 1 #if ZEN_PLATFORM_WINDOWS && ZEN_USE_WINDOWS_THREADPOOL @@ -40,34 +46,54 @@ namespace { struct WorkerThreadPool::Impl { + const int m_ThreadCount = 0; PTP_POOL m_ThreadPool = nullptr; PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; TP_CALLBACK_ENVIRON m_CallbackEnvironment; PTP_WORK m_Work = nullptr; - std::string m_WorkerThreadBaseName; - std::atomic<int> m_WorkerThreadCounter{0}; + std::string m_WorkerThreadBaseName; + std::atomic<size_t> m_WorkerThreadCounter{0}; + std::atomic<int> m_FreeWorkerCount{0}; - RwLock m_QueueLock; + mutable RwLock m_QueueLock; std::deque<Ref<IWork>> m_WorkQueue; - Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_WorkerThreadBaseName(WorkerThreadBaseName) + Impl(int InThreadCount, std::string_view WorkerThreadBaseName) + : m_ThreadCount(InThreadCount) + , m_WorkerThreadBaseName(WorkerThreadBaseName) + , m_FreeWorkerCount(m_ThreadCount) { // Thread pool setup m_ThreadPool = CreateThreadpool(NULL); + if (m_ThreadPool == NULL) + { + ThrowLastError("CreateThreadpool failed"); + } - SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); - SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2); + if (!SetThreadpoolThreadMinimum(m_ThreadPool, (DWORD)m_ThreadCount)) + { + ThrowLastError("SetThreadpoolThreadMinimum failed"); + } + SetThreadpoolThreadMaximum(m_ThreadPool, (DWORD)m_ThreadCount); InitializeThreadpoolEnvironment(&m_CallbackEnvironment); m_CleanupGroup = CreateThreadpoolCleanupGroup(); + if (m_CleanupGroup == NULL) + { + ThrowLastError("CreateThreadpoolCleanupGroup failed"); + } SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); m_Work = CreateThreadpoolWork(&WorkCallback, this, &m_CallbackEnvironment); + if (m_Work == NULL) + { + ThrowLastError("CreateThreadpoolWork failed"); + } } ~Impl() @@ -77,12 +103,29 @@ struct WorkerThreadPool::Impl CloseThreadpool(m_ThreadPool); } - void ScheduleWork(Ref<IWork> Work) + [[nodiscard]] Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) { - m_QueueLock.WithExclusiveLock([&] { m_WorkQueue.push_back(std::move(Work)); }); + if (Mode == WorkerThreadPool::EMode::DisableBacklog) + { + if (m_FreeWorkerCount <= 0) + { + return Work; + } + RwLock::ExclusiveLockScope _(m_QueueLock); + const int QueuedCount = gsl::narrow<int>(m_WorkQueue.size()); + if (QueuedCount >= m_FreeWorkerCount) + { + return Work; + } + m_WorkQueue.push_back(std::move(Work)); + } + else + { + m_QueueLock.WithExclusiveLock([&] { m_WorkQueue.push_back(std::move(Work)); }); + } SubmitThreadpoolWork(m_Work); + return {}; } - [[nodiscard]] size_t PendingWorkItemCount() const { return 0; } static VOID CALLBACK WorkCallback(_Inout_ PTP_CALLBACK_INSTANCE Instance, _Inout_opt_ PVOID Context, _Inout_ PTP_WORK Work) { @@ -93,10 +136,13 @@ struct WorkerThreadPool::Impl void DoWork() { + m_FreeWorkerCount--; + auto _ = MakeGuard([&]() { m_FreeWorkerCount++; }); + if (!t_IsThreadNamed) { t_IsThreadNamed = true; - const int ThreadIndex = ++m_WorkerThreadCounter; + const size_t ThreadIndex = ++m_WorkerThreadCounter; zen::ExtendableStringBuilder<128> ThreadName; ThreadName << m_WorkerThreadBaseName << "_" << ThreadIndex; SetCurrentThreadName(ThreadName); @@ -105,7 +151,7 @@ struct WorkerThreadPool::Impl Ref<IWork> WorkFromQueue; { - RwLock::ExclusiveLockScope _{m_QueueLock}; + RwLock::ExclusiveLockScope __{m_QueueLock}; WorkFromQueue = std::move(m_WorkQueue.front()); m_WorkQueue.pop_front(); } @@ -125,20 +171,25 @@ struct WorkerThreadPool::ThreadStartInfo struct WorkerThreadPool::Impl { + const int m_ThreadCount = 0; void WorkerThreadFunction(ThreadStartInfo Info); std::string m_WorkerThreadBaseName; std::vector<std::thread> m_WorkerThreads; BlockingQueue<Ref<IWork>> m_WorkQueue; + std::atomic<int> m_FreeWorkerCount{0}; - Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_WorkerThreadBaseName(WorkerThreadBaseName) + Impl(int InThreadCount, std::string_view WorkerThreadBaseName) + : m_ThreadCount(InThreadCount) + , m_WorkerThreadBaseName(WorkerThreadBaseName) + , m_FreeWorkerCount(m_ThreadCount) { # if ZEN_WITH_TRACE trace::ThreadGroupBegin(m_WorkerThreadBaseName.c_str()); # endif - zen::Latch WorkerLatch{InThreadCount}; + zen::Latch WorkerLatch{m_ThreadCount}; - for (int i = 0; i < InThreadCount; ++i) + for (int i = 0; i < m_ThreadCount; ++i) { m_WorkerThreads.emplace_back(&Impl::WorkerThreadFunction, this, ThreadStartInfo{i + 1, &WorkerLatch}); } @@ -165,8 +216,23 @@ struct WorkerThreadPool::Impl m_WorkerThreads.clear(); } - void ScheduleWork(Ref<IWork> Work) { m_WorkQueue.Enqueue(std::move(Work)); } - [[nodiscard]] size_t PendingWorkItemCount() const { return m_WorkQueue.Size(); } + [[nodiscard]] Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) + { + if (Mode == WorkerThreadPool::EMode::DisableBacklog) + { + if (m_FreeWorkerCount <= 0) + { + return Work; + } + const int QueuedCount = gsl::narrow<int>(m_WorkQueue.Size()); + if (QueuedCount >= m_FreeWorkerCount) + { + return Work; + } + } + m_WorkQueue.Enqueue(std::move(Work)); + return {}; + } }; void @@ -181,15 +247,23 @@ WorkerThreadPool::Impl::WorkerThreadFunction(ThreadStartInfo Info) Ref<IWork> Work; if (m_WorkQueue.WaitAndDequeue(Work)) { + m_FreeWorkerCount--; + auto _ = MakeGuard([&]() { m_FreeWorkerCount++; }); + try { ZEN_TRACE_CPU_FLUSH("AsyncWork"); Work->Execute(); + Work = {}; } - catch (std::exception& e) + catch (const AssertException& Ex) { - Work->m_Exception = std::current_exception(); - + Work = {}; + ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); + } + catch (const std::exception& e) + { + Work = {}; ZEN_WARN("Caught exception in worker thread: {}", e.what()); } } @@ -221,42 +295,38 @@ WorkerThreadPool::~WorkerThreadPool() } void -WorkerThreadPool::ScheduleWork(Ref<IWork> Work) +WorkerThreadPool::ScheduleWork(Ref<IWork> Work, EMode Mode) { if (m_Impl) { - m_Impl->ScheduleWork(std::move(Work)); - } - else - { - try + if (Work = m_Impl->ScheduleWork(std::move(Work), Mode); !Work) { - ZEN_TRACE_CPU_FLUSH("SyncWork"); - Work->Execute(); + return; } - catch (std::exception& e) - { - Work->m_Exception = std::current_exception(); + } - ZEN_WARN("Caught exception when executing worker synchronously: {}", e.what()); - } + try + { + ZEN_TRACE_CPU_FLUSH("SyncWork"); + Work->Execute(); + Work = {}; + } + catch (const AssertException& Ex) + { + Work = {}; + ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); + } + catch (const std::exception& e) + { + Work = {}; + ZEN_WARN("Caught exception when executing worker synchronously: {}", e.what()); } } void -WorkerThreadPool::ScheduleWork(std::function<void()>&& Work) +WorkerThreadPool::ScheduleWork(std::function<void()>&& Work, EMode Mode) { - ScheduleWork(Ref<IWork>(new detail::LambdaWork(Work))); -} - -[[nodiscard]] size_t -WorkerThreadPool::PendingWorkItemCount() const -{ - if (m_Impl) - { - return m_Impl->PendingWorkItemCount(); - } - return 0; + ScheduleWork(Ref<IWork>(new detail::LambdaWork(std::move(Work))), Mode); } ////////////////////////////////////////////////////////////////////////// @@ -274,9 +344,10 @@ TEST_CASE("threadpool.basic") { WorkerThreadPool Threadpool{1}; - auto Future42 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 42; }}); - auto Future99 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 99; }}); - auto FutureThrow = Threadpool.EnqueueTask(std::packaged_task<void()>{[] { throw std::runtime_error("meep!"); }}); + auto Future42 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 42; }}, WorkerThreadPool::EMode::EnableBacklog); + auto Future99 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 99; }}, WorkerThreadPool::EMode::EnableBacklog); + auto FutureThrow = Threadpool.EnqueueTask(std::packaged_task<void()>{[] { throw std::runtime_error("meep!"); }}, + WorkerThreadPool::EMode::EnableBacklog); CHECK_EQ(Future42.get(), 42); CHECK_EQ(Future99.get(), 99); diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua index 5420464fa..b3a33e052 100644 --- a/src/zencore/xmake.lua +++ b/src/zencore/xmake.lua @@ -3,6 +3,7 @@ target('zencore') set_kind("static") set_group("libs") + add_options("zentrace", "zenmimalloc", "zenrpmalloc") add_headerfiles("**.h") add_configfiles("include/zencore/config.h.in") on_load(function (target) @@ -12,10 +13,27 @@ target('zencore') end) set_configdir("include/zencore") add_files("**.cpp") + add_files("trace.cpp", {unity_ignored = true }) + + if has_config("zenrpmalloc") then + set_languages("c17", "cxx20") + if is_os("windows") then + add_cflags("/experimental:c11atomics") + end + add_defines("RPMALLOC_FIRST_CLASS_HEAPS=1", "ENABLE_STATISTICS=1", "ENABLE_OVERRIDE=0") + add_files("$(projectdir)/thirdparty/rpmalloc/rpmalloc.c") + end + + if has_config("zenmimalloc") then + add_packages("vcpkg::mimalloc") + end + add_includedirs("include", {public=true}) + add_includedirs("$(projectdir)/thirdparty/GetTimeSinceProcessStart") add_includedirs("$(projectdir)/thirdparty/utfcpp/source") add_includedirs("$(projectdir)/thirdparty/Oodle/include") add_includedirs("$(projectdir)/thirdparty/trace", {public=true}) + add_includedirs("$(projectdir)/thirdparty/rpmalloc") if is_os("windows") then add_linkdirs("$(projectdir)/thirdparty/Oodle/lib/Win64") add_links("oo2core_win64") @@ -27,17 +45,18 @@ target('zencore') add_linkdirs("$(projectdir)/thirdparty/Oodle/lib/Mac_x64") add_links("oo2coremac64") end - add_options("zentrace") add_deps("zenbase") add_packages( "vcpkg::blake3", "vcpkg::json11", - "vcpkg::mimalloc", + "vcpkg::ryml", + "vcpkg::c4core", "vcpkg::openssl", -- required for crypto "vcpkg::spdlog") add_packages( "vcpkg::doctest", + "vcpkg::eastl", "vcpkg::fmt", "vcpkg::gsl-lite", "vcpkg::lz4", @@ -45,12 +64,34 @@ target('zencore') {public=true} ) + if has_config("zensentry") then + add_packages("vcpkg::sentry-native") + + if is_plat("windows") then + add_links("dbghelp", "winhttp", "version") -- for Sentry + end + + if is_plat("linux") then + -- As sentry_native uses symbols from breakpad_client, the latter must + -- be specified after the former with GCC-like toolchains. xmake however + -- is unaware of this and simply globs files from vcpkg's output. The + -- line below forces breakpad_client to be to the right of sentry_native + add_syslinks("breakpad_client") + end + + if is_plat("macosx") then + add_syslinks("bsm") + end + + end + if is_plat("linux") then add_syslinks("rt") end if is_plat("windows") then add_syslinks("Advapi32") + add_syslinks("Dbghelp") add_syslinks("Shell32") add_syslinks("User32") add_syslinks("crypt32") diff --git a/src/zencore/xxhash.cpp b/src/zencore/xxhash.cpp index 450131d19..6d1050531 100644 --- a/src/zencore/xxhash.cpp +++ b/src/zencore/xxhash.cpp @@ -47,4 +47,55 @@ XXH3_128::ToHexString(StringBuilderBase& OutBuilder) const return OutBuilder; } +////////////////////////////////////////////////////////////////////////// +// +// Unit tests +// + +#if ZEN_WITH_TESTS + +void +xxhash_forcelink() +{ +} + +TEST_CASE("XXH3_128") +{ + using namespace std::literals; + + const std::string_view ShortString{"1234"}; + const std::string_view LongString{ + "1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + "1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + "1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890"}; + + SUBCASE("short_deprecated") + { + XXH3_128Stream_deprecated x; + x.Append(ShortString.data(), ShortString.size()); + const XXH3_128 Hash = x.GetHash(); + CHECK(Hash == XXH3_128::FromHexString("0d44dd7fde8ea2b4ba961e1a26f71f21"sv)); + } + + SUBCASE("short") + { + XXH3_128Stream x; + x.Append(ShortString.data(), ShortString.size()); + const XXH3_128 Hash = x.GetHash(); + CHECK(Hash == XXH3_128::FromHexString("9a4dea864648af82823c8c03e6dd2202"sv)); + CHECK(Hash == XXH3_128::HashMemory(ShortString.data(), ShortString.size())); + } + + SUBCASE("long") + { + XXH3_128Stream x; + x.Append(LongString.data(), LongString.size()); + const XXH3_128 Hash = x.GetHash(); + CHECK(Hash == XXH3_128::FromHexString("fbd5e72f7a5894590d1ef49dfcc58b7d"sv)); + CHECK(Hash == XXH3_128::HashMemory(LongString.data(), LongString.size())); + } +} + +#endif + } // namespace zen diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp index d0acac608..b78991918 100644 --- a/src/zencore/zencore.cpp +++ b/src/zencore/zencore.cpp @@ -6,12 +6,10 @@ # include <zencore/windows.h> #endif -#if ZEN_PLATFORM_LINUX -# include <pthread.h> -#endif - #include <zencore/assertfmt.h> +#include <zencore/basicfile.h> #include <zencore/blake3.h> +#include <zencore/callstack.h> #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> #include <zencore/compactbinarypackage.h> @@ -23,8 +21,9 @@ #include <zencore/iobuffer.h> #include <zencore/jobqueue.h> #include <zencore/logging.h> -#include <zencore/memory.h> +#include <zencore/memoryview.h> #include <zencore/mpscqueue.h> +#include <zencore/parallelwork.h> #include <zencore/process.h> #include <zencore/sha1.h> #include <zencore/stats.h> @@ -39,6 +38,10 @@ #include <atomic> +namespace zen { +extern void xxhash_forcelink(); +} + namespace zen::assert { void @@ -55,10 +58,120 @@ ExecAssertFmt(const char* Filename, int LineNumber, const char* FunctionName, st namespace zen { +AssertException::AssertException(const AssertException& Rhs) noexcept : _Mybase(Rhs), _Callstack(CloneCallstack(Rhs._Callstack)) +{ +} + +AssertException::AssertException(AssertException&& Rhs) noexcept : _Mybase(Rhs), _Callstack(Rhs._Callstack) +{ + Rhs._Callstack = nullptr; +} + +AssertException::~AssertException() noexcept +{ + FreeCallstack(_Callstack); +} + +AssertException& +AssertException::operator=(const AssertException& Rhs) noexcept +{ + _Mybase::operator=(Rhs); + + CallstackFrames* Callstack = CloneCallstack(Rhs._Callstack); + std::swap(_Callstack, Callstack); + FreeCallstack(Callstack); + return *this; +} + +std::string +AssertException::FullDescription() const noexcept +{ + if (_Callstack) + { + return fmt::format("'{}'\n{}", what(), CallstackToString(_Callstack, " ")); + } + return what(); +} + +AssertImpl::AssertImpl() : NextAssertImpl(nullptr) +{ + AssertImpl** WriteAssertPtr = &CurrentAssertImpl; + while (*WriteAssertPtr) + { + WriteAssertPtr = &(*WriteAssertPtr)->NextAssertImpl; + } + *WriteAssertPtr = this; +} + +AssertImpl::~AssertImpl() +{ + AssertImpl** WriteAssertPtr = &CurrentAssertImpl; + while ((*WriteAssertPtr) != this) + { + ZEN_ASSERT((*WriteAssertPtr) != nullptr); + WriteAssertPtr = &(*WriteAssertPtr)->NextAssertImpl; + } + *WriteAssertPtr = NextAssertImpl; +} + +void +AssertImpl::ExecAssert(const char* Filename, int LineNumber, const char* FunctionName, const char* Msg) +{ + constexpr int SkipFrameCount = 2; + constexpr int FrameCount = 8; + uint8_t CallstackBuffer[CallstackRawMemorySize(SkipFrameCount, FrameCount)]; + CallstackFrames* Callstack = GetCallstackRaw(&CallstackBuffer[0], SkipFrameCount, FrameCount); + + AssertImpl* AssertImpl = CurrentAssertImpl; + while (AssertImpl) + { + try + { + AssertImpl->OnAssert(Filename, LineNumber, FunctionName, Msg, Callstack); + } + catch (const std::exception&) + { + // Just keep exception silent - we don't want exception thrown from assert callbacks + } + AssertImpl = AssertImpl->NextAssertImpl; + } + ThrowAssertException(Filename, LineNumber, FunctionName, Msg, Callstack); +} +void +AssertImpl::OnAssert(const char* Filename, int LineNumber, const char* FunctionName, const char* Msg, const CallstackFrames* Callstack) +{ + ZEN_UNUSED(FunctionName); + + fmt::basic_memory_buffer<char, 2048> Message; + auto Appender = fmt::appender(Message); + fmt::format_to(Appender, "{}({}): ZEN_ASSERT({})\n{}", Filename, LineNumber, Msg, CallstackToString(Callstack, " ")); + Message.push_back('\0'); + + // We use direct ZEN_LOG here instead of ZEN_ERROR as we don't care about *this* code location in the log + ZEN_LOG(Log(), zen::logging::level::Err, "{}", Message.data()); + zen::logging::FlushLogging(); +} + +void +AssertImpl::ThrowAssertException(const char* Filename, + int LineNumber, + const char* FunctionName, + const char* Msg, + const CallstackFrames* Callstack) +{ + ZEN_UNUSED(FunctionName); + fmt::basic_memory_buffer<char, 2048> Message; + auto Appender = fmt::appender(Message); + fmt::format_to(Appender, "{}({}): {}", Filename, LineNumber, Msg); + Message.push_back('\0'); + + throw AssertException(Message.data(), CloneCallstack(Callstack)); +} + void refcount_forcelink(); +AssertImpl* AssertImpl::CurrentAssertImpl = nullptr; AssertImpl AssertImpl::DefaultAssertImpl; -AssertImpl* AssertImpl::CurrentAssertImpl = &AssertImpl::DefaultAssertImpl; ////////////////////////////////////////////////////////////////////////// @@ -115,11 +228,16 @@ IsApplicationExitRequested() return s_ApplicationExitRequested; } -void +bool RequestApplicationExit(int ExitCode) { - s_ApplicationExitCode = ExitCode; - s_ApplicationExitRequested = true; + bool Expected = false; + if (s_ApplicationExitRequested.compare_exchange_weak(Expected, true)) + { + s_ApplicationExitCode = ExitCode; + return true; + } + return false; } int @@ -132,7 +250,9 @@ ApplicationExitCode() void zencore_forcelinktests() { + zen::basicfile_forcelink(); zen::blake3_forcelink(); + zen::callstack_forcelink(); zen::compositebuffer_forcelink(); zen::compress_forcelink(); zen::crypto_forcelink(); @@ -143,6 +263,7 @@ zencore_forcelinktests() zen::logging_forcelink(); zen::memory_forcelink(); zen::mpscqueue_forcelink(); + zen::parallellwork_forcelink(); zen::process_forcelink(); zen::refcount_forcelink(); zen::sha1_forcelink(); @@ -155,7 +276,10 @@ zencore_forcelinktests() zen::uson_forcelink(); zen::usonbuilder_forcelink(); zen::usonpackage_forcelink(); + zen::cbjson_forcelink(); + zen::cbyaml_forcelink(); zen::workthreadpool_forcelink(); + xxhash_forcelink(); } } // namespace zen @@ -167,35 +291,34 @@ TEST_SUITE_BEGIN("core.assert"); TEST_CASE("Assert.Default") { - bool A = true; - bool B = false; - CHECK_THROWS_WITH(ZEN_ASSERT(A == B), "A == B"); + bool A = true; + bool B = false; + std::string Expected = fmt::format("{}({}): {}", __FILE__, __LINE__ + 1, "A == B"); + CHECK_THROWS_WITH(ZEN_ASSERT(A == B), Expected.c_str()); } TEST_CASE("Assert.Format") { - bool A = true; - bool B = false; - CHECK_THROWS_WITH(ZEN_ASSERT_FORMAT(A == B, "{} == {}", A, B), "assert(A == B) failed: true == false"); + bool A = true; + bool B = false; + std::string Expected = fmt::format("{}({}): {}", __FILE__, __LINE__ + 1, "assert(A == B) failed: true == false"); + CHECK_THROWS_WITH(ZEN_ASSERT_FORMAT(A == B, "{} == {}", A, B), Expected.c_str()); } TEST_CASE("Assert.Custom") { struct MyAssertImpl : AssertImpl { - ZEN_FORCENOINLINE ZEN_DEBUG_SECTION MyAssertImpl() : PrevAssertImpl(CurrentAssertImpl) { CurrentAssertImpl = this; } - virtual ZEN_FORCENOINLINE ZEN_DEBUG_SECTION ~MyAssertImpl() { CurrentAssertImpl = PrevAssertImpl; } - virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION OnAssert(const char* Filename, - int LineNumber, - const char* FunctionName, - const char* Msg) + virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION + OnAssert(const char* Filename, int LineNumber, const char* FunctionName, const char* Msg, const CallstackFrames* Callstack) { + ZEN_UNUSED(Callstack); AssertFileName = Filename; Line = LineNumber; FuncName = FunctionName; Message = Msg; } - AssertImpl* PrevAssertImpl; + AssertImpl* PrevAssertImpl = nullptr; const char* AssertFileName = nullptr; int Line = -1; @@ -206,13 +329,48 @@ TEST_CASE("Assert.Custom") MyAssertImpl MyAssert; bool A = true; bool B = false; - CHECK_THROWS_WITH(ZEN_ASSERT(A == B), "A == B"); + CHECK_THROWS_WITH(ZEN_ASSERT(A == B), std::string(fmt::format("{}({}): {}", __FILE__, __LINE__, "A == B")).c_str()); CHECK(MyAssert.AssertFileName != nullptr); CHECK(MyAssert.Line != -1); CHECK(MyAssert.FuncName != nullptr); CHECK(strcmp(MyAssert.Message, "A == B") == 0); } +TEST_CASE("Assert.Callstack") +{ + try + { + ZEN_ASSERT(false); + } + catch (const AssertException& Assert) + { + ZEN_INFO("Assert failed: {}", Assert.what()); + CHECK(Assert._Callstack->FrameCount > 0); + CHECK(Assert._Callstack->Frames != nullptr); + ZEN_INFO("Callstack:\n{}", CallstackToString(Assert._Callstack)); + } + + WorkerThreadPool Pool(1); + auto Task = Pool.EnqueueTask(std::packaged_task<int()>{[] { + ZEN_ASSERT(false); + return 1; + }}, + WorkerThreadPool::EMode::EnableBacklog); + + try + { + (void)Task.get(); + CHECK(false); + } + catch (const AssertException& Assert) + { + ZEN_INFO("Assert in future: {}", Assert.what()); + CHECK(Assert._Callstack->FrameCount > 0); + CHECK(Assert._Callstack->Frames != nullptr); + ZEN_INFO("Callstack:\n{}", CallstackToString(Assert._Callstack)); + } +} + TEST_SUITE_END(); #endif |