diff options
Diffstat (limited to 'src/zenhttp/clients')
| -rw-r--r-- | src/zenhttp/clients/asynchttpclient.cpp | 225 |
1 files changed, 187 insertions, 38 deletions
diff --git a/src/zenhttp/clients/asynchttpclient.cpp b/src/zenhttp/clients/asynchttpclient.cpp index fea89d47d..bdf7f160c 100644 --- a/src/zenhttp/clients/asynchttpclient.cpp +++ b/src/zenhttp/clients/asynchttpclient.cpp @@ -68,7 +68,7 @@ struct AsyncHttpClient::Impl , m_OwnedIoContext(std::make_unique<asio::io_context>()) , m_IoContext(*m_OwnedIoContext) , m_Strand(asio::make_strand(m_IoContext)) - , m_PollTimer(m_Strand) + , m_Timer(m_Strand) { Init(); m_WorkGuard.emplace(m_IoContext.get_executor()); @@ -91,7 +91,7 @@ struct AsyncHttpClient::Impl , m_Log(logging::Get(Settings.LogCategory)) , m_IoContext(IoContext) , m_Strand(asio::make_strand(m_IoContext)) - , m_PollTimer(m_Strand) + , m_Timer(m_Strand) { Init(); } @@ -107,7 +107,18 @@ struct AsyncHttpClient::Impl asio::post(m_Strand, [this, &Done]() { m_ShuttingDown = true; - m_PollTimer.cancel(); + m_Timer.cancel(); + + // Release all tracked sockets (don't close — curl owns the fds). + for (auto& [Fd, Info] : m_Sockets) + { + if (Info->Socket.is_open()) + { + Info->Socket.cancel(); + Info->Socket.release(); + } + } + m_Sockets.clear(); for (auto& [Handle, Ctx] : m_Transfers) { @@ -155,6 +166,8 @@ struct AsyncHttpClient::Impl throw std::runtime_error("curl_multi_init failed"); } + SetupMultiCallbacks(); + if (m_Settings.SessionId == Oid::Zero) { m_SessionId = std::string(GetSessionIdString()); @@ -308,45 +321,187 @@ struct AsyncHttpClient::Impl [Cb = std::move(Stolen->Callback), Response = std::move(ErrorResponse)]() mutable { Cb(std::move(Response)); }); return; } + } + + // ── Socket-action integration ─────────────────────────────────────── + // + // curl_multi drives I/O via two callbacks: + // - SocketCallback: curl tells us which sockets to watch for read/write + // - TimerCallback: curl tells us when to fire a timeout + // + // On each socket event or timeout we call curl_multi_socket_action(), + // then drain completed transfers via curl_multi_info_read(). + + // Per-socket state: wraps the native fd in an ASIO socket for async_wait. + struct SocketInfo + { + asio::ip::tcp::socket Socket; + int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT + + explicit SocketInfo(asio::io_context& IoContext) : Socket(IoContext) {} + }; - SchedulePoll(); + // Static thunks registered with curl_multi ──────────────────────────── + + static int CurlSocketCallback(CURL* Easy, curl_socket_t Fd, int Action, void* UserPtr, void* SocketPtr) + { + ZEN_UNUSED(Easy); + auto* Self = static_cast<Impl*>(UserPtr); + Self->OnCurlSocket(Fd, Action, static_cast<SocketInfo*>(SocketPtr)); + return 0; + } + + static int CurlTimerCallback(CURLM* Multi, long TimeoutMs, void* UserPtr) + { + ZEN_UNUSED(Multi); + auto* Self = static_cast<Impl*>(UserPtr); + Self->OnCurlTimer(TimeoutMs); + return 0; + } + + void SetupMultiCallbacks() + { + curl_multi_setopt(m_Multi, CURLMOPT_SOCKETFUNCTION, CurlSocketCallback); + curl_multi_setopt(m_Multi, CURLMOPT_SOCKETDATA, this); + curl_multi_setopt(m_Multi, CURLMOPT_TIMERFUNCTION, CurlTimerCallback); + curl_multi_setopt(m_Multi, CURLMOPT_TIMERDATA, this); } - // ── Poll timer ────────────────────────────────────────────────────── + // Called by curl when socket watch state changes ────────────────────── - void SchedulePoll() + void OnCurlSocket(curl_socket_t Fd, int Action, SocketInfo* Info) { - if (m_PollScheduled || m_ShuttingDown) + if (Action == CURL_POLL_REMOVE) { + if (Info) + { + // Cancel pending async_wait ops before releasing the fd. + // curl owns the fd, so we must release() rather than close(). + Info->Socket.cancel(); + if (Info->Socket.is_open()) + { + Info->Socket.release(); + } + m_Sockets.erase(Fd); + } return; } - m_PollScheduled = true; - // Poll at a fixed interval. curl_multi_timeout() returns the time until - // the next internal timeout event (e.g. connect timeout), which can be - // much longer than ideal for detecting socket readiness. A fixed short - // interval ensures we detect connection completion and data arrival - // promptly. - static constexpr long PollIntervalMs = 10; - - m_PollTimer.expires_after(std::chrono::milliseconds(PollIntervalMs)); - m_PollTimer.async_wait([this](const asio::error_code& Ec) { - m_PollScheduled = false; - if (Ec || m_ShuttingDown) + if (!Info) + { + // New socket — wrap the native fd in an ASIO socket. + auto [It, Inserted] = m_Sockets.emplace(Fd, std::make_unique<SocketInfo>(m_IoContext)); + Info = It->second.get(); + + asio::error_code Ec; + // Determine protocol from the fd (v4 vs v6). Default to v4. + Info->Socket.assign(asio::ip::tcp::v4(), Fd, Ec); + if (Ec) + { + // Try v6 as fallback + Info->Socket.assign(asio::ip::tcp::v6(), Fd, Ec); + } + if (Ec) { + ZEN_WARN("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message()); + m_Sockets.erase(Fd); return; } - OnPoll(); - }); + + curl_multi_assign(m_Multi, Fd, Info); + } + + Info->WatchFlags = Action; + SetSocketWatch(Fd, Info); + } + + void SetSocketWatch(curl_socket_t Fd, SocketInfo* Info) + { + // Cancel any pending wait before issuing a new one. + Info->Socket.cancel(); + + if (Info->WatchFlags & CURL_POLL_IN) + { + Info->Socket.async_wait(asio::socket_base::wait_read, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + OnSocketReady(Fd, CURL_CSELECT_IN); + })); + } + + if (Info->WatchFlags & CURL_POLL_OUT) + { + Info->Socket.async_wait(asio::socket_base::wait_write, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + OnSocketReady(Fd, CURL_CSELECT_OUT); + })); + } } - void OnPoll() + void OnSocketReady(curl_socket_t Fd, int CurlAction) { - ZEN_TRACE_CPU("AsyncHttpClient::OnPoll"); + ZEN_TRACE_CPU("AsyncHttpClient::OnSocketReady"); int StillRunning = 0; - curl_multi_perform(m_Multi, &StillRunning); + curl_multi_socket_action(m_Multi, Fd, CurlAction, &StillRunning); + CheckCompleted(); + + // Re-arm the watch if the socket is still tracked. + auto It = m_Sockets.find(Fd); + if (It != m_Sockets.end()) + { + SetSocketWatch(Fd, It->second.get()); + } + } + + // Called by curl when it wants a timeout ────────────────────────────── + + void OnCurlTimer(long TimeoutMs) + { + m_Timer.cancel(); + + if (TimeoutMs < 0) + { + // curl says "no timeout needed" + return; + } - // Collect completed transfers + if (TimeoutMs == 0) + { + // curl wants immediate action — run it directly on the strand. + asio::post(m_Strand, [this]() { + if (m_ShuttingDown) + { + return; + } + int StillRunning = 0; + curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); + CheckCompleted(); + }); + return; + } + + m_Timer.expires_after(std::chrono::milliseconds(TimeoutMs)); + m_Timer.async_wait(asio::bind_executor(m_Strand, [this](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + ZEN_TRACE_CPU("AsyncHttpClient::OnTimeout"); + int StillRunning = 0; + curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); + CheckCompleted(); + })); + } + + // Drain completed transfers from curl_multi ────────────────────────── + + void CheckCompleted() + { int MsgsLeft = 0; CURLMsg* Msg = nullptr; while ((Msg = curl_multi_info_read(m_Multi, &MsgsLeft)) != nullptr) @@ -364,7 +519,6 @@ struct AsyncHttpClient::Impl auto It = m_Transfers.find(Handle); if (It == m_Transfers.end()) { - // Should not happen, but be safe ReleaseHandle(Handle); continue; } @@ -374,11 +528,6 @@ struct AsyncHttpClient::Impl CompleteTransfer(Handle, Result, std::move(Ctx)); } - - if (StillRunning > 0) - { - SchedulePoll(); - } } void CompleteTransfer(CURL* Handle, CURLcode CurlResult, std::unique_ptr<TransferContext> Ctx) @@ -683,13 +832,13 @@ struct AsyncHttpClient::Impl std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard; std::thread m_IoThread; - // curl_multi - CURLM* m_Multi = nullptr; - std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers; - std::vector<CURL*> m_HandlePool; - asio::steady_timer m_PollTimer; - bool m_PollScheduled = false; - bool m_ShuttingDown = false; + // curl_multi and socket-action state + CURLM* m_Multi = nullptr; + std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers; + std::vector<CURL*> m_HandlePool; + std::unordered_map<curl_socket_t, std::unique_ptr<SocketInfo>> m_Sockets; + asio::steady_timer m_Timer; + bool m_ShuttingDown = false; // Access token cache RwLock m_AccessTokenLock; |