diff options
Diffstat (limited to 'src/zenutil/zenserverprocess.cpp')
| -rw-r--r-- | src/zenutil/zenserverprocess.cpp | 311 |
1 files changed, 305 insertions, 6 deletions
diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index d271edb93..36a8b9c6f 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -449,6 +449,30 @@ ZenServerState::ZenServerEntry::IsReady() const return (Flags.load() & static_cast<uint16_t>(FlagsEnum::kIsReady)) != 0; } +void +ZenServerState::ZenServerEntry::SignalHasInstanceInfo() +{ + Flags |= uint16_t(FlagsEnum::kHasInstanceInfo); +} + +bool +ZenServerState::ZenServerEntry::HasInstanceInfo() const +{ + return (Flags.load() & static_cast<uint16_t>(FlagsEnum::kHasInstanceInfo)) != 0; +} + +void +ZenServerState::ZenServerEntry::SignalNoNetwork() +{ + Flags |= uint16_t(FlagsEnum::kNoNetwork); +} + +bool +ZenServerState::ZenServerEntry::IsNoNetwork() const +{ + return (Flags.load() & static_cast<uint16_t>(FlagsEnum::kNoNetwork)) != 0; +} + bool ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd, uint64_t Timeout) { @@ -492,6 +516,222 @@ ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd, uint64_t Ti } ////////////////////////////////////////////////////////////////////////// +// ZenServerInstanceInfo +////////////////////////////////////////////////////////////////////////// + +static constexpr size_t kInstanceInfoSize = 4096; + +ZenServerInstanceInfo::ZenServerInstanceInfo() = default; + +ZenServerInstanceInfo::~ZenServerInstanceInfo() +{ +#if ZEN_PLATFORM_WINDOWS + if (m_Data) + { + UnmapViewOfFile(m_Data); + } + if (m_hMapFile) + { + CloseHandle(m_hMapFile); + } +#else + if (m_Data != nullptr) + { + munmap(m_Data, kInstanceInfoSize); + } + if (m_hMapFile != nullptr) + { + int Fd = int(intptr_t(m_hMapFile)); + close(Fd); + } + if (m_IsOwner) + { + std::string Name = MakeName(m_SessionId); + shm_unlink(Name.c_str()); + } +#endif + m_Data = nullptr; +} + +std::string +ZenServerInstanceInfo::MakeName(const Oid& SessionId) +{ +#if ZEN_PLATFORM_WINDOWS + return fmt::format("Global\\ZenInstance_{}", SessionId); +#else + // macOS limits shm_open names to ~31 chars (PSHMNAMLEN), so keep this short. + // "/ZenI_" (6) + 24 hex = 30 chars, within the limit. + return fmt::format("/ZenI_{}", SessionId); +#endif +} + +void +ZenServerInstanceInfo::Create(const Oid& SessionId, const InstanceInfoData& Data) +{ + m_SessionId = SessionId; + m_IsOwner = true; + + // Serialize the data to compact binary + CbObjectWriter Cbo; + if (!Data.UnixSocketPath.empty()) + { + Cbo << "unix_socket" << PathToUtf8(Data.UnixSocketPath); + } + CbObject Payload = Cbo.Save(); + + MemoryView PayloadView = Payload.GetView(); + uint32_t PayloadSize = gsl::narrow<uint32_t>(PayloadView.GetSize()); + + std::string Name = MakeName(SessionId); + +#if ZEN_PLATFORM_WINDOWS + zenutil::AnyUserSecurityAttributes Attrs; + + std::wstring WideName(Name.begin(), Name.end()); + + HANDLE hMap = + CreateFileMappingW(INVALID_HANDLE_VALUE, Attrs.Attributes(), PAGE_READWRITE, 0, DWORD(kInstanceInfoSize), WideName.c_str()); + + if (hMap == NULL) + { + // Fall back to Local namespace + std::string LocalName = fmt::format("Local\\ZenInstance_{}", SessionId); + std::wstring WideLocalName(LocalName.begin(), LocalName.end()); + hMap = CreateFileMappingW(INVALID_HANDLE_VALUE, + Attrs.Attributes(), + PAGE_READWRITE, + 0, + DWORD(kInstanceInfoSize), + WideLocalName.c_str()); + } + + if (hMap == NULL) + { + ThrowLastError("Could not create instance info shared memory"); + } + + void* pBuf = MapViewOfFile(hMap, FILE_MAP_ALL_ACCESS, 0, 0, DWORD(kInstanceInfoSize)); + if (pBuf == NULL) + { + CloseHandle(hMap); + ThrowLastError("Could not map instance info shared memory"); + } +#else + int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC, 0666); + if (Fd < 0) + { + ThrowLastError("Could not create instance info shared memory"); + } + fchmod(Fd, 0666); + + if (ftruncate(Fd, kInstanceInfoSize) < 0) + { + close(Fd); + shm_unlink(Name.c_str()); + ThrowLastError("Could not resize instance info shared memory"); + } + + void* pBuf = mmap(nullptr, kInstanceInfoSize, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0); + if (pBuf == MAP_FAILED) + { + close(Fd); + shm_unlink(Name.c_str()); + ThrowLastError("Could not map instance info shared memory"); + } + + void* hMap = reinterpret_cast<void*>(intptr_t(Fd)); +#endif + + m_hMapFile = hMap; + m_Data = reinterpret_cast<uint8_t*>(pBuf); + + // Write payload: [uint32_t size][compact binary bytes] + memcpy(m_Data, &PayloadSize, sizeof PayloadSize); + if (PayloadSize > 0) + { + memcpy(m_Data + sizeof(uint32_t), PayloadView.GetData(), PayloadSize); + } +} + +bool +ZenServerInstanceInfo::OpenReadOnly(const Oid& SessionId) +{ + m_SessionId = SessionId; + + std::string Name = MakeName(SessionId); + +#if ZEN_PLATFORM_WINDOWS + std::wstring WideName(Name.begin(), Name.end()); + + HANDLE hMap = OpenFileMappingW(FILE_MAP_READ, FALSE, WideName.c_str()); + if (hMap == NULL) + { + // Fall back to Local namespace + std::string LocalName = fmt::format("Local\\ZenInstance_{}", SessionId); + std::wstring WideLocalName(LocalName.begin(), LocalName.end()); + hMap = OpenFileMappingW(FILE_MAP_READ, FALSE, WideLocalName.c_str()); + } + + if (hMap == NULL) + { + return false; + } + + void* pBuf = MapViewOfFile(hMap, FILE_MAP_READ, 0, 0, DWORD(kInstanceInfoSize)); + if (pBuf == NULL) + { + CloseHandle(hMap); + return false; + } +#else + int Fd = shm_open(Name.c_str(), O_RDONLY | O_CLOEXEC, 0666); + if (Fd < 0) + { + return false; + } + + void* pBuf = mmap(nullptr, kInstanceInfoSize, PROT_READ, MAP_SHARED, Fd, 0); + if (pBuf == MAP_FAILED) + { + close(Fd); + return false; + } + + void* hMap = reinterpret_cast<void*>(intptr_t(Fd)); +#endif + + m_hMapFile = hMap; + m_Data = reinterpret_cast<uint8_t*>(pBuf); + m_IsOwner = false; + + return true; +} + +InstanceInfoData +ZenServerInstanceInfo::Read() const +{ + InstanceInfoData Result; + + if (m_Data == nullptr) + { + return Result; + } + + uint32_t PayloadSize = 0; + memcpy(&PayloadSize, m_Data, sizeof PayloadSize); + + if (PayloadSize == 0 || PayloadSize > kInstanceInfoSize - sizeof(uint32_t)) + { + return Result; + } + + CbObject Payload = CbObject::Clone(m_Data + sizeof(uint32_t)); + Result.UnixSocketPath = Payload["unix_socket"].AsU8String(); + + return Result; +} + +////////////////////////////////////////////////////////////////////////// std::atomic<int> ZenServerTestCounter{0}; @@ -503,6 +743,39 @@ ZenServerEnvironment::~ZenServerEnvironment() { } +ZenServerEnvironment::ZenServerEnvironment(ZenServerEnvironment&& Other) +: m_ProgramBaseDir(std::move(Other.m_ProgramBaseDir)) +, m_ChildProcessBaseDir(std::move(Other.m_ChildProcessBaseDir)) +, m_IsInitialized(Other.m_IsInitialized) +, m_IsTestInstance(Other.m_IsTestInstance) +, m_IsHubInstance(Other.m_IsHubInstance) +, m_PassthroughOutput(Other.m_PassthroughOutput) +, m_ServerClass(std::move(Other.m_ServerClass)) +, m_NextPortNumber(Other.m_NextPortNumber.load()) +{ +} + +ZenServerEnvironment::ZenServerEnvironment(EStorageTag, std::filesystem::path ProgramBaseDir) +{ + Initialize(ProgramBaseDir); +} + +ZenServerEnvironment::ZenServerEnvironment(EHubTag, + std::filesystem::path ProgramBaseDir, + std::filesystem::path TestBaseDir, + std::string_view ServerClass) +{ + InitializeForHub(ProgramBaseDir, TestBaseDir, ServerClass); +} + +ZenServerEnvironment::ZenServerEnvironment(ETestTag, + std::filesystem::path ProgramBaseDir, + std::filesystem::path TestBaseDir, + std::string_view ServerClass) +{ + InitializeForTest(ProgramBaseDir, TestBaseDir, ServerClass); +} + void ZenServerEnvironment::Initialize(std::filesystem::path ProgramBaseDir) { @@ -792,6 +1065,8 @@ ToString(ZenServerInstance::ServerMode Mode) return "storage"sv; case ZenServerInstance::ServerMode::kHubServer: return "hub"sv; + case ZenServerInstance::ServerMode::kComputeServer: + return "compute"sv; default: return "invalid"sv; } @@ -815,6 +1090,10 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { CommandLine << " hub"; } + else if (m_ServerMode == ServerMode::kComputeServer) + { + CommandLine << " compute"; + } CommandLine << " --child-id " << ChildEventName; @@ -836,10 +1115,18 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); const std::filesystem::path Executable = m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath; - const std::filesystem::path OutputPath = - OpenConsole ? std::filesystem::path{} : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); - CreateProcOptions CreateOptions = {.WorkingDirectory = &CurrentDirectory, .Flags = CreationFlags, .StdoutFile = OutputPath}; - CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); + const std::filesystem::path OutputPath = (OpenConsole || m_Env.IsPassthroughOutput()) + ? std::filesystem::path{} + : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); + CreateProcOptions CreateOptions = { + .WorkingDirectory = &CurrentDirectory, + .Flags = CreationFlags, + .StdoutFile = OutputPath, +#if ZEN_PLATFORM_WINDOWS + .AssignToJob = m_JobObject, +#endif + }; + CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); #if ZEN_PLATFORM_WINDOWS if (!ChildPid) { @@ -848,6 +1135,12 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { ZEN_DEBUG("Regular spawn failed - spawning elevated server"); CreateOptions.Flags |= CreateProcOptions::Flag_Elevated; + // ShellExecuteEx (used by the elevated path) does not support job object assignment + if (CreateOptions.AssignToJob) + { + ZEN_WARN("Elevated process spawn does not support job object assignment; child will not be auto-terminated on parent exit"); + CreateOptions.AssignToJob = nullptr; + } ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); } else @@ -941,7 +1234,8 @@ ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerAr CommandLine << " " << AdditionalServerArgs; } - SpawnServerInternal(ChildId, CommandLine, !IsTest, WaitTimeoutMs); + const bool OpenConsole = !IsTest && !m_Env.IsHubEnvironment(); + SpawnServerInternal(ChildId, CommandLine, OpenConsole, WaitTimeoutMs); } void @@ -1187,6 +1481,10 @@ MakeLockFilePayload(const LockFileInfo& Info) CbObjectWriter Cbo; Cbo << "pid" << Info.Pid << "data" << PathToUtf8(Info.DataDir) << "port" << Info.EffectiveListenPort << "session_id" << Info.SessionId << "ready" << Info.Ready << "executable" << PathToUtf8(Info.ExecutablePath); + if (!Info.UnixSocketPath.empty()) + { + Cbo << "unix_socket" << PathToUtf8(Info.UnixSocketPath); + } return Cbo.Save(); } LockFileInfo @@ -1199,6 +1497,7 @@ ReadLockFilePayload(const CbObject& Payload) Info.Ready = Payload["ready"].AsBool(); Info.DataDir = Payload["data"].AsU8String(); Info.ExecutablePath = Payload["executable"].AsU8String(); + Info.UnixSocketPath = Payload["unix_socket"].AsU8String(); return Info; } @@ -1228,7 +1527,7 @@ ValidateLockFileInfo(const LockFileInfo& Info, std::string& OutReason) OutReason = fmt::format("session id ({}) is not valid", Info.SessionId); return false; } - if (Info.EffectiveListenPort == 0) + if (Info.EffectiveListenPort == 0 && Info.UnixSocketPath.empty()) { OutReason = fmt::format("listen port ({}) is not valid", Info.EffectiveListenPort); return false; |