// Copyright Epic Games, Inc. All Rights Reserved. #include "dlltransport.h" #include #include #include #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_WITH_PLUGINS namespace zen { ////////////////////////////////////////////////////////////////////////// class DllTransportLogger : public TransportLogger, public RefCounted { public: DllTransportLogger(std::string_view PluginName); virtual ~DllTransportLogger() = default; void LogMessage(LogLevel Level, const char* Message) override; private: std::string m_PluginName; }; struct LoadedDll { std::string Name; std::filesystem::path LoadedFromPath; DllTransportLogger* Logger = nullptr; Ref Plugin; }; class DllTransportPluginImpl : public DllTransportPlugin, RefCounted { public: DllTransportPluginImpl() = default; ~DllTransportPluginImpl() = default; virtual uint32_t AddRef() const override; virtual uint32_t Release() const override; virtual void Configure(const char* OptionTag, const char* OptionValue) override; virtual void Initialize(TransportServer* ServerInterface) override; virtual void Shutdown() override; virtual const char* GetDebugName() override; virtual bool IsAvailable() override; virtual bool LoadDll(std::string_view Name) override; virtual void ConfigureDll(std::string_view Name, const char* OptionTag, const char* OptionValue) override; private: TransportServer* m_ServerInterface = nullptr; RwLock m_Lock; std::vector m_Transports; }; DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginName(PluginName) { } void DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message) { logging::level::LogLevel Level; // clang-format off switch (PluginLogLevel) { case LogLevel::Trace: Level = logging::level::Trace; break; case LogLevel::Debug: Level = logging::level::Debug; break; case LogLevel::Info: Level = logging::level::Info; break; case LogLevel::Warn: Level = logging::level::Warn; break; case LogLevel::Err: Level = logging::level::Err; break; case LogLevel::Critical: Level = logging::level::Critical; break; default: Level = logging::level::Off; break; } // clang-format on ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message) } uint32_t DllTransportPluginImpl::AddRef() const { return RefCounted::AddRef(); } uint32_t DllTransportPluginImpl::Release() const { return RefCounted::Release(); } void DllTransportPluginImpl::Configure(const char* OptionTag, const char* OptionValue) { // No configuration options ZEN_UNUSED(OptionTag, OptionValue); } void DllTransportPluginImpl::Initialize(TransportServer* ServerIface) { m_ServerInterface = ServerIface; RwLock::ExclusiveLockScope _(m_Lock); for (LoadedDll& Transport : m_Transports) { try { Transport.Plugin->Initialize(ServerIface); } catch (const std::exception&) { // TODO: report } } } void DllTransportPluginImpl::Shutdown() { RwLock::ExclusiveLockScope _(m_Lock); for (LoadedDll& Transport : m_Transports) { try { Transport.Plugin->Shutdown(); Transport.Logger->Release(); } catch (const std::exception&) { // TODO: report } } } const char* DllTransportPluginImpl::GetDebugName() { return nullptr; } bool DllTransportPluginImpl::IsAvailable() { return true; } void DllTransportPluginImpl::ConfigureDll(std::string_view Name, const char* OptionTag, const char* OptionValue) { RwLock::ExclusiveLockScope _(m_Lock); for (auto& Transport : m_Transports) { if (Transport.Name == Name) { Transport.Plugin->Configure(OptionTag, OptionValue); } } } bool DllTransportPluginImpl::LoadDll(std::string_view Name) { RwLock::ExclusiveLockScope _(m_Lock); WideStringBuilder<1024> DllPath; DllPath << Name; if (!Name.ends_with(".dll")) { DllPath << ".dll"; } std::string FileName = std::filesystem::path(DllPath.c_str()).filename().replace_extension().string(); std::string Path = std::filesystem::path(DllPath.c_str()).parent_path().string(); SetDllDirectoryW(Utf8ToWide(Path).c_str()); HMODULE DllHandle = LoadLibraryW(DllPath.c_str()); if (!DllHandle) { ZEN_WARN("Failed to load transport DLL from '{}' due to '{}'", WideToUtf8(DllPath), GetLastErrorAsString()) return false; } PfnGetTransportPluginVersion GetVersion = (PfnGetTransportPluginVersion)GetProcAddress(DllHandle, "GetTransportPluginVersion"); PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin"); uint32_t APIVersion = 0; uint32_t PluginVersion = 0; if (GetVersion) { GetVersion(&APIVersion, &PluginVersion); } const bool bValidApiVersion = APIVersion == kTransportApiVersion; if (!GetVersion || !CreatePlugin || !bValidApiVersion) { std::error_code Ec = MakeErrorCodeFromLastError(); FreeLibrary(DllHandle); if (GetVersion && !bValidApiVersion) { ZEN_WARN("Failed to load transport DLL from '{}' due to invalid API version {}, supported API version is {}", WideToUtf8(DllPath), APIVersion, kTransportApiVersion) } else { ZEN_WARN("Failed to load transport DLL from '{}' due to not finding GetTransportPluginVersion or CreateTransportPlugin", WideToUtf8(DllPath)) } return false; } LoadedDll NewDll; NewDll.Name = Name; NewDll.LoadedFromPath = DllPath.c_str(); NewDll.Logger = new DllTransportLogger(FileName); NewDll.Logger->AddRef(); NewDll.Plugin = CreatePlugin(NewDll.Logger); m_Transports.emplace_back(std::move(NewDll)); return true; } DllTransportPlugin* CreateDllTransportPlugin() { return new DllTransportPluginImpl; } } // namespace zen #endif