aboutsummaryrefslogtreecommitdiff
path: root/discord/client.py
diff options
context:
space:
mode:
authorJosh <[email protected]>2021-06-30 10:02:19 +1000
committerGitHub <[email protected]>2021-06-29 20:02:19 -0400
commit7601d6cec3acc4a94727a8a5fd1a5817641e1bae (patch)
tree703001dadb8bbb9bea70616014da96a844d0c263 /discord/client.py
parentAdd examples for how to use views (diff)
downloaddiscord.py-7601d6cec3acc4a94727a8a5fd1a5817641e1bae.tar.xz
discord.py-7601d6cec3acc4a94727a8a5fd1a5817641e1bae.zip
[typing] Type-hint client.py
Diffstat (limited to 'discord/client.py')
-rw-r--r--discord/client.py205
1 files changed, 118 insertions, 87 deletions
diff --git a/discord/client.py b/discord/client.py
index 64c9a678..2b3c3e17 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -29,7 +29,7 @@ import logging
import signal
import sys
import traceback
-from typing import Any, Generator, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
import aiohttp
@@ -38,12 +38,13 @@ from .invite import Invite
from .template import Template
from .widget import Widget
from .guild import Guild
+from .emoji import Emoji
from .channel import _channel_factory
from .enums import ChannelType
from .mentions import AllowedMentions
from .errors import *
from .enums import Status, VoiceRegion
-from .flags import ApplicationFlags
+from .flags import ApplicationFlags, Intents
from .gateway import *
from .activity import BaseActivity, create_activity
from .voice_client import VoiceClient
@@ -58,16 +59,24 @@ from .appinfo import AppInfo
from .ui.view import View
from .stage_instance import StageInstance
+if TYPE_CHECKING:
+ from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
+ from .channel import DMChannel
+ from .user import ClientUser
+ from .message import Message
+ from .member import Member
+ from .voice_client import VoiceProtocol
+
__all__ = (
'Client',
)
-if TYPE_CHECKING:
- from .abc import SnowflakeTime
+Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
-log = logging.getLogger(__name__)
-def _cancel_tasks(loop):
+log: logging.Logger = logging.getLogger(__name__)
+
+def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
@@ -90,7 +99,7 @@ def _cancel_tasks(loop):
'task': task
})
-def _cleanup_loop(loop):
+def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
@@ -116,7 +125,7 @@ class Client:
The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations.
Defaults to ``None``, in which case the default event loop is used via
:func:`asyncio.get_event_loop()`.
- connector: :class:`aiohttp.BaseConnector`
+ connector: Optional[:class:`aiohttp.BaseConnector`]
The connector to use for connection pooling.
proxy: Optional[:class:`str`]
Proxy URL.
@@ -181,31 +190,36 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
- def __init__(self, *, loop=None, **options):
- self.ws = None
- self.loop = asyncio.get_event_loop() if loop is None else loop
- self._listeners = {}
- self.shard_id = options.get('shard_id')
- self.shard_count = options.get('shard_count')
-
- connector = options.pop('connector', None)
- proxy = options.pop('proxy', None)
- proxy_auth = options.pop('proxy_auth', None)
- unsync_clock = options.pop('assume_unsync_clock', True)
- self.http = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
-
- self._handlers = {
+ def __init__(
+ self,
+ *,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ **options: Any,
+ ):
+ self.ws: DiscordWebSocket = None # type: ignore
+ self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
+ self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
+ self.shard_id: Optional[int] = options.get('shard_id')
+ self.shard_count: Optional[int] = options.get('shard_count')
+
+ connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
+ proxy: Optional[str] = options.pop('proxy', None)
+ proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
+ unsync_clock: bool = options.pop('assume_unsync_clock', True)
+ self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
+
+ self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready
}
- self._hooks = {
+ self._hooks: Dict[str, Callable] = {
'before_identify': self._call_before_identify_hook
}
- self._connection = self._get_state(**options)
+ self._connection: ConnectionState = self._get_state(**options)
self._connection.shard_count = self.shard_count
- self._closed = False
- self._ready = asyncio.Event()
+ self._closed: bool = False
+ self._ready: asyncio.Event = asyncio.Event()
self._connection._get_websocket = self._get_websocket
self._connection._get_client = lambda: self
@@ -215,18 +229,18 @@ class Client:
# internals
- def _get_websocket(self, guild_id=None, *, shard_id=None):
+ def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
return self.ws
- def _get_state(self, **options):
+ def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
- def _handle_ready(self):
+ def _handle_ready(self) -> None:
self._ready.set()
@property
- def latency(self):
+ def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This could be referred to as the Discord WebSocket protocol latency.
@@ -234,7 +248,7 @@ class Client:
ws = self.ws
return float('nan') if not ws else ws.latency
- def is_ws_ratelimited(self):
+ def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members
@@ -247,22 +261,22 @@ class Client:
return False
@property
- def user(self):
+ def user(self) -> Optional[ClientUser]:
"""Optional[:class:`.ClientUser`]: Represents the connected client. ``None`` if not logged in."""
return self._connection.user
@property
- def guilds(self):
+ def guilds(self) -> List[Guild]:
"""List[:class:`.Guild`]: The guilds that the connected client is a member of."""
return self._connection.guilds
@property
- def emojis(self):
+ def emojis(self) -> List[Emoji]:
"""List[:class:`.Emoji`]: The emojis that the connected client has."""
return self._connection.emojis
@property
- def cached_messages(self):
+ def cached_messages(self) -> Sequence[Message]:
"""Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached.
.. versionadded:: 1.1
@@ -270,7 +284,7 @@ class Client:
return utils.SequenceProxy(self._connection._messages or [])
@property
- def private_channels(self):
+ def private_channels(self) -> List[PrivateChannel]:
"""List[:class:`.abc.PrivateChannel`]: The private channels that the connected client is participating on.
.. note::
@@ -281,7 +295,7 @@ class Client:
return self._connection.private_channels
@property
- def voice_clients(self):
+ def voice_clients(self) -> List[VoiceProtocol]:
"""List[:class:`.VoiceProtocol`]: Represents a list of voice connections.
These are usually :class:`.VoiceClient` instances.
@@ -289,7 +303,7 @@ class Client:
return self._connection.voice_clients
@property
- def application_id(self):
+ def application_id(self) -> Optional[int]:
"""Optional[:class:`int`]: The client's application ID.
If this is not passed via ``__init__`` then this is retrieved
@@ -306,11 +320,11 @@ class Client:
"""
return self._connection.application_flags # type: ignore
- def is_ready(self):
+ def is_ready(self) -> bool:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
- async def _run_event(self, coro, event_name, *args, **kwargs):
+ async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
@@ -321,12 +335,12 @@ class Client:
except asyncio.CancelledError:
pass
- def _schedule_event(self, coro, event_name, *args, **kwargs):
+ def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
- def dispatch(self, event, *args, **kwargs):
+ def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
log.debug('Dispatching event %s', event)
method = 'on_' + event
@@ -366,7 +380,7 @@ class Client:
else:
self._schedule_event(coro, method, *args, **kwargs)
- async def on_error(self, event_method, *args, **kwargs):
+ async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None:
"""|coro|
The default error handler provided by the client.
@@ -380,13 +394,13 @@ class Client:
# hooks
- async def _call_before_identify_hook(self, shard_id, *, initial=False):
+ async def _call_before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None:
# This hook is an internal hook that actually calls the public one.
# It allows the library to have its own hook without stepping on the
# toes of those who need to override their own hook.
await self.before_identify_hook(shard_id, initial=initial)
- async def before_identify_hook(self, shard_id, *, initial=False):
+ async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None:
"""|coro|
A hook that is called before IDENTIFYing a session. This is useful
@@ -410,7 +424,7 @@ class Client:
# login state management
- async def login(self, token):
+ async def login(self, token: str) -> None:
"""|coro|
Logs in the client with the specified credentials.
@@ -435,7 +449,7 @@ class Client:
log.info('logging in using static token')
await self.http.static_login(token.strip())
- async def connect(self, *, reconnect=True):
+ async def connect(self, *, reconnect: bool = True) -> None:
"""|coro|
Creates a websocket connection and lets the websocket listen
@@ -519,7 +533,7 @@ class Client:
# This is apparently what the official Discord client does.
ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id)
- async def close(self):
+ async def close(self) -> None:
"""|coro|
Closes the connection to Discord.
@@ -531,7 +545,7 @@ class Client:
for voice in self.voice_clients:
try:
- await voice.disconnect()
+ await voice.disconnect(force=True)
except Exception:
# if an error happens during disconnects, disregard it.
pass
@@ -542,7 +556,7 @@ class Client:
await self.http.close()
self._ready.clear()
- def clear(self):
+ def clear(self) -> None:
"""Clears the internal state of the bot.
After this, the bot can be considered "re-opened", i.e. :meth:`is_closed`
@@ -554,7 +568,7 @@ class Client:
self._connection.clear()
self.http.recreate()
- async def start(self, token, *, reconnect=True):
+ async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro|
A shorthand coroutine for :meth:`login` + :meth:`connect`.
@@ -567,7 +581,7 @@ class Client:
await self.login(token)
await self.connect(reconnect=reconnect)
- def run(self, *args, **kwargs):
+ def run(self, *args: Any, **kwargs: Any) -> None:
"""A blocking call that abstracts away the event loop
initialisation from you.
@@ -629,19 +643,19 @@ class Client:
# properties
- def is_closed(self):
+ def is_closed(self) -> bool:
""":class:`bool`: Indicates if the websocket connection is closed."""
return self._closed
@property
- def activity(self):
+ def activity(self) -> Optional[BaseActivity]:
"""Optional[:class:`.BaseActivity`]: The activity being used upon
logging in.
"""
return create_activity(self._connection._activity)
@activity.setter
- def activity(self, value):
+ def activity(self, value: Optional[BaseActivity]) -> None:
if value is None:
self._connection._activity = None
elif isinstance(value, BaseActivity):
@@ -650,7 +664,7 @@ class Client:
raise TypeError('activity must derive from BaseActivity.')
@property
- def allowed_mentions(self):
+ def allowed_mentions(self) -> Optional[AllowedMentions]:
"""Optional[:class:`~discord.AllowedMentions`]: The allowed mention configuration.
.. versionadded:: 1.4
@@ -658,14 +672,14 @@ class Client:
return self._connection.allowed_mentions
@allowed_mentions.setter
- def allowed_mentions(self, value):
+ def allowed_mentions(self, value: Optional[AllowedMentions]) -> None:
if value is None or isinstance(value, AllowedMentions):
self._connection.allowed_mentions = value
else:
raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}')
@property
- def intents(self):
+ def intents(self) -> Intents:
""":class:`~discord.Intents`: The intents configured for this connection.
.. versionadded:: 1.5
@@ -675,11 +689,11 @@ class Client:
# helpers/getters
@property
- def users(self):
+ def users(self) -> List[User]:
"""List[:class:`~discord.User`]: Returns a list of all the users the bot can see."""
return list(self._connection._users.values())
- def get_channel(self, id):
+ def get_channel(self, id: int) -> Optional[Union[GuildChannel, PrivateChannel]]:
"""Returns a channel with the given ID.
Parameters
@@ -716,7 +730,7 @@ class Client:
if isinstance(channel, StageChannel):
return channel.instance
- def get_guild(self, id):
+ def get_guild(self, id) -> Optional[Guild]:
"""Returns a guild with the given ID.
Parameters
@@ -731,7 +745,7 @@ class Client:
"""
return self._connection._get_guild(id)
- def get_user(self, id):
+ def get_user(self, id) -> Optional[User]:
"""Returns a user with the given ID.
Parameters
@@ -746,7 +760,7 @@ class Client:
"""
return self._connection.get_user(id)
- def get_emoji(self, id):
+ def get_emoji(self, id) -> Optional[Emoji]:
"""Returns an emoji with the given ID.
Parameters
@@ -761,7 +775,7 @@ class Client:
"""
return self._connection.get_emoji(id)
- def get_all_channels(self):
+ def get_all_channels(self) -> Generator[GuildChannel, None, None]:
"""A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'.
This is equivalent to: ::
@@ -785,7 +799,7 @@ class Client:
for guild in self.guilds:
yield from guild.channels
- def get_all_members(self):
+ def get_all_members(self) -> Generator[Member, None, None]:
"""Returns a generator with every :class:`.Member` the client can see.
This is equivalent to: ::
@@ -804,14 +818,20 @@ class Client:
# listeners/waiters
- async def wait_until_ready(self):
+ async def wait_until_ready(self) -> None:
"""|coro|
Waits until the client's internal cache is all ready.
"""
await self._ready.wait()
- def wait_for(self, event, *, check=None, timeout=None):
+ def wait_for(
+ self,
+ event: str,
+ *,
+ check: Optional[Callable[..., bool]] = None,
+ timeout: Optional[float] = None,
+ ) -> Any:
"""|coro|
Waits for a WebSocket event to be dispatched.
@@ -911,7 +931,7 @@ class Client:
# event registration
- def event(self, coro):
+ def event(self, coro: Coro) -> Coro:
"""A decorator that registers an event to listen to.
You can find more info about the events on the :ref:`documentation below <discord-api-events>`.
@@ -940,7 +960,13 @@ class Client:
log.debug('%s has successfully been registered as an event', coro.__name__)
return coro
- async def change_presence(self, *, activity=None, status=None, afk=False):
+ async def change_presence(
+ self,
+ *,
+ activity: Optional[BaseActivity] = None,
+ status: Optional[Status] = None,
+ afk: bool = False,
+ ):
"""|coro|
Changes the client's presence.
@@ -972,16 +998,15 @@ class Client:
"""
if status is None:
- status = 'online'
- status_enum = Status.online
+ status_str = 'online'
+ status = Status.online
elif status is Status.offline:
- status = 'invisible'
- status_enum = Status.offline
+ status_str = 'invisible'
+ status = Status.offline
else:
- status_enum = status
- status = str(status)
+ status_str = str(status)
- await self.ws.change_presence(activity=activity, status=status, afk=afk)
+ await self.ws.change_presence(activity=activity, status=status_str, afk=afk)
for guild in self._connection.guilds:
me = guild.me
@@ -993,11 +1018,17 @@ class Client:
else:
me.activities = ()
- me.status = status_enum
+ me.status = status
# Guild stuff
- def fetch_guilds(self, *, limit: int = 100, before: SnowflakeTime = None, after: SnowflakeTime = None) -> List[Guild]:
+ def fetch_guilds(
+ self,
+ *,
+ limit: Optional[int] = 100,
+ before: SnowflakeTime = None,
+ after: SnowflakeTime = None
+ ) -> List[Guild]:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
.. note::
@@ -1052,7 +1083,7 @@ class Client:
"""
return GuildIterator(self, limit=limit, before=before, after=after)
- async def fetch_template(self, code):
+ async def fetch_template(self, code: Union[Template, str]) -> Template:
"""|coro|
Gets a :class:`.Template` from a discord.new URL or code.
@@ -1078,7 +1109,7 @@ class Client:
data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore
- async def fetch_guild(self, guild_id):
+ async def fetch_guild(self, guild_id: int) -> Guild:
"""|coro|
Retrieves a :class:`.Guild` from an ID.
@@ -1112,7 +1143,7 @@ class Client:
data = await self.http.get_guild(guild_id)
return Guild(data=data, state=self._connection)
- async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None, *, code: str = None):
+ async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None, *, code: str = None) -> Guild:
"""|coro|
Creates a :class:`.Guild`.
@@ -1259,7 +1290,7 @@ class Client:
# Miscellaneous stuff
- async def fetch_widget(self, guild_id):
+ async def fetch_widget(self, guild_id: int) -> Widget:
"""|coro|
Gets a :class:`.Widget` from a guild ID.
@@ -1289,7 +1320,7 @@ class Client:
return Widget(state=self._connection, data=data)
- async def application_info(self):
+ async def application_info(self) -> AppInfo:
"""|coro|
Retrieves the bot's application information.
@@ -1309,7 +1340,7 @@ class Client:
data['rpc_origins'] = None
return AppInfo(self._connection, data)
- async def fetch_user(self, user_id):
+ async def fetch_user(self, user_id: int) -> User:
"""|coro|
Retrieves a :class:`~discord.User` based on their ID.
@@ -1340,7 +1371,7 @@ class Client:
data = await self.http.get_user(user_id)
return User(state=self._connection, data=data)
- async def fetch_channel(self, channel_id):
+ async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]:
"""|coro|
Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID.
@@ -1382,7 +1413,7 @@ class Client:
return channel
- async def fetch_webhook(self, webhook_id):
+ async def fetch_webhook(self, webhook_id: int) -> Webhook:
"""|coro|
Retrieves a :class:`.Webhook` with the specified ID.
@@ -1404,7 +1435,7 @@ class Client:
data = await self.http.get_webhook(webhook_id)
return Webhook.from_state(data, state=self._connection)
- async def create_dm(self, user):
+ async def create_dm(self, user: Snowflake) -> DMChannel:
"""|coro|
Creates a :class:`.DMChannel` with this user.