aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-08-11 02:16:22 -0400
committerRapptz <[email protected]>2021-08-11 02:16:22 -0400
commit08a4db396118aeda6205ff56c8c8fc565fc338fc (patch)
tree828a83ef13e5c5caae894cdd16498a9c742fc6e2
parentRefactor Client.run to use asyncio.run (diff)
downloaddiscord.py-08a4db396118aeda6205ff56c8c8fc565fc338fc.tar.xz
discord.py-08a4db396118aeda6205ff56c8c8fc565fc338fc.zip
Revert "Refactor Client.run to use asyncio.run"
This reverts commit 6e6c8a7b2810747222a938c7fe3e466c2994b23f.
-rw-r--r--discord/client.py200
-rw-r--r--discord/http.py3
2 files changed, 103 insertions, 100 deletions
diff --git a/discord/client.py b/discord/client.py
index 4279a36b..cdfc6bdb 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -26,24 +26,10 @@ from __future__ import annotations
import asyncio
import logging
+import signal
import sys
import traceback
-from typing import (
- Any,
- Callable,
- Coroutine,
- Dict,
- Generator,
- Iterable,
- List,
- Optional,
- Sequence,
- TYPE_CHECKING,
- Tuple,
- TypeVar,
- Type,
- Union,
-)
+from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
import aiohttp
@@ -82,7 +68,6 @@ if TYPE_CHECKING:
from .message import Message
from .member import Member
from .voice_client import VoiceProtocol
- from types import TracebackType
__all__ = (
'Client',
@@ -93,8 +78,36 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
log: logging.Logger = logging.getLogger(__name__)
-C = TypeVar('C', bound='Client')
-
+def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
+ tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
+
+ if not tasks:
+ return
+
+ log.info('Cleaning up after %d tasks.', len(tasks))
+ for task in tasks:
+ task.cancel()
+
+ loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
+ log.info('All tasks finished cancelling.')
+
+ for task in tasks:
+ if task.cancelled():
+ continue
+ if task.exception() is not None:
+ loop.call_exception_handler({
+ 'message': 'Unhandled exception during Client.run shutdown.',
+ 'exception': task.exception(),
+ 'task': task
+ })
+
+def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
+ try:
+ _cancel_tasks(loop)
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ finally:
+ log.info('Closing the event loop.')
+ loop.close()
class Client:
r"""Represents a client connection that connects to Discord.
@@ -187,7 +200,6 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
-
def __init__(
self,
*,
@@ -195,8 +207,7 @@ class Client:
**options: Any,
):
self.ws: DiscordWebSocket = None # type: ignore
- # this is filled in later
- self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop
+ self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count')
@@ -205,16 +216,14 @@ class Client:
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
- self.http: HTTPClient = HTTPClient(
- connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop
- )
+ self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
self._handlers: Dict[str, Callable] = {
- 'ready': self._handle_ready,
+ 'ready': self._handle_ready
}
self._hooks: Dict[str, Callable] = {
- 'before_identify': self._call_before_identify_hook,
+ 'before_identify': self._call_before_identify_hook
}
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
@@ -235,9 +244,8 @@ class Client:
return self.ws
def _get_state(self, **options: Any) -> ConnectionState:
- return ConnectionState(
- dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options
- )
+ return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
+ hooks=self._hooks, http=self.http, loop=self.loop, **options)
def _handle_ready(self) -> None:
self._ready.set()
@@ -335,9 +343,7 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
- async def _run_event(
- self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
- ) -> None:
+ async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
@@ -348,9 +354,7 @@ class Client:
except asyncio.CancelledError:
pass
- def _schedule_event(
- self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
- ) -> asyncio.Task:
+ def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
@@ -462,8 +466,7 @@ class Client:
"""
log.info('logging in using static token')
- self.loop = loop = asyncio.get_running_loop()
- self._connection.loop = loop
+
data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)
@@ -509,14 +512,12 @@ class Client:
self.dispatch('disconnect')
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue
- except (
- OSError,
- HTTPException,
- GatewayNotFound,
- ConnectionClosed,
- aiohttp.ClientError,
- asyncio.TimeoutError,
- ) as exc:
+ except (OSError,
+ HTTPException,
+ GatewayNotFound,
+ ConnectionClosed,
+ aiohttp.ClientError,
+ asyncio.TimeoutError) as exc:
self.dispatch('disconnect')
if not reconnect:
@@ -557,22 +558,6 @@ class Client:
"""|coro|
Closes the connection to Discord.
-
- Instead of calling this directly, it is recommended to use the asynchronous context
- manager to allow resources to be cleaned up automatically:
-
- .. code-block:: python3
-
- async def main():
- async with Client() as client:
- await client.login(token)
- await client.connect()
-
- asyncio.run(main())
-
-
- .. versionchanged:: 2.0
- The client can now be closed with an asynchronous context manager
"""
if self._closed:
return
@@ -604,47 +589,36 @@ class Client:
self._connection.clear()
self.http.recreate()
- async def __aenter__(self: C) -> C:
- return self
-
- async def __aexit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_value: Optional[BaseException],
- traceback: Optional[TracebackType],
- ) -> None:
- await self.close()
-
async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro|
- A shorthand function equivalent to the following:
+ A shorthand coroutine for :meth:`login` + :meth:`connect`.
- .. code-block:: python3
-
- async with client:
- await client.login(token)
- await client.connect()
-
- This closes the client when it returns.
+ Raises
+ -------
+ TypeError
+ An unexpected keyword argument was received.
"""
- try:
- await self.login(token)
- await self.connect(reconnect=reconnect)
- finally:
- await self.close()
+ await self.login(token)
+ await self.connect(reconnect=reconnect)
def run(self, *args: Any, **kwargs: Any) -> None:
- """A convenience blocking call that abstracts away the event loop
+ """A blocking call that abstracts away the event loop
initialisation from you.
If you want more control over the event loop then this
function should not be used. Use :meth:`start` coroutine
or :meth:`connect` + :meth:`login`.
- Equivalent to: ::
+ Roughly Equivalent to: ::
- asyncio.run(bot.start(*args, **kwargs))
+ try:
+ loop.run_until_complete(start(*args, **kwargs))
+ except KeyboardInterrupt:
+ loop.run_until_complete(close())
+ # cancel all tasks lingering
+ finally:
+ loop.close()
.. warning::
@@ -652,7 +626,41 @@ class Client:
is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns.
"""
- asyncio.run(self.start(*args, **kwargs))
+ loop = self.loop
+
+ try:
+ loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
+ loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
+ except NotImplementedError:
+ pass
+
+ async def runner():
+ try:
+ await self.start(*args, **kwargs)
+ finally:
+ if not self.is_closed():
+ await self.close()
+
+ def stop_loop_on_completion(f):
+ loop.stop()
+
+ future = asyncio.ensure_future(runner(), loop=loop)
+ future.add_done_callback(stop_loop_on_completion)
+ try:
+ loop.run_forever()
+ except KeyboardInterrupt:
+ log.info('Received signal to terminate bot and event loop.')
+ finally:
+ future.remove_done_callback(stop_loop_on_completion)
+ log.info('Cleaning up tasks.')
+ _cleanup_loop(loop)
+
+ if not future.cancelled():
+ try:
+ return future.result()
+ except KeyboardInterrupt:
+ # I am unsure why this gets raised here but suppress it anyway
+ return None
# properties
@@ -965,10 +973,8 @@ class Client:
future = self.loop.create_future()
if check is None:
-
def _check(*args):
return True
-
check = _check
ev = event.lower()
@@ -1077,7 +1083,7 @@ class Client:
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
- after: SnowflakeTime = None,
+ after: SnowflakeTime = None
) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@@ -1157,7 +1163,7 @@ class Client:
"""
code = utils.resolve_template(code)
data = await self.http.get_template(code)
- return Template(data=data, state=self._connection) # type: ignore
+ return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int) -> Guild:
"""|coro|
@@ -1278,9 +1284,7 @@ class Client:
# Invite management
- async def fetch_invite(
- self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
- ) -> Invite:
+ async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
"""|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID.
@@ -1516,7 +1520,7 @@ class Client:
"""
data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore
- return cls(state=self._connection, data=data) # type: ignore
+ return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro|
diff --git a/discord/http.py b/discord/http.py
index 5a31928b..b186782f 100644
--- a/discord/http.py
+++ b/discord/http.py
@@ -167,7 +167,7 @@ class HTTPClient:
loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True
) -> None:
- self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop # filled in static_login
+ self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self.connector = connector
self.__session: aiohttp.ClientSession = MISSING # filled in static_login
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
@@ -371,7 +371,6 @@ class HTTPClient:
async def static_login(self, token: str) -> user.User:
# Necessary to get aiohttp to stop complaining about session creation
- self.loop = asyncio.get_running_loop()
self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse)
old_token = self.token
self.token = token