diff options
| author | Dan Engelbrecht <[email protected]> | 2025-01-08 13:49:56 +0100 |
|---|---|---|
| committer | Dan Engelbrecht <[email protected]> | 2025-01-08 13:49:56 +0100 |
| commit | 995aec217bbb26c9c2a701cc77edb067ffbf8d36 (patch) | |
| tree | 2da2d3fd806547bd9f38bc190514abbf9fdb6361 /src/zenutil/service.cpp | |
| parent | check if service is already installed before attempting install (diff) | |
| download | zen-995aec217bbb26c9c2a701cc77edb067ffbf8d36.tar.xz zen-995aec217bbb26c9c2a701cc77edb067ffbf8d36.zip | |
add ServiceLevel for service processes: User, AllUsers and Service
Diffstat (limited to 'src/zenutil/service.cpp')
| -rw-r--r-- | src/zenutil/service.cpp | 391 |
1 files changed, 356 insertions, 35 deletions
diff --git a/src/zenutil/service.cpp b/src/zenutil/service.cpp index fd96af0c8..44aa50494 100644 --- a/src/zenutil/service.cpp +++ b/src/zenutil/service.cpp @@ -3,8 +3,14 @@ #include <zenutil/service.h> #include <zencore/except.h> +#include <zencore/process.h> #include <zencore/scopeguard.h> #include <zencore/zencore.h> +#include <zenutil/zenserverprocess.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> @@ -19,6 +25,46 @@ namespace zen { using namespace std::literals; +namespace { + bool SplitExecutableAndArgs(const std::wstring& ExeAndArgs, std::filesystem::path& OutExecutablePath, std::string& OutArguments) + { + if (ExeAndArgs.size()) + { + if (ExeAndArgs[0] == '"') + { + std::wstring::size_type ExecutableEnd = ExeAndArgs.find('"', 1); + if (ExecutableEnd == std::wstring::npos) + { + OutExecutablePath = ExeAndArgs; + return true; + } + else + { + OutExecutablePath = ExeAndArgs.substr(0, ExecutableEnd + 1); + OutArguments = WideToUtf8(ExeAndArgs.substr(ExecutableEnd + 1 + ExeAndArgs[ExecutableEnd + 1] == ' ' ? 1 : 0)); + return true; + } + } + else + { + std::wstring::size_type ExecutableEnd = ExeAndArgs.find(' ', 1); + if (ExecutableEnd == std::wstring::npos) + { + OutExecutablePath = ExeAndArgs; + return true; + } + else + { + OutExecutablePath = ExeAndArgs.substr(0, ExecutableEnd); + OutArguments = WideToUtf8(ExeAndArgs.substr(ExecutableEnd + 1)); + return true; + } + } + } + return false; + } +} // namespace + #if ZEN_PLATFORM_MAC namespace { @@ -202,6 +248,44 @@ namespace { #endif // ZEN_PLATFORM_MAC std::string_view +ToString(ServiceLevel Level) +{ + switch (Level) + { + case ServiceLevel::CurrentUser: + return "Current User"sv; + case ServiceLevel::AllUsers: + return "All Users"sv; + case ServiceLevel::SystemService: + return "System Service"sv; + default: + ZEN_ASSERT(false); + return ""sv; + } +} + +std::optional<ServiceLevel> +FromString(const std::string& Level) +{ + if (StrCaseCompare(Level.c_str(), std::string(ToString(ServiceLevel::CurrentUser)).c_str()) == 0 || + StrCaseCompare(Level.c_str(), "CurrentUser") == 0) + { + return ServiceLevel::CurrentUser; + } + if (StrCaseCompare(Level.c_str(), std::string(ToString(ServiceLevel::AllUsers)).c_str()) == 0 || + StrCaseCompare(Level.c_str(), "AllUsers") == 0) + { + return ServiceLevel::AllUsers; + } + if (StrCaseCompare(Level.c_str(), std::string(ToString(ServiceLevel::SystemService)).c_str()) == 0 || + StrCaseCompare(Level.c_str(), "SystemService") == 0) + { + return ServiceLevel::SystemService; + } + return {}; +} + +std::string_view ToString(ServiceStatus Status) { switch (Status) @@ -231,7 +315,36 @@ ToString(ServiceStatus Status) #if ZEN_PLATFORM_WINDOWS std::error_code -InstallService(std::string_view ServiceName, const ServiceSpec& Spec) +InstallRunService(std::string_view ServiceName, const ServiceSpec& Spec) +{ + HKEY RegKey = NULL; + if (LSTATUS Status = RegOpenKey(Spec.ServiceLevel == ServiceLevel::AllUsers ? HKEY_LOCAL_MACHINE : HKEY_CURRENT_USER, + L"Software\\Microsoft\\Windows\\CurrentVersion\\Run", + &RegKey); + Status == ERROR_SUCCESS) + { + auto _ = MakeGuard([&]() { RegCloseKey(RegKey); }); + std::wstring PathIncludingArgs = Spec.ExecutablePath.wstring(); + if (!Spec.CommandLineOptions.empty()) + { + PathIncludingArgs += (L" " + Utf8ToWide(Spec.CommandLineOptions)); + } + if (Status = RegSetValueEx(RegKey, + Utf8ToWide(ServiceName).c_str(), + 0, + REG_SZ, + (BYTE*)PathIncludingArgs.c_str(), + DWORD((PathIncludingArgs.length() * 2) + 1)); + Status == ERROR_SUCCESS) + { + return {}; + } + } + return MakeErrorCodeFromLastError(); +} + +std::error_code +InstallSystemService(std::string_view ServiceName, const ServiceSpec& Spec) { // Get a handle to the SCM database. @@ -300,7 +413,40 @@ InstallService(std::string_view ServiceName, const ServiceSpec& Spec) } std::error_code -UninstallService(std::string_view ServiceName) +InstallService(std::string_view ServiceName, const ServiceSpec& Spec) +{ + if (Spec.ServiceLevel == ServiceLevel::SystemService) + { + return InstallSystemService(ServiceName, Spec); + } + return InstallRunService(ServiceName, Spec); +} + +std::error_code +UninstallRunService(bool AllUsers, std::string_view ServiceName) +{ + HKEY RegKey = NULL; + if (LSTATUS Status = + RegOpenKey(AllUsers ? HKEY_LOCAL_MACHINE : HKEY_CURRENT_USER, L"Software\\Microsoft\\Windows\\CurrentVersion\\Run", &RegKey); + Status == ERROR_SUCCESS) + { + auto _ = MakeGuard([&]() { RegCloseKey(RegKey); }); + TCHAR Value[4096]; + DWORD Type; + DWORD Size = sizeof(Value); + if (ERROR_SUCCESS == RegQueryValueEx(RegKey, Utf8ToWide(ServiceName).c_str(), 0, &Type, (BYTE*)Value, &Size)) + { + if (Status = RegDeleteValue(RegKey, Utf8ToWide(ServiceName).c_str()); Status == ERROR_SUCCESS) + { + return std::error_code{}; + } + } + } + return MakeErrorCodeFromLastError(); +} + +std::error_code +UninstallSystemService(std::string_view ServiceName) { // Get a handle to the SCM database. SC_HANDLE schSCManager = OpenSCManager(NULL, // local computer @@ -345,7 +491,57 @@ UninstallService(std::string_view ServiceName) } std::error_code -QueryInstalledService(std::string_view ServiceName, ServiceInfo& OutInfo) +UninstallService(std::string_view ServiceName, ServiceLevel Level) +{ + switch (Level) + { + case ServiceLevel::CurrentUser: + return UninstallRunService(/*AllUsers*/ false, ServiceName); + case ServiceLevel::AllUsers: + return UninstallRunService(/*AllUsers*/ true, ServiceName); + case ServiceLevel::SystemService: + return UninstallSystemService(ServiceName); + default: + ZEN_ASSERT(false); + return {}; + } +} + +bool +QueryRunServiceStatus(bool AllUsers, std::string_view ServiceName, ServiceInfo& OutInfo) +{ + HKEY RegKey = NULL; + if (ERROR_SUCCESS == + RegOpenKey(AllUsers ? HKEY_LOCAL_MACHINE : HKEY_CURRENT_USER, L"Software\\Microsoft\\Windows\\CurrentVersion\\Run", &RegKey)) + { + auto _ = MakeGuard([&]() { RegCloseKey(RegKey); }); + TCHAR Value[4096]; + DWORD Type; + DWORD Size = sizeof(Value); + if (ERROR_SUCCESS == RegQueryValueEx(RegKey, Utf8ToWide(ServiceName).c_str(), 0, &Type, (BYTE*)Value, &Size)) + { + OutInfo.Spec.ServiceLevel = AllUsers ? ServiceLevel::AllUsers : ServiceLevel::CurrentUser; + std::wstring PathIncludingArgs(Value); + + (void)SplitExecutableAndArgs(PathIncludingArgs, OutInfo.Spec.ExecutablePath, OutInfo.Spec.CommandLineOptions); + + OutInfo.Spec.DisplayName = ServiceName; + OutInfo.Spec.Description = ""; + OutInfo.Status = ServiceStatus::Stopped; + ProcessHandle Process; + std::error_code Ec = FindProcess(OutInfo.Spec.ExecutablePath, Process); + if (!Ec) + { + OutInfo.Status = ServiceStatus::Running; + } + return true; + } + } + return false; +} + +std::error_code +QuerySystemServiceStatus(std::string_view ServiceName, ServiceInfo& OutInfo) { // Get a handle to the SCM database. SC_HANDLE schSCManager = OpenSCManager(NULL, // local computer @@ -389,36 +585,7 @@ QueryInstalledService(std::string_view ServiceName, ServiceInfo& OutInfo) } std::wstring BinaryWithArguments(ServiceConfig->lpBinaryPathName); - if (BinaryWithArguments.size()) - { - if (BinaryWithArguments[0] == '"') - { - std::wstring::size_type ExecutableEnd = BinaryWithArguments.find('"', 1); - if (ExecutableEnd == std::wstring::npos) - { - OutInfo.Spec.ExecutablePath = BinaryWithArguments; - } - else - { - OutInfo.Spec.ExecutablePath = BinaryWithArguments.substr(0, ExecutableEnd + 1); - OutInfo.Spec.CommandLineOptions = - WideToUtf8(BinaryWithArguments.substr(ExecutableEnd + 1 + BinaryWithArguments[ExecutableEnd + 1] == ' ' ? 1 : 0)); - } - } - else - { - std::wstring::size_type ExecutableEnd = BinaryWithArguments.find(' ', 1); - if (ExecutableEnd == std::wstring::npos) - { - OutInfo.Spec.ExecutablePath = BinaryWithArguments; - } - else - { - OutInfo.Spec.ExecutablePath = BinaryWithArguments.substr(0, ExecutableEnd); - OutInfo.Spec.CommandLineOptions = WideToUtf8(BinaryWithArguments.substr(ExecutableEnd + 1)); - } - } - } + (void)SplitExecutableAndArgs(BinaryWithArguments, OutInfo.Spec.ExecutablePath, OutInfo.Spec.CommandLineOptions); OutInfo.Spec.DisplayName = WideToUtf8(ServiceConfig->lpDisplayName); SERVICE_STATUS ServiceStatus; @@ -480,7 +647,54 @@ QueryInstalledService(std::string_view ServiceName, ServiceInfo& OutInfo) } std::error_code -StartService(std::string_view ServiceName) +QueryInstalledService(std::string_view ServiceName, ServiceInfo& OutInfo) +{ + if (QueryRunServiceStatus(/*AllUsers*/ false, ServiceName, OutInfo)) + { + return {}; + } + if (QueryRunServiceStatus(/*AllUsers*/ true, ServiceName, OutInfo)) + { + return {}; + } + return QuerySystemServiceStatus(ServiceName, OutInfo); +} + +std::error_code +StartRunService(bool AllUsers, std::string_view ServiceName) +{ + HKEY RegKey = NULL; + if (ERROR_SUCCESS == + RegOpenKey(AllUsers ? HKEY_LOCAL_MACHINE : HKEY_CURRENT_USER, L"Software\\Microsoft\\Windows\\CurrentVersion\\Run", &RegKey)) + { + auto _ = MakeGuard([&]() { RegCloseKey(RegKey); }); + TCHAR Value[4096]; + DWORD Type; + DWORD Size = sizeof(Value); + if (ERROR_SUCCESS == RegQueryValueEx(RegKey, Utf8ToWide(ServiceName).c_str(), 0, &Type, (BYTE*)Value, &Size)) + { + std::wstring PathIncludingArgs(Value); + + std::filesystem::path ExecutablePath; + std::string CommandLineOptions; + if (SplitExecutableAndArgs(PathIncludingArgs, ExecutablePath, CommandLineOptions)) + { + ProcessHandle Proc; + Proc.Initialize(CreateProc(ExecutablePath, WideToUtf8(PathIncludingArgs), {.Flags = CreateProcOptions::Flag_NoConsole})); + if (Proc.IsValid()) + { + return {}; + } + MakeErrorCode(ERROR_PATH_NOT_FOUND); + } + MakeErrorCode(ERROR_INVALID_PARAMETER); + } + } + return MakeErrorCodeFromLastError(); +} + +std::error_code +StartSystemService(std::string_view ServiceName) { // Get a handle to the SCM database. SC_HANDLE schSCManager = OpenSCManager(NULL, // local computer @@ -520,7 +734,96 @@ StartService(std::string_view ServiceName) } std::error_code -StopService(std::string_view ServiceName) +StartService(std::string_view ServiceName, ServiceLevel Level) +{ + switch (Level) + { + case ServiceLevel::CurrentUser: + return StartRunService(/*AllUsers*/ false, ServiceName); + case ServiceLevel::AllUsers: + return StartRunService(/*AllUsers*/ true, ServiceName); + case ServiceLevel::SystemService: + return StartSystemService(ServiceName); + default: + ZEN_ASSERT(false); + return {}; + } +} + +std::error_code +StopRunService(bool AllUsers, std::string_view ServiceName) +{ + HKEY RegKey = NULL; + if (ERROR_SUCCESS == + RegOpenKey(AllUsers ? HKEY_LOCAL_MACHINE : HKEY_CURRENT_USER, L"Software\\Microsoft\\Windows\\CurrentVersion\\Run", &RegKey)) + { + auto _ = MakeGuard([&]() { RegCloseKey(RegKey); }); + TCHAR Value[4096]; + DWORD Type; + DWORD Size = sizeof(Value); + if (ERROR_SUCCESS == RegQueryValueEx(RegKey, Utf8ToWide(ServiceName).c_str(), 0, &Type, (BYTE*)Value, &Size)) + { + std::wstring PathIncludingArgs(Value); + + std::filesystem::path ExecutablePath; + std::string CommandLineOptions; + if (SplitExecutableAndArgs(PathIncludingArgs, ExecutablePath, CommandLineOptions)) + { + ProcessHandle Proc; + std::error_code Ec = FindProcess(ExecutablePath, Proc); + if (Ec) + { + return Ec; + } + else + { + // This is hacky and checks if the running service is a zenserver instance and tries to shut down using the shutdown + // event + ExtendableStringBuilder<32> ChildShutdownEventName; + ZenServerState State; + if (State.InitializeReadOnly()) + { + State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { + if (Entry.Pid == gsl::narrow<uint32_t>(Proc.Pid())) + { + ChildShutdownEventName << "Zen_" << Entry.EffectiveListenPort; + ChildShutdownEventName << "_Shutdown"; + } + }); + if (ChildShutdownEventName.Size() > 0) + { + NamedEvent Event(ChildShutdownEventName); + Ec = Event.Set(); + if (Ec) + { + return Ec; + } + return {}; + } + } + // This only works for a running process that does not already have a console attached - zenserver does have one + // hence the attempt to shut down using event above + if (AttachConsole(Proc.Pid())) + { + if (SetConsoleCtrlHandler(NULL, TRUE)) + { + if (GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0)) + { + return {}; + } + } + } + return MakeErrorCodeFromLastError(); + } + } + return MakeErrorCode(ERROR_INVALID_PARAMETER); + } + } + return MakeErrorCodeFromLastError(); +} + +std::error_code +StopSystemService(std::string_view ServiceName) { // Get a handle to the SCM database. SC_HANDLE schSCManager = OpenSCManager(NULL, // local computer @@ -558,6 +861,24 @@ StopService(std::string_view ServiceName) return {}; } + +std::error_code +StopService(std::string_view ServiceName, ServiceLevel Level) +{ + switch (Level) + { + case ServiceLevel::CurrentUser: + return StopRunService(/*AllUsers*/ false, ServiceName); + case ServiceLevel::AllUsers: + return StopRunService(/*AllUsers*/ true, ServiceName); + case ServiceLevel::SystemService: + return StopSystemService(ServiceName); + default: + ZEN_ASSERT(false); + return {}; + } +} + #else # if 0 |