diff options
Diffstat (limited to 'discord')
| -rw-r--r-- | discord/client.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/discord/client.py b/discord/client.py index 569d17c0..2d029371 100644 --- a/discord/client.py +++ b/discord/client.py @@ -29,7 +29,7 @@ import logging import signal import sys import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union import aiohttp @@ -206,6 +206,7 @@ class Client: loop: Optional[asyncio.AbstractEventLoop] = None, **options: Any, ): + # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} @@ -682,7 +683,8 @@ class Client: if value is None: self._connection._activity = None elif isinstance(value, BaseActivity): - self._connection._activity = value.to_dict() + # ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any] + self._connection._activity = value.to_dict() # type: ignore else: raise TypeError('activity must derive from BaseActivity.') @@ -716,8 +718,8 @@ class Client: """List[:class:`~discord.User`]: Returns a list of all the users the bot can see.""" return list(self._connection._users.values()) - def get_channel(self, id: int) -> Optional[Union[GuildChannel, PrivateChannel]]: - """Returns a channel with the given ID. + def get_channel(self, id: int) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: + """Returns a channel or thread with the given ID. Parameters ----------- @@ -726,7 +728,7 @@ class Client: Returns -------- - Optional[Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`]] + Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]] The returned channel or ``None`` if not found. """ return self._connection.get_channel(id) @@ -1473,11 +1475,14 @@ class Client: raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): - channel = factory(me=self.user, data=data, state=self._connection) + # the factory will be a DMChannel or GroupChannel here + channel = factory(me=self.user, data=data, state=self._connection) # type: ignore else: - guild_id = int(data['guild_id']) + # the factory can't be a DMChannel or GroupChannel here + guild_id = int(data['guild_id']) # type: ignore guild = self.get_guild(guild_id) or Object(id=guild_id) - channel = factory(guild=guild, state=self._connection, data=data) + # GuildChannels expect a Guild, we may be passing an Object + channel = factory(guild=guild, state=self._connection, data=data) # type: ignore return channel |