// Copyright Epic Games, Inc. All Rights Reserved. #include "zenutil/zenserverprocess.h" #include #include #include #include #include #include #include #include #include #include ////////////////////////////////////////////////////////////////////////// namespace zen { namespace zenutil { class SecurityAttributes { public: inline SECURITY_ATTRIBUTES* Attributes() { return &m_Attributes; } protected: SECURITY_ATTRIBUTES m_Attributes{}; SECURITY_DESCRIPTOR m_Sd{}; }; // Security attributes which allows any user access class AnyUserSecurityAttributes : public SecurityAttributes { public: AnyUserSecurityAttributes() { m_Attributes.nLength = sizeof m_Attributes; m_Attributes.bInheritHandle = false; // Disable inheritance const BOOL Success = InitializeSecurityDescriptor(&m_Sd, SECURITY_DESCRIPTOR_REVISION); if (Success) { if (!SetSecurityDescriptorDacl(&m_Sd, TRUE, (PACL)NULL, FALSE)) { zen::ThrowLastError("SetSecurityDescriptorDacl failed", std::source_location::current()); } m_Attributes.lpSecurityDescriptor = &m_Sd; } } }; } // namespace zenutil ////////////////////////////////////////////////////////////////////////// ZenServerState::ZenServerState() { } ZenServerState::~ZenServerState() { if (m_OurEntry) { // Clean up our entry now that we're leaving m_OurEntry->Reset(); m_OurEntry = nullptr; } if (m_Data) { UnmapViewOfFile(m_Data); m_Data = nullptr; } if (m_hMapFile) { CloseHandle(m_hMapFile); } } void ZenServerState::Initialize() { // TODO: there's a small chance of a race here, this logic could be tightened up with a mutex to // ensure only a single process at a time creates the mapping if (HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap")) { m_hMapFile = hMap; } else { // Security attributes to enable any user to access state zenutil::AnyUserSecurityAttributes Attrs; hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file Attrs.Attributes(), // allow anyone to access PAGE_READWRITE, // read/write access 0, // maximum object size (high-order DWORD) m_MaxEntryCount * sizeof(ZenServerEntry), // maximum object size (low-order DWORD) L"Global\\ZenMap"); // name of mapping object if (hMap == NULL) { zen::ThrowLastError("Could not open or create file mapping object for Zen server state"); } m_hMapFile = hMap; } void* pBuf = MapViewOfFile(m_hMapFile, // handle to map object FILE_MAP_ALL_ACCESS, // read/write permission 0, // offset high 0, // offset low m_MaxEntryCount * sizeof(ZenServerEntry)); if (pBuf == NULL) { zen::ThrowLastError("Could not map view of Zen server state"); } m_Data = reinterpret_cast(pBuf); m_IsReadOnly = false; } bool ZenServerState::InitializeReadOnly() { if (HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap")) { m_hMapFile = hMap; } else { return false; } void* pBuf = MapViewOfFile(m_hMapFile, // handle to map object FILE_MAP_READ, // read permission 0, // offset high 0, // offset low m_MaxEntryCount * sizeof(ZenServerEntry)); if (pBuf == NULL) { zen::ThrowLastError("Could not map view of Zen server state"); } m_Data = reinterpret_cast(pBuf); return true; } ZenServerState::ZenServerEntry* ZenServerState::Lookup(int ListenPort) { for (int i = 0; i < m_MaxEntryCount; ++i) { if (m_Data[i].ListenPort == ListenPort) { return &m_Data[i]; } } return nullptr; } ZenServerState::ZenServerEntry* ZenServerState::Register(int ListenPort) { if (m_Data == nullptr) { return nullptr; } // Allocate an entry int Pid = zen::GetCurrentProcessId(); for (int i = 0; i < m_MaxEntryCount; ++i) { ZenServerEntry& Entry = m_Data[i]; if (Entry.ListenPort.load(std::memory_order::memory_order_relaxed) == 0) { uint16_t Expected = 0; if (Entry.ListenPort.compare_exchange_strong(Expected, uint16_t(ListenPort))) { // Successfully allocated entry m_OurEntry = &Entry; Entry.Pid = Pid; Entry.Flags = 0; const zen::Oid SesId = zen::GetSessionId(); memcpy(Entry.SessionId, &SesId, sizeof SesId); return &Entry; } } } return nullptr; } void ZenServerState::Sweep() { if (m_Data == nullptr) { return; } ZEN_ASSERT(m_IsReadOnly == false); for (int i = 0; i < m_MaxEntryCount; ++i) { ZenServerEntry& Entry = m_Data[i]; if (Entry.ListenPort) { if (zen::IsProcessRunning(Entry.Pid) == false) { ZEN_DEBUG("Sweep - pid {} not running, reclaiming entry (port {})", Entry.Pid, Entry.ListenPort); Entry.Reset(); } } } } void ZenServerState::Snapshot(std::function&& Callback) { if (m_Data == nullptr) { return; } for (int i = 0; i < m_MaxEntryCount; ++i) { ZenServerEntry& Entry = m_Data[i]; if (Entry.ListenPort) { Callback(Entry); } } } void ZenServerState::ZenServerEntry::Reset() { Pid = 0; ListenPort = 0; Flags = 0; } void ZenServerState::ZenServerEntry::SignalShutdownRequest() { Flags |= uint16_t(FlagsEnum::kShutdownPlease); } bool ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd) { for (std::atomic& PidEntry : SponsorPids) { if (PidEntry.load(std::memory_order::memory_order_relaxed) == 0) { uint32_t Expected = 0; if (PidEntry.compare_exchange_strong(Expected, uint16_t(PidToAdd))) { // Success! return true; } } else if (PidEntry.load(std::memory_order::memory_order_relaxed) == PidToAdd) { // Success, the because pid is already in the list return true; } } return false; } ////////////////////////////////////////////////////////////////////////// std::atomic TestCounter{0}; ZenServerEnvironment::ZenServerEnvironment() { } ZenServerEnvironment::~ZenServerEnvironment() { } void ZenServerEnvironment::Initialize(std::filesystem::path ProgramBaseDir) { m_ProgramBaseDir = ProgramBaseDir; ZEN_DEBUG("Program base dir is '{}'", ProgramBaseDir); m_IsInitialized = true; } void ZenServerEnvironment::InitializeForTest(std::filesystem::path ProgramBaseDir, std::filesystem::path TestBaseDir) { m_ProgramBaseDir = ProgramBaseDir; m_TestBaseDir = TestBaseDir; ZEN_INFO("Program base dir is '{}'", ProgramBaseDir); ZEN_INFO("Cleaning test base dir '{}'", TestBaseDir); zen::DeleteDirectories(TestBaseDir.c_str()); m_IsTestInstance = true; m_IsInitialized = true; } std::filesystem::path ZenServerEnvironment::CreateNewTestDir() { using namespace std::literals; zen::ExtendableWideStringBuilder<256> TestDir; TestDir << "test"sv << int64_t(++TestCounter); std::filesystem::path TestPath = m_TestBaseDir / TestDir.c_str(); ZEN_INFO("Creating new test dir @ '{}'", TestPath); zen::CreateDirectories(TestPath.c_str()); return TestPath; } std::filesystem::path ZenServerEnvironment::GetTestRootDir(std::string_view Path) { std::filesystem::path Root = m_ProgramBaseDir.parent_path().parent_path(); std::filesystem::path Relative{Path}; return Root / Relative; } ////////////////////////////////////////////////////////////////////////// std::atomic ChildIdCounter{0}; ZenServerInstance::ZenServerInstance(ZenServerEnvironment& TestEnvironment) : m_Env(TestEnvironment) { ZEN_ASSERT(TestEnvironment.IsInitialized()); } ZenServerInstance::~ZenServerInstance() { Shutdown(); } void ZenServerInstance::SignalShutdown() { m_ShutdownEvent.Set(); } void ZenServerInstance::Shutdown() { if (m_Process.IsValid()) { if (m_Terminate) { ZEN_INFO("Terminating zenserver process"); m_Process.Terminate(111); } else { SignalShutdown(); m_Process.Wait(); m_Process.Reset(); } } } void ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerArgs) { ZEN_ASSERT(!m_Process.IsValid()); // Only spawn once const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); const std::filesystem::path Executable = BaseDir / "zenserver.exe"; const int MyPid = _getpid(); const int ChildId = ++ChildIdCounter; zen::ExtendableStringBuilder<32> ChildEventName; ChildEventName << "Zen_Child_" << ChildId; zen::NamedEvent ChildEvent{ChildEventName}; CreateShutdownEvent(BasePort); zen::ExtendableStringBuilder<32> LogId; LogId << "Zen" << ChildId; zen::ExtendableWideStringBuilder<512> CommandLine; CommandLine << "\""; CommandLine.Append(Executable.c_str()); CommandLine << "\""; const bool IsTest = m_Env.IsTestEnvironment(); if (IsTest) { if (!m_OwnerPid.has_value()) { m_OwnerPid = MyPid; } CommandLine << " --test --log-id " << LogId; } if (m_OwnerPid.has_value()) { CommandLine << " --owner-pid " << m_OwnerPid.value(); } CommandLine << " --child-id " << ChildEventName; if (BasePort) { CommandLine << " --port " << BasePort; m_BasePort = BasePort; } if (!m_TestDir.empty()) { CommandLine << " --data-dir "; CommandLine << m_TestDir.c_str(); } if (m_MeshEnabled) { CommandLine << " --mesh"; } if (!AdditionalServerArgs.empty()) { CommandLine << " " << AdditionalServerArgs; } std::filesystem::path CurrentDirectory = std::filesystem::current_path(); ZEN_DEBUG("Spawning server '{}'", LogId); PROCESS_INFORMATION ProcessInfo{}; STARTUPINFO StartupInfo{.cb = sizeof(STARTUPINFO)}; DWORD CreationFlags = 0; if (!IsTest) { CreationFlags |= CREATE_NEW_CONSOLE; } HANDLE hProcess = NULL; { const bool InheritHandles = false; void* Environment = nullptr; LPSECURITY_ATTRIBUTES ProcessAttributes = nullptr; LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr; BOOL Success = CreateProcessW(Executable.c_str(), (LPWSTR)CommandLine.c_str(), ProcessAttributes, ThreadAttributes, InheritHandles, CreationFlags, Environment, CurrentDirectory.c_str(), &StartupInfo, &ProcessInfo); if (Success) { hProcess = ProcessInfo.hProcess; CloseHandle(ProcessInfo.hThread); } else { DWORD WinError = ::GetLastError(); if (WinError == ERROR_ELEVATION_REQUIRED) { // Try launching elevated process ZEN_DEBUG("Regular spawn failed - spawning elevated server"); SHELLEXECUTEINFO ShellExecuteInfo; ZeroMemory(&ShellExecuteInfo, sizeof(ShellExecuteInfo)); ShellExecuteInfo.cbSize = sizeof(ShellExecuteInfo); ShellExecuteInfo.fMask = SEE_MASK_UNICODE | SEE_MASK_NOCLOSEPROCESS; ShellExecuteInfo.lpFile = Executable.c_str(); ShellExecuteInfo.lpVerb = TEXT("runas"); ShellExecuteInfo.nShow = SW_SHOW; ShellExecuteInfo.lpParameters = CommandLine.c_str(); if (::ShellExecuteEx(&ShellExecuteInfo)) { WinError = NO_ERROR; hProcess = ShellExecuteInfo.hProcess; } } if (WinError != NO_ERROR) { std::error_code err(WinError, std::system_category()); ZEN_ERROR("Server spawn failed: {}", err.message()); throw std::system_error(err, "failed to create server process"); } } } ZEN_DEBUG("Server '{}' spawned OK", LogId); if (IsTest) { m_Process.Initialize(hProcess); } else { CloseHandle(hProcess); } m_ReadyEvent = std::move(ChildEvent); } void ZenServerInstance::CreateShutdownEvent(int BasePort) { zen::ExtendableStringBuilder<32> ChildShutdownEventName; ChildShutdownEventName << "Zen_" << BasePort; ChildShutdownEventName << "_Shutdown"; zen::NamedEvent ChildShutdownEvent{ChildShutdownEventName}; m_ShutdownEvent = std::move(ChildShutdownEvent); } void ZenServerInstance::AttachToRunningServer(int BasePort) { ZenServerState State; if (!State.InitializeReadOnly()) { // TODO: return success/error code instead? throw std::runtime_error("No zen state found"); } const ZenServerState::ZenServerEntry* Entry = nullptr; if (BasePort) { Entry = State.Lookup(BasePort); } else { State.Snapshot([&](const ZenServerState::ZenServerEntry& InEntry) { Entry = &InEntry; }); } if (!Entry) { // TODO: return success/error code instead? throw std::runtime_error("No server found"); } m_Process.Initialize(Entry->Pid); CreateShutdownEvent(BasePort); } void ZenServerInstance::Detach() { if (m_Process.IsValid()) { m_Process.Reset(); m_ShutdownEvent.Close(); } } void ZenServerInstance::WaitUntilReady() { while (m_ReadyEvent.Wait(100) == false) { if (!m_Process.IsRunning() || !m_Process.IsValid()) { return; } } } bool ZenServerInstance::WaitUntilReady(int Timeout) { return m_ReadyEvent.Wait(Timeout); } std::string ZenServerInstance::GetBaseUri() const { ZEN_ASSERT(m_BasePort); using namespace fmt::literals; return "http://localhost:{}"_format(m_BasePort); } } // namespace zen