aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-08-10 22:58:02 -0400
committerRapptz <[email protected]>2021-08-10 23:00:24 -0400
commit6e6c8a7b2810747222a938c7fe3e466c2994b23f (patch)
tree4890629c984158d6b206b5b7c40f967f8d7f7fbf
parentClarify StageInstance.discoverable_disabled documentation (diff)
downloaddiscord.py-6e6c8a7b2810747222a938c7fe3e466c2994b23f.tar.xz
discord.py-6e6c8a7b2810747222a938c7fe3e466c2994b23f.zip
Refactor Client.run to use asyncio.run
This also adds asynchronous context manager support to allow for idiomatic asyncio usage for the lower-level counterpart. At first I wanted to remove Client.run but I figured that a lot of beginners would have been confused or not enjoyed the verbosity of the newer approach of using async-with.
-rw-r--r--discord/client.py200
-rw-r--r--discord/http.py3
2 files changed, 100 insertions, 103 deletions
diff --git a/discord/client.py b/discord/client.py
index cdfc6bdb..4279a36b 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -26,10 +26,24 @@ from __future__ import annotations
import asyncio
import logging
-import signal
import sys
import traceback
-from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
+from typing import (
+ Any,
+ Callable,
+ Coroutine,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ TYPE_CHECKING,
+ Tuple,
+ TypeVar,
+ Type,
+ Union,
+)
import aiohttp
@@ -68,6 +82,7 @@ if TYPE_CHECKING:
from .message import Message
from .member import Member
from .voice_client import VoiceProtocol
+ from types import TracebackType
__all__ = (
'Client',
@@ -78,36 +93,8 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
log: logging.Logger = logging.getLogger(__name__)
-def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
- tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
-
- if not tasks:
- return
-
- log.info('Cleaning up after %d tasks.', len(tasks))
- for task in tasks:
- task.cancel()
-
- loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
- log.info('All tasks finished cancelling.')
-
- for task in tasks:
- if task.cancelled():
- continue
- if task.exception() is not None:
- loop.call_exception_handler({
- 'message': 'Unhandled exception during Client.run shutdown.',
- 'exception': task.exception(),
- 'task': task
- })
-
-def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
- try:
- _cancel_tasks(loop)
- loop.run_until_complete(loop.shutdown_asyncgens())
- finally:
- log.info('Closing the event loop.')
- loop.close()
+C = TypeVar('C', bound='Client')
+
class Client:
r"""Represents a client connection that connects to Discord.
@@ -200,6 +187,7 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
+
def __init__(
self,
*,
@@ -207,7 +195,8 @@ class Client:
**options: Any,
):
self.ws: DiscordWebSocket = None # type: ignore
- self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
+ # this is filled in later
+ self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count')
@@ -216,14 +205,16 @@ class Client:
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
- self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
+ self.http: HTTPClient = HTTPClient(
+ connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop
+ )
self._handlers: Dict[str, Callable] = {
- 'ready': self._handle_ready
+ 'ready': self._handle_ready,
}
self._hooks: Dict[str, Callable] = {
- 'before_identify': self._call_before_identify_hook
+ 'before_identify': self._call_before_identify_hook,
}
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
@@ -244,8 +235,9 @@ class Client:
return self.ws
def _get_state(self, **options: Any) -> ConnectionState:
- return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
- hooks=self._hooks, http=self.http, loop=self.loop, **options)
+ return ConnectionState(
+ dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options
+ )
def _handle_ready(self) -> None:
self._ready.set()
@@ -343,7 +335,9 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
- async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
+ async def _run_event(
+ self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
+ ) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
@@ -354,7 +348,9 @@ class Client:
except asyncio.CancelledError:
pass
- def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
+ def _schedule_event(
+ self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
+ ) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
@@ -466,7 +462,8 @@ class Client:
"""
log.info('logging in using static token')
-
+ self.loop = loop = asyncio.get_running_loop()
+ self._connection.loop = loop
data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)
@@ -512,12 +509,14 @@ class Client:
self.dispatch('disconnect')
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue
- except (OSError,
- HTTPException,
- GatewayNotFound,
- ConnectionClosed,
- aiohttp.ClientError,
- asyncio.TimeoutError) as exc:
+ except (
+ OSError,
+ HTTPException,
+ GatewayNotFound,
+ ConnectionClosed,
+ aiohttp.ClientError,
+ asyncio.TimeoutError,
+ ) as exc:
self.dispatch('disconnect')
if not reconnect:
@@ -558,6 +557,22 @@ class Client:
"""|coro|
Closes the connection to Discord.
+
+ Instead of calling this directly, it is recommended to use the asynchronous context
+ manager to allow resources to be cleaned up automatically:
+
+ .. code-block:: python3
+
+ async def main():
+ async with Client() as client:
+ await client.login(token)
+ await client.connect()
+
+ asyncio.run(main())
+
+
+ .. versionchanged:: 2.0
+ The client can now be closed with an asynchronous context manager
"""
if self._closed:
return
@@ -589,36 +604,47 @@ class Client:
self._connection.clear()
self.http.recreate()
+ async def __aenter__(self: C) -> C:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro|
- A shorthand coroutine for :meth:`login` + :meth:`connect`.
+ A shorthand function equivalent to the following:
- Raises
- -------
- TypeError
- An unexpected keyword argument was received.
+ .. code-block:: python3
+
+ async with client:
+ await client.login(token)
+ await client.connect()
+
+ This closes the client when it returns.
"""
- await self.login(token)
- await self.connect(reconnect=reconnect)
+ try:
+ await self.login(token)
+ await self.connect(reconnect=reconnect)
+ finally:
+ await self.close()
def run(self, *args: Any, **kwargs: Any) -> None:
- """A blocking call that abstracts away the event loop
+ """A convenience blocking call that abstracts away the event loop
initialisation from you.
If you want more control over the event loop then this
function should not be used. Use :meth:`start` coroutine
or :meth:`connect` + :meth:`login`.
- Roughly Equivalent to: ::
+ Equivalent to: ::
- try:
- loop.run_until_complete(start(*args, **kwargs))
- except KeyboardInterrupt:
- loop.run_until_complete(close())
- # cancel all tasks lingering
- finally:
- loop.close()
+ asyncio.run(bot.start(*args, **kwargs))
.. warning::
@@ -626,41 +652,7 @@ class Client:
is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns.
"""
- loop = self.loop
-
- try:
- loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
- loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
- except NotImplementedError:
- pass
-
- async def runner():
- try:
- await self.start(*args, **kwargs)
- finally:
- if not self.is_closed():
- await self.close()
-
- def stop_loop_on_completion(f):
- loop.stop()
-
- future = asyncio.ensure_future(runner(), loop=loop)
- future.add_done_callback(stop_loop_on_completion)
- try:
- loop.run_forever()
- except KeyboardInterrupt:
- log.info('Received signal to terminate bot and event loop.')
- finally:
- future.remove_done_callback(stop_loop_on_completion)
- log.info('Cleaning up tasks.')
- _cleanup_loop(loop)
-
- if not future.cancelled():
- try:
- return future.result()
- except KeyboardInterrupt:
- # I am unsure why this gets raised here but suppress it anyway
- return None
+ asyncio.run(self.start(*args, **kwargs))
# properties
@@ -973,8 +965,10 @@ class Client:
future = self.loop.create_future()
if check is None:
+
def _check(*args):
return True
+
check = _check
ev = event.lower()
@@ -1083,7 +1077,7 @@ class Client:
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
- after: SnowflakeTime = None
+ after: SnowflakeTime = None,
) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@@ -1163,7 +1157,7 @@ class Client:
"""
code = utils.resolve_template(code)
data = await self.http.get_template(code)
- return Template(data=data, state=self._connection) # type: ignore
+ return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int) -> Guild:
"""|coro|
@@ -1284,7 +1278,9 @@ class Client:
# Invite management
- async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
+ async def fetch_invite(
+ self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
+ ) -> Invite:
"""|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID.
@@ -1520,7 +1516,7 @@ class Client:
"""
data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore
- return cls(state=self._connection, data=data) # type: ignore
+ return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro|
diff --git a/discord/http.py b/discord/http.py
index b186782f..5a31928b 100644
--- a/discord/http.py
+++ b/discord/http.py
@@ -167,7 +167,7 @@ class HTTPClient:
loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True
) -> None:
- self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
+ self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop # filled in static_login
self.connector = connector
self.__session: aiohttp.ClientSession = MISSING # filled in static_login
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
@@ -371,6 +371,7 @@ class HTTPClient:
async def static_login(self, token: str) -> user.User:
# Necessary to get aiohttp to stop complaining about session creation
+ self.loop = asyncio.get_running_loop()
self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse)
old_token = self.token
self.token = token