aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/transports/dlltransport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/transports/dlltransport.cpp')
-rw-r--r--src/zenhttp/transports/dlltransport.cpp81
1 files changed, 70 insertions, 11 deletions
diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp
index e09e62ec5..dd05f6f0c 100644
--- a/src/zenhttp/transports/dlltransport.cpp
+++ b/src/zenhttp/transports/dlltransport.cpp
@@ -21,18 +21,31 @@ namespace zen {
//////////////////////////////////////////////////////////////////////////
+class DllTransportLogger : public TransportLogger, public RefCounted
+{
+public:
+ DllTransportLogger(std::string_view PluginName);
+ virtual ~DllTransportLogger() = default;
+
+ void LogMessage(LogLevel Level, const char* Message) override;
+
+private:
+ std::string m_PluginName;
+};
+
struct LoadedDll
{
std::string Name;
std::filesystem::path LoadedFromPath;
+ DllTransportLogger* Logger = nullptr;
Ref<TransportPlugin> Plugin;
};
class DllTransportPluginImpl : public DllTransportPlugin, RefCounted
{
public:
- DllTransportPluginImpl();
- ~DllTransportPluginImpl();
+ DllTransportPluginImpl() = default;
+ ~DllTransportPluginImpl() = default;
virtual uint32_t AddRef() const override;
virtual uint32_t Release() const override;
@@ -51,12 +64,27 @@ private:
std::vector<LoadedDll> m_Transports;
};
-DllTransportPluginImpl::DllTransportPluginImpl()
+DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginName(PluginName)
{
}
-DllTransportPluginImpl::~DllTransportPluginImpl()
+void
+DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message)
{
+ logging::level::LogLevel Level;
+ // clang-format off
+ switch (PluginLogLevel)
+ {
+ case LogLevel::Trace: Level = logging::level::Trace; break;
+ case LogLevel::Debug: Level = logging::level::Debug; break;
+ case LogLevel::Info: Level = logging::level::Info; break;
+ case LogLevel::Warn: Level = logging::level::Warn; break;
+ case LogLevel::Err: Level = logging::level::Err; break;
+ case LogLevel::Critical: Level = logging::level::Critical; break;
+ default: Level = logging::level::Off; break;
+ }
+ // clang-format on
+ ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message)
}
uint32_t
@@ -109,6 +137,7 @@ DllTransportPluginImpl::Shutdown()
try
{
Transport.Plugin->Shutdown();
+ Transport.Logger->Release();
}
catch (const std::exception&)
{
@@ -148,8 +177,15 @@ DllTransportPluginImpl::LoadDll(std::string_view Name)
{
RwLock::ExclusiveLockScope _(m_Lock);
- ExtendableStringBuilder<128> DllPath;
- DllPath << Name << ".dll";
+ ExtendableStringBuilder<1024> DllPath;
+ DllPath << Name;
+ if (!Name.ends_with(".dll"))
+ {
+ DllPath << ".dll";
+ }
+
+ std::string FileName = std::filesystem::path(DllPath.c_str()).filename().replace_extension().string();
+
HMODULE DllHandle = LoadLibraryA(DllPath.c_str());
if (!DllHandle)
@@ -159,24 +195,47 @@ DllTransportPluginImpl::LoadDll(std::string_view Name)
throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath));
}
- TransportPlugin* CreateTransportPlugin();
+ PfnGetTransportPluginVersion GetVersion = (PfnGetTransportPluginVersion)GetProcAddress(DllHandle, "GetTransportPluginVersion");
+ PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin");
- PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin");
+ uint32_t APIVersion = 0;
+ uint32_t PluginVersion = 0;
- if (!CreatePlugin)
+ if (GetVersion)
+ {
+ GetVersion(&APIVersion, &PluginVersion);
+ }
+
+ const bool bValidApiVersion = APIVersion == kTransportApiVersion;
+
+ if (!GetVersion || !CreatePlugin || !bValidApiVersion)
{
std::error_code Ec = MakeErrorCodeFromLastError();
FreeLibrary(DllHandle);
- throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath));
+ if (GetVersion && !bValidApiVersion)
+ {
+ throw std::system_error(
+ Ec,
+ fmt::format("invalid API version '{}' detected in transport DLL loaded from '{}', supported API version '{}'",
+ APIVersion,
+ DllPath,
+ kTransportApiVersion));
+ }
+ else
+ {
+ throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath));
+ }
}
LoadedDll NewDll;
NewDll.Name = Name;
NewDll.LoadedFromPath = DllPath.c_str();
- NewDll.Plugin = CreatePlugin();
+ NewDll.Logger = new DllTransportLogger(FileName);
+ NewDll.Logger->AddRef();
+ NewDll.Plugin = CreatePlugin(NewDll.Logger);
m_Transports.emplace_back(std::move(NewDll));
}