aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
Diffstat (limited to 'discord')
-rw-r--r--discord/state.py411
-rw-r--r--discord/types/appinfo.py1
-rw-r--r--discord/types/voice.py4
3 files changed, 237 insertions, 179 deletions
diff --git a/discord/state.py b/discord/state.py
index b6f407d8..680d112f 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -30,9 +30,7 @@ import copy
import datetime
import itertools
import logging
-from typing import Dict, Optional, TYPE_CHECKING, Union
-import weakref
-import warnings
+from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque
import inspect
import os
@@ -56,26 +54,43 @@ from .object import Object
from .invite import Invite
from .integrations import _integration_factory
from .interactions import Interaction
-from .ui.view import ViewStore
+from .ui.view import ViewStore, View
from .stage_instance import StageInstance
from .threads import Thread, ThreadMember
from .sticker import GuildSticker
if TYPE_CHECKING:
+ from .abc import PrivateChannel
+ from .message import MessageableChannel
+ from .guild import GuildChannel, VocalGuildChannel
from .http import HTTPClient
+ from .voice_client import VoiceProtocol
+ from .client import Client
+ from .gateway import DiscordWebSocket
+
from .types.activity import Activity as ActivityPayload
+ from .types.channel import DMChannel as DMChannelPayload
+ from .types.user import User as UserPayload
+ from .types.emoji import Emoji as EmojiPayload
+ from .types.sticker import GuildSticker as GuildStickerPayload
+ from .types.guild import Guild as GuildPayload
+ from .types.message import Message as MessagePayload
+
+ T = TypeVar('T')
+ CS = TypeVar('CS', bound='ConnectionState')
+ Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable]
class ChunkRequest:
- def __init__(self, guild_id, loop, resolver, *, cache=True):
- self.guild_id = guild_id
- self.resolver = resolver
- self.loop = loop
- self.cache = cache
- self.nonce = os.urandom(16).hex()
- self.buffer = [] # List[Member]
- self.waiters = []
-
- def add_members(self, members):
+ def __init__(self, guild_id: int, loop: asyncio.AbstractEventLoop, resolver: Callable[[int], Any], *, cache: bool = True) -> None:
+ self.guild_id: int = guild_id
+ self.resolver: Callable[[int], Any] = resolver
+ self.loop: asyncio.AbstractEventLoop = loop
+ self.cache: bool = cache
+ self.nonce: str = os.urandom(16).hex()
+ self.buffer: List[Member] = []
+ self.waiters: List[asyncio.Future[List[Member]]] = []
+
+ def add_members(self, members: List[Member]) -> None:
self.buffer.extend(members)
if self.cache:
guild = self.resolver(self.guild_id)
@@ -87,7 +102,7 @@ class ChunkRequest:
if existing is None or existing.joined_at is None:
guild._add_member(member)
- async def wait(self):
+ async def wait(self) -> List[Member]:
future = self.loop.create_future()
self.waiters.append(future)
try:
@@ -95,35 +110,40 @@ class ChunkRequest:
finally:
self.waiters.remove(future)
- def get_future(self):
+ def get_future(self) -> asyncio.Future[List[Member]]:
future = self.loop.create_future()
self.waiters.append(future)
return future
- def done(self):
+ def done(self) -> None:
for future in self.waiters:
if not future.done():
future.set_result(self.buffer)
-log = logging.getLogger(__name__)
+log: logging.Logger = logging.getLogger(__name__)
-async def logging_coroutine(coroutine, *, info):
+async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> Optional[T]:
try:
await coroutine
except Exception:
log.exception('Exception occurred during %s', info)
class ConnectionState:
- def __init__(self, *, dispatch, handlers, hooks, http: HTTPClient, loop: asyncio.AbstractEventLoop, **options):
+ if TYPE_CHECKING:
+ _get_websocket: Callable[..., DiscordWebSocket]
+ _get_client: Callable[..., Client]
+ _parsers: Dict[str, Callable[[Dict[str, Any]], None]]
+
+ def __init__(self, *, dispatch: Callable, handlers: Dict[str, Callable], hooks: Dict[str, Callable], http: HTTPClient, loop: asyncio.AbstractEventLoop, **options: Any) -> None:
self.loop: asyncio.AbstractEventLoop = loop
self.http: HTTPClient = http
self.max_messages: Optional[int] = options.get('max_messages', 1000)
if self.max_messages is not None and self.max_messages <= 0:
self.max_messages = 1000
- self.dispatch = dispatch
- self.handlers = handlers
- self.hooks = hooks
+ self.dispatch: Callable = dispatch
+ self.handlers: Dict[str, Callable] = handlers
+ self.hooks: Dict[str, Callable] = hooks
self.shard_count: Optional[int] = None
self._ready_task: Optional[asyncio.Task] = None
self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id')
@@ -195,8 +215,8 @@ class ConnectionState:
self.clear()
- def clear(self):
- self.user = None
+ def clear(self) -> None:
+ self.user: Optional[ClientUser] = None
# Originally, this code used WeakValueDictionary to maintain references to the
# global user mapping.
@@ -210,19 +230,22 @@ class ConnectionState:
# using __del__. Testing this for memory leaks led to no discernable leaks,
# though more testing will have to be done.
self._users: Dict[int, User] = {}
- self._emojis = {}
- self._stickers = {}
- self._guilds = {}
- self._view_store = ViewStore(self)
- self._voice_clients = {}
+ self._emojis: Dict[int, Emoji] = {}
+ self._stickers: Dict[int, GuildSticker] = {}
+ self._guilds: Dict[int, Guild] = {}
+ self._view_store: ViewStore = ViewStore(self)
+ self._voice_clients: Dict[int, VoiceProtocol] = {}
# LRU of max size 128
- self._private_channels = OrderedDict()
+ self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict()
# extra dict to look up private channels by user id
- self._private_channels_by_user = {}
- self._messages = self.max_messages and deque(maxlen=self.max_messages)
+ self._private_channels_by_user: Dict[int, PrivateChannel] = {}
+ if self.max_messages is not None:
+ self._messages: Optional[Deque[Message]] = deque(maxlen=self.max_messages)
+ else:
+ self._messages: Optional[Deque[Message]] = None
- def process_chunk_requests(self, guild_id, nonce, members, complete):
+ def process_chunk_requests(self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool) -> None:
removed = []
for key, request in self._chunk_requests.items():
if request.guild_id == guild_id and request.nonce == nonce:
@@ -234,7 +257,7 @@ class ConnectionState:
for key in removed:
del self._chunk_requests[key]
- def call_handlers(self, key, *args, **kwargs):
+ def call_handlers(self, key: str, *args: Any, **kwargs: Any) -> None:
try:
func = self.handlers[key]
except KeyError:
@@ -242,7 +265,7 @@ class ConnectionState:
else:
func(*args, **kwargs)
- async def call_hooks(self, key, *args, **kwargs):
+ async def call_hooks(self, key: str, *args: Any, **kwargs: Any) -> None:
try:
coro = self.hooks[key]
except KeyError:
@@ -251,34 +274,35 @@ class ConnectionState:
await coro(*args, **kwargs)
@property
- def self_id(self):
+ def self_id(self) -> Optional[int]:
u = self.user
return u.id if u else None
@property
- def intents(self):
+ def intents(self) -> Intents:
ret = Intents.none()
ret.value = self._intents.value
return ret
@property
- def voice_clients(self):
+ def voice_clients(self) -> List[VoiceProtocol]:
return list(self._voice_clients.values())
- def _get_voice_client(self, guild_id):
- return self._voice_clients.get(guild_id)
+ def _get_voice_client(self, guild_id: Optional[int]) -> Optional[VoiceProtocol]:
+ # the keys of self._voice_clients are ints
+ return self._voice_clients.get(guild_id) # type: ignore
- def _add_voice_client(self, guild_id, voice):
+ def _add_voice_client(self, guild_id: int, voice: VoiceProtocol) -> None:
self._voice_clients[guild_id] = voice
- def _remove_voice_client(self, guild_id):
+ def _remove_voice_client(self, guild_id: int) -> None:
self._voice_clients.pop(guild_id, None)
- def _update_references(self, ws):
+ def _update_references(self, ws: DiscordWebSocket) -> None:
for vc in self.voice_clients:
- vc.main_ws = ws
+ vc.main_ws = ws # type: ignore
- def store_user(self, data):
+ def store_user(self, data: UserPayload) -> User:
user_id = int(data['id'])
try:
return self._users[user_id]
@@ -289,49 +313,52 @@ class ConnectionState:
user._stored = True
return user
- def deref_user(self, user_id):
+ def deref_user(self, user_id: int) -> None:
self._users.pop(user_id, None)
- def create_user(self, data):
+ def create_user(self, data: UserPayload) -> User:
return User(state=self, data=data)
- def deref_user_no_intents(self, user_id):
+ def deref_user_no_intents(self, user_id: int) -> None:
return
- def get_user(self, id):
- return self._users.get(id)
+ def get_user(self, id: Optional[int]) -> Optional[User]:
+ # the keys of self._users are ints
+ return self._users.get(id) # type: ignore
- def store_emoji(self, guild, data):
- emoji_id = int(data['id'])
+ def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji:
+ # the id will be present here
+ emoji_id = int(data['id']) # type: ignore
self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data)
return emoji
- def store_sticker(self, guild, data):
+ def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker:
sticker_id = int(data['id'])
self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data)
return sticker
- def store_view(self, view, message_id=None):
+ def store_view(self, view: View, message_id: Optional[int] = None) -> None:
self._view_store.add_view(view, message_id)
- def prevent_view_updates_for(self, message_id):
+ def prevent_view_updates_for(self, message_id: int) -> Optional[View]:
return self._view_store.remove_message_tracking(message_id)
@property
- def persistent_views(self):
+ def persistent_views(self) -> Sequence[View]:
return self._view_store.persistent_views
@property
- def guilds(self):
+ def guilds(self) -> List[Guild]:
return list(self._guilds.values())
- def _get_guild(self, guild_id):
- return self._guilds.get(guild_id)
+ def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
+ # the keys of self._guilds are ints
+ return self._guilds.get(guild_id) # type: ignore
- def _add_guild(self, guild):
+ def _add_guild(self, guild: Guild) -> None:
self._guilds[guild.id] = guild
- def _remove_guild(self, guild):
+ def _remove_guild(self, guild: Guild) -> None:
self._guilds.pop(guild.id, None)
for emoji in guild.emojis:
@@ -343,36 +370,40 @@ class ConnectionState:
del guild
@property
- def emojis(self):
+ def emojis(self) -> List[Emoji]:
return list(self._emojis.values())
@property
- def stickers(self):
+ def stickers(self) -> List[GuildSticker]:
return list(self._stickers.values())
- def get_emoji(self, emoji_id):
- return self._emojis.get(emoji_id)
+ def get_emoji(self, emoji_id: Optional[int]) -> Optional[Emoji]:
+ # the keys of self._emojis are ints
+ return self._emojis.get(emoji_id) # type: ignore
- def get_sticker(self, sticker_id):
- return self._stickers.get(sticker_id)
+ def get_sticker(self, sticker_id: Optional[int]) -> Optional[GuildSticker]:
+ # the keys of self._stickers are ints
+ return self._stickers.get(sticker_id) # type: ignore
@property
- def private_channels(self):
+ def private_channels(self) -> List[PrivateChannel]:
return list(self._private_channels.values())
- def _get_private_channel(self, channel_id):
+ def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]:
try:
- value = self._private_channels[channel_id]
+ # the keys of self._private_channels are ints
+ value = self._private_channels[channel_id] # type: ignore
except KeyError:
return None
else:
- self._private_channels.move_to_end(channel_id)
+ self._private_channels.move_to_end(channel_id) # type: ignore
return value
- def _get_private_channel_by_user(self, user_id):
- return self._private_channels_by_user.get(user_id)
+ def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[PrivateChannel]:
+ # the keys of self._private_channels are ints
+ return self._private_channels_by_user.get(user_id) # type: ignore
- def _add_private_channel(self, channel):
+ def _add_private_channel(self, channel: PrivateChannel) -> None:
channel_id = channel.id
self._private_channels[channel_id] = channel
@@ -384,29 +415,32 @@ class ConnectionState:
if isinstance(channel, DMChannel) and channel.recipient:
self._private_channels_by_user[channel.recipient.id] = channel
- def add_dm_channel(self, data):
- channel = DMChannel(me=self.user, state=self, data=data)
+ def add_dm_channel(self, data: DMChannelPayload) -> DMChannel:
+ # self.user is *always* cached when this is called
+ channel = DMChannel(me=self.user, state=self, data=data) # type: ignore
self._add_private_channel(channel)
return channel
- def _remove_private_channel(self, channel):
+ def _remove_private_channel(self, channel: PrivateChannel) -> None:
self._private_channels.pop(channel.id, None)
if isinstance(channel, DMChannel):
- self._private_channels_by_user.pop(channel.recipient.id, None)
+ recipient = channel.recipient
+ if recipient is not None:
+ self._private_channels_by_user.pop(recipient.id, None)
- def _get_message(self, msg_id):
+ def _get_message(self, msg_id: Optional[int]) -> Optional[Message]:
return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None
- def _add_guild_from_data(self, guild):
- guild = Guild(data=guild, state=self)
+ def _add_guild_from_data(self, data: GuildPayload) -> Guild:
+ guild = Guild(data=data, state=self)
self._add_guild(guild)
return guild
- def _guild_needs_chunking(self, guild):
+ def _guild_needs_chunking(self, guild: Guild) -> bool:
# If presences are enabled then we get back the old guild.large behaviour
return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large)
- def _get_guild_channel(self, data):
+ def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Thread], Optional[Guild]]:
channel_id = int(data['channel_id'])
try:
guild = self._get_guild(int(data['guild_id']))
@@ -418,11 +452,11 @@ class ConnectionState:
return channel or PartialMessageable(state=self, id=channel_id), guild
- async def chunker(self, guild_id, query='', limit=0, presences=False, *, nonce=None):
+ async def chunker(self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None) -> None:
ws = self._get_websocket(guild_id) # This is ignored upstream
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
- async def query_members(self, guild, query, limit, user_ids, cache, presences):
+ async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool):
guild_id = guild.id
ws = self._get_websocket(guild_id)
if ws is None:
@@ -439,7 +473,7 @@ class ConnectionState:
log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id)
raise
- async def _delay_ready(self):
+ async def _delay_ready(self) -> None:
try:
states = []
while True:
@@ -485,13 +519,13 @@ class ConnectionState:
finally:
self._ready_task = None
- def parse_ready(self, data):
+ def parse_ready(self, data) -> None:
if self._ready_task is not None:
self._ready_task.cancel()
self._ready_state = asyncio.Queue()
self.clear()
- self.user = user = ClientUser(state=self, data=data['user'])
+ self.user = ClientUser(state=self, data=data['user'])
self.store_user(data['user'])
if self.application_id is None:
@@ -501,7 +535,8 @@ class ConnectionState:
pass
else:
self.application_id = utils._get_as_snowflake(application, 'id')
- self.application_flags = ApplicationFlags._from_value(application['flags'])
+ # flags will always be present here
+ self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore
for guild_data in data['guilds']:
self._add_guild_from_data(guild_data)
@@ -509,19 +544,21 @@ class ConnectionState:
self.dispatch('connect')
self._ready_task = asyncio.create_task(self._delay_ready())
- def parse_resumed(self, data):
+ def parse_resumed(self, data) -> None:
self.dispatch('resumed')
- def parse_message_create(self, data):
+ def parse_message_create(self, data) -> None:
channel, _ = self._get_guild_channel(data)
- message = Message(channel=channel, data=data, state=self)
+ # channel would be the correct type here
+ message = Message(channel=channel, data=data, state=self) # type: ignore
self.dispatch('message', message)
if self._messages is not None:
self._messages.append(message)
+ # we ensure that the channel is either a TextChannel or Thread
if channel and channel.__class__ in (TextChannel, Thread):
- channel.last_message_id = message.id
+ channel.last_message_id = message.id # type: ignore
- def parse_message_delete(self, data):
+ def parse_message_delete(self, data) -> None:
raw = RawMessageDeleteEvent(data)
found = self._get_message(raw.message_id)
raw.cached_message = found
@@ -530,7 +567,7 @@ class ConnectionState:
self.dispatch('message_delete', found)
self._messages.remove(found)
- def parse_message_delete_bulk(self, data):
+ def parse_message_delete_bulk(self, data) -> None:
raw = RawBulkMessageDeleteEvent(data)
if self._messages:
found_messages = [message for message in self._messages if message.id in raw.message_ids]
@@ -541,9 +578,10 @@ class ConnectionState:
if found_messages:
self.dispatch('bulk_message_delete', found_messages)
for msg in found_messages:
- self._messages.remove(msg)
+ # self._messages won't be None here
+ self._messages.remove(msg) # type: ignore
- def parse_message_update(self, data):
+ def parse_message_update(self, data) -> None:
raw = RawMessageUpdateEvent(data)
message = self._get_message(raw.message_id)
if message is not None:
@@ -561,7 +599,7 @@ class ConnectionState:
if 'components' in data and self._view_store.is_message_tracked(raw.message_id):
self._view_store.update_from_message(raw.message_id, data['components'])
- def parse_message_reaction_add(self, data):
+ def parse_message_reaction_add(self, data) -> None:
emoji = data['emoji']
emoji_id = utils._get_as_snowflake(emoji, 'id')
emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get('animated', False), name=emoji['name'])
@@ -570,7 +608,10 @@ class ConnectionState:
member_data = data.get('member')
if member_data:
guild = self._get_guild(raw.guild_id)
- raw.member = Member(data=member_data, guild=guild, state=self)
+ if guild is not None:
+ raw.member = Member(data=member_data, guild=guild, state=self)
+ else:
+ raw.member = None
else:
raw.member = None
self.dispatch('raw_reaction_add', raw)
@@ -585,7 +626,7 @@ class ConnectionState:
if user:
self.dispatch('reaction_add', reaction, user)
- def parse_message_reaction_remove_all(self, data):
+ def parse_message_reaction_remove_all(self, data) -> None:
raw = RawReactionClearEvent(data)
self.dispatch('raw_reaction_clear', raw)
@@ -595,7 +636,7 @@ class ConnectionState:
message.reactions.clear()
self.dispatch('reaction_clear', message, old_reactions)
- def parse_message_reaction_remove(self, data):
+ def parse_message_reaction_remove(self, data) -> None:
emoji = data['emoji']
emoji_id = utils._get_as_snowflake(emoji, 'id')
emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name'])
@@ -614,7 +655,7 @@ class ConnectionState:
if user:
self.dispatch('reaction_remove', reaction, user)
- def parse_message_reaction_remove_emoji(self, data):
+ def parse_message_reaction_remove_emoji(self, data) -> None:
emoji = data['emoji']
emoji_id = utils._get_as_snowflake(emoji, 'id')
emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name'])
@@ -631,7 +672,7 @@ class ConnectionState:
if reaction:
self.dispatch('reaction_clear_emoji', reaction)
- def parse_interaction_create(self, data):
+ def parse_interaction_create(self, data) -> None:
interaction = Interaction(data=data, state=self)
if data['type'] == 3: # interaction component
custom_id = interaction.data['custom_id'] # type: ignore
@@ -640,8 +681,9 @@ class ConnectionState:
self.dispatch('interaction', interaction)
- def parse_presence_update(self, data):
+ def parse_presence_update(self, data) -> None:
guild_id = utils._get_as_snowflake(data, 'guild_id')
+ # guild_id won't be None here
guild = self._get_guild(guild_id)
if guild is None:
log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
@@ -661,21 +703,23 @@ class ConnectionState:
self.dispatch('presence_update', old_member, member)
- def parse_user_update(self, data):
- self.user._update(data)
- ref = self._users.get(self.user.id)
+ def parse_user_update(self, data) -> None:
+ # self.user is *always* cached when this is called
+ user: ClientUser = self.user # type: ignore
+ user._update(data)
+ ref = self._users.get(user.id)
if ref:
ref._update(data)
- def parse_invite_create(self, data):
+ def parse_invite_create(self, data) -> None:
invite = Invite.from_gateway(state=self, data=data)
self.dispatch('invite_create', invite)
- def parse_invite_delete(self, data):
+ def parse_invite_delete(self, data) -> None:
invite = Invite.from_gateway(state=self, data=data)
self.dispatch('invite_delete', invite)
- def parse_channel_delete(self, data):
+ def parse_channel_delete(self, data) -> None:
guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id'))
channel_id = int(data['id'])
if guild is not None:
@@ -684,13 +728,14 @@ class ConnectionState:
guild._remove_channel(channel)
self.dispatch('guild_channel_delete', channel)
- def parse_channel_update(self, data):
+ def parse_channel_update(self, data) -> None:
channel_type = try_enum(ChannelType, data.get('type'))
channel_id = int(data['id'])
if channel_type is ChannelType.group:
channel = self._get_private_channel(channel_id)
old_channel = copy.copy(channel)
- channel._update_group(data)
+ # the channel is a GroupChannel
+ channel._update_group(data) # type: ignore
self.dispatch('private_channel_update', old_channel, channel)
return
@@ -707,7 +752,7 @@ class ConnectionState:
else:
log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
- def parse_channel_create(self, data):
+ def parse_channel_create(self, data) -> None:
factory, ch_type = _channel_factory(data['type'])
if factory is None:
log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type'])
@@ -716,14 +761,15 @@ class ConnectionState:
guild_id = utils._get_as_snowflake(data, 'guild_id')
guild = self._get_guild(guild_id)
if guild is not None:
- channel = factory(guild=guild, state=self, data=data)
- guild._add_channel(channel)
+ # the factory can't be a DMChannel or GroupChannel here
+ channel = factory(guild=guild, state=self, data=data) # type: ignore
+ guild._add_channel(channel) # type: ignore
self.dispatch('guild_channel_create', channel)
else:
log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id)
return
- def parse_channel_pins_update(self, data):
+ def parse_channel_pins_update(self, data) -> None:
channel_id = int(data['channel_id'])
try:
guild = self._get_guild(int(data['guild_id']))
@@ -744,7 +790,7 @@ class ConnectionState:
else:
self.dispatch('guild_channel_pins_update', channel, last_pin)
- def parse_thread_create(self, data):
+ def parse_thread_create(self, data) -> None:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
@@ -757,7 +803,7 @@ class ConnectionState:
if not has_thread:
self.dispatch('thread_join', thread)
- def parse_thread_update(self, data):
+ def parse_thread_update(self, data) -> None:
guild_id = int(data['guild_id'])
guild = self._get_guild(guild_id)
if guild is None:
@@ -775,7 +821,7 @@ class ConnectionState:
guild._add_thread(thread)
self.dispatch('thread_join', thread)
- def parse_thread_delete(self, data):
+ def parse_thread_delete(self, data) -> None:
guild_id = int(data['guild_id'])
guild = self._get_guild(guild_id)
if guild is None:
@@ -785,10 +831,10 @@ class ConnectionState:
thread_id = int(data['id'])
thread = guild.get_thread(thread_id)
if thread is not None:
- guild._remove_thread(thread)
+ guild._remove_thread(thread) # type: ignore
self.dispatch('thread_delete', thread)
- def parse_thread_list_sync(self, data):
+ def parse_thread_list_sync(self, data) -> None:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
@@ -827,7 +873,7 @@ class ConnectionState:
for thread in previous_threads.values():
self.dispatch('thread_remove', thread)
- def parse_thread_member_update(self, data):
+ def parse_thread_member_update(self, data) -> None:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
@@ -843,7 +889,7 @@ class ConnectionState:
member = ThreadMember(thread, data)
thread.me = member
- def parse_thread_members_update(self, data):
+ def parse_thread_members_update(self, data) -> None:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
@@ -875,7 +921,7 @@ class ConnectionState:
else:
self.dispatch('thread_remove', thread)
- def parse_guild_member_add(self, data):
+ def parse_guild_member_add(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is None:
log.debug('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
@@ -892,7 +938,7 @@ class ConnectionState:
self.dispatch('member_join', member)
- def parse_guild_member_remove(self, data):
+ def parse_guild_member_remove(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
try:
@@ -903,12 +949,12 @@ class ConnectionState:
user_id = int(data['user']['id'])
member = guild.get_member(user_id)
if member is not None:
- guild._remove_member(member)
+ guild._remove_member(member) # type: ignore
self.dispatch('member_remove', member)
else:
log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
- def parse_guild_member_update(self, data):
+ def parse_guild_member_update(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
user = data['user']
user_id = int(user['id'])
@@ -937,7 +983,7 @@ class ConnectionState:
guild._add_member(member)
log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id)
- def parse_guild_emojis_update(self, data):
+ def parse_guild_emojis_update(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is None:
log.debug('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
@@ -946,10 +992,11 @@ class ConnectionState:
before_emojis = guild.emojis
for emoji in before_emojis:
self._emojis.pop(emoji.id, None)
- guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis']))
+ # guild won't be None here
+ guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) #type: ignore
self.dispatch('guild_emojis_update', guild, before_emojis, guild.emojis)
- def parse_guild_stickers_update(self, data):
+ def parse_guild_stickers_update(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is None:
log.debug('GUILD_STICKERS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
@@ -958,7 +1005,8 @@ class ConnectionState:
before_stickers = guild.stickers
for emoji in before_stickers:
self._stickers.pop(emoji.id, None)
- guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers']))
+ # guild won't be None here
+ guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) # type: ignore
self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers)
def _get_create_guild(self, data):
@@ -999,7 +1047,7 @@ class ConnectionState:
else:
self.dispatch('guild_join', guild)
- def parse_guild_create(self, data):
+ def parse_guild_create(self, data) -> None:
unavailable = data.get('unavailable')
if unavailable is True:
# joined a guild with unavailable == True so..
@@ -1027,7 +1075,7 @@ class ConnectionState:
else:
self.dispatch('guild_join', guild)
- def parse_guild_update(self, data):
+ def parse_guild_update(self, data) -> None:
guild = self._get_guild(int(data['id']))
if guild is not None:
old_guild = copy.copy(guild)
@@ -1036,7 +1084,7 @@ class ConnectionState:
else:
log.debug('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id'])
- def parse_guild_delete(self, data):
+ def parse_guild_delete(self, data) -> None:
guild = self._get_guild(int(data['id']))
if guild is None:
log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id'])
@@ -1051,12 +1099,12 @@ class ConnectionState:
# do a cleanup of the messages cache
if self._messages is not None:
- self._messages = deque((msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages)
+ self._messages: Optional[Deque[Message]] = deque((msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages)
self._remove_guild(guild)
self.dispatch('guild_remove', guild)
- def parse_guild_ban_add(self, data):
+ def parse_guild_ban_add(self, data) -> None:
# we make the assumption that GUILD_BAN_ADD is done
# before GUILD_MEMBER_REMOVE is called
# hence we don't remove it from cache or do anything
@@ -1072,13 +1120,13 @@ class ConnectionState:
member = guild.get_member(user.id) or user
self.dispatch('member_ban', guild, member)
- def parse_guild_ban_remove(self, data):
+ def parse_guild_ban_remove(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None and 'user' in data:
user = self.store_user(data['user'])
self.dispatch('member_unban', guild, user)
- def parse_guild_role_create(self, data):
+ def parse_guild_role_create(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is None:
log.debug('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
@@ -1089,7 +1137,7 @@ class ConnectionState:
guild._add_role(role)
self.dispatch('guild_role_create', role)
- def parse_guild_role_delete(self, data):
+ def parse_guild_role_delete(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
role_id = int(data['role_id'])
@@ -1102,7 +1150,7 @@ class ConnectionState:
else:
log.debug('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
- def parse_guild_role_update(self, data):
+ def parse_guild_role_update(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
role_data = data['role']
@@ -1115,12 +1163,13 @@ class ConnectionState:
else:
log.debug('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
- def parse_guild_members_chunk(self, data):
+ def parse_guild_members_chunk(self, data) -> None:
guild_id = int(data['guild_id'])
guild = self._get_guild(guild_id)
presences = data.get('presences', [])
- members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])]
+ # the guild won't be None here
+ members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore
log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id)
if presences:
@@ -1129,19 +1178,20 @@ class ConnectionState:
user = presence['user']
member_id = user['id']
member = member_dict.get(member_id)
- member._presence_update(presence, user)
+ if member is not None:
+ member._presence_update(presence, user)
complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count')
self.process_chunk_requests(guild_id, data.get('nonce'), members, complete)
- def parse_guild_integrations_update(self, data):
+ def parse_guild_integrations_update(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
self.dispatch('guild_integrations_update', guild)
else:
log.debug('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
- def parse_integration_create(self, data):
+ def parse_integration_create(self, data) -> None:
guild_id = int(data.pop('guild_id'))
guild = self._get_guild(guild_id)
if guild is not None:
@@ -1151,7 +1201,7 @@ class ConnectionState:
else:
log.debug('INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id)
- def parse_integration_update(self, data):
+ def parse_integration_update(self, data) -> None:
guild_id = int(data.pop('guild_id'))
guild = self._get_guild(guild_id)
if guild is not None:
@@ -1161,7 +1211,7 @@ class ConnectionState:
else:
log.debug('INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
- def parse_integration_delete(self, data):
+ def parse_integration_delete(self, data) -> None:
guild_id = int(data['guild_id'])
guild = self._get_guild(guild_id)
if guild is not None:
@@ -1170,7 +1220,7 @@ class ConnectionState:
else:
log.debug('INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.', guild_id)
- def parse_webhooks_update(self, data):
+ def parse_webhooks_update(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is None:
log.debug('WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding', data['guild_id'])
@@ -1182,7 +1232,7 @@ class ConnectionState:
else:
log.debug('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id'])
- def parse_stage_instance_create(self, data):
+ def parse_stage_instance_create(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
stage_instance = StageInstance(guild=guild, state=self, data=data)
@@ -1191,7 +1241,7 @@ class ConnectionState:
else:
log.debug('STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
- def parse_stage_instance_update(self, data):
+ def parse_stage_instance_update(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
stage_instance = guild._stage_instances.get(int(data['id']))
@@ -1204,7 +1254,7 @@ class ConnectionState:
else:
log.debug('STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
- def parse_stage_instance_delete(self, data):
+ def parse_stage_instance_delete(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
try:
@@ -1216,11 +1266,12 @@ class ConnectionState:
else:
log.debug('STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
- def parse_voice_state_update(self, data):
+ def parse_voice_state_update(self, data) -> None:
guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id'))
channel_id = utils._get_as_snowflake(data, 'channel_id')
flags = self.member_cache_flags
- self_id = self.user.id
+ # self.user is *always* cached when this is called
+ self_id = self.user.id # type: ignore
if guild is not None:
if int(data['user_id']) == self_id:
voice = self._get_voice_client(guild.id)
@@ -1228,12 +1279,13 @@ class ConnectionState:
coro = voice.on_voice_state_update(data)
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler'))
- member, before, after = guild._update_voice_state(data, channel_id)
+ member, before, after = guild._update_voice_state(data, channel_id) # type: ignore
if member is not None:
if flags.voice:
if channel_id is None and flags._voice_only and member.id != self_id:
- # Only remove from cache iff we only have the voice flag enabled
- guild._remove_member(member)
+ # Only remove from cache if we only have the voice flag enabled
+ # Member doesn't meet the Snowflake protocol currently
+ guild._remove_member(member) # type: ignore
elif channel_id is not None:
guild._add_member(member)
@@ -1241,7 +1293,7 @@ class ConnectionState:
else:
log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id'])
- def parse_voice_server_update(self, data):
+ def parse_voice_server_update(self, data) -> None:
try:
key_id = int(data['guild_id'])
except KeyError:
@@ -1252,15 +1304,18 @@ class ConnectionState:
coro = vc.on_voice_server_update(data)
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
- def parse_typing_start(self, data):
+ def parse_typing_start(self, data) -> None:
channel, guild = self._get_guild_channel(data)
if channel is not None:
member = None
user_id = utils._get_as_snowflake(data, 'user_id')
if isinstance(channel, DMChannel):
member = channel.recipient
+
elif isinstance(channel, (Thread, TextChannel)) and guild is not None:
- member = guild.get_member(user_id)
+ # user_id won't be None
+ member = guild.get_member(user_id) # type: ignore
+
if member is None:
member_data = data.get('member')
if member_data:
@@ -1273,12 +1328,12 @@ class ConnectionState:
timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc)
self.dispatch('typing', channel, member, timestamp)
- def _get_reaction_user(self, channel, user_id):
+ def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]:
if isinstance(channel, TextChannel):
return channel.guild.get_member(user_id)
return self.get_user(user_id)
- def get_reaction_emoji(self, data):
+ def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]:
emoji_id = utils._get_as_snowflake(data, 'id')
if not emoji_id:
@@ -1289,7 +1344,7 @@ class ConnectionState:
except KeyError:
return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name'])
- def _upgrade_partial_emoji(self, emoji):
+ def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]:
emoji_id = emoji.id
if not emoji_id:
return emoji.name
@@ -1298,7 +1353,7 @@ class ConnectionState:
except KeyError:
return emoji
- def get_channel(self, id):
+ def get_channel(self, id: Optional[int]) -> Optional[Union[Channel, Thread]]:
if id is None:
return None
@@ -1311,18 +1366,18 @@ class ConnectionState:
if channel is not None:
return channel
- def create_message(self, *, channel, data):
+ def create_message(self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload) -> Message:
return Message(state=self, channel=channel, data=data)
class AutoShardedConnectionState(ConnectionState):
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
- self._ready_task = None
- self.shard_ids = ()
- self.shards_launched = asyncio.Event()
+ self.shard_ids: Union[List[int], range] = []
+ self.shards_launched: asyncio.Event = asyncio.Event()
- def _update_message_references(self):
- for msg in self._messages:
+ def _update_message_references(self) -> None:
+ # self._messages won't be None when this is called
+ for msg in self._messages: # type: ignore
if not msg.guild:
continue
@@ -1330,13 +1385,14 @@ class AutoShardedConnectionState(ConnectionState):
if new_guild is not None and new_guild is not msg.guild:
channel_id = msg.channel.id
channel = new_guild._resolve_channel(channel_id) or Object(id=channel_id)
- msg._rebind_cached_references(new_guild, channel)
+ # channel will either be a TextChannel, Thread or Object
+ msg._rebind_cached_references(new_guild, channel) # type: ignore
- async def chunker(self, guild_id, query='', limit=0, presences=False, *, shard_id=None, nonce=None):
+ async def chunker(self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, shard_id: Optional[int] = None, nonce: Optional[str] = None) -> None:
ws = self._get_websocket(guild_id, shard_id=shard_id)
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
- async def _delay_ready(self):
+ async def _delay_ready(self) -> None:
await self.shards_launched.wait()
processed = []
max_concurrency = len(self.shard_ids) * 2
@@ -1403,12 +1459,13 @@ class AutoShardedConnectionState(ConnectionState):
self.call_handlers('ready')
self.dispatch('ready')
- def parse_ready(self, data):
+ def parse_ready(self, data) -> None:
if not hasattr(self, '_ready_state'):
self._ready_state = asyncio.Queue()
self.user = user = ClientUser(state=self, data=data['user'])
- self._users[user.id] = user
+ # self._users is a list of Users, we're setting a ClientUser
+ self._users[user.id] = user # type: ignore
if self.application_id is None:
try:
@@ -1431,6 +1488,6 @@ class AutoShardedConnectionState(ConnectionState):
if self._ready_task is None:
self._ready_task = asyncio.create_task(self._delay_ready())
- def parse_resumed(self, data):
+ def parse_resumed(self, data) -> None:
self.dispatch('resumed')
self.dispatch('shard_resumed', data['__shard_id__'])
diff --git a/discord/types/appinfo.py b/discord/types/appinfo.py
index d223837f..912d5ad5 100644
--- a/discord/types/appinfo.py
+++ b/discord/types/appinfo.py
@@ -61,6 +61,7 @@ class _PartialAppInfoOptional(TypedDict, total=False):
terms_of_service_url: str
privacy_policy_url: str
max_participants: int
+ flags: int
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
pass
diff --git a/discord/types/voice.py b/discord/types/voice.py
index b29288d4..82584025 100644
--- a/discord/types/voice.py
+++ b/discord/types/voice.py
@@ -24,14 +24,14 @@ DEALINGS IN THE SOFTWARE.
from typing import Optional, TypedDict, List, Literal
from .snowflake import Snowflake
-from .member import Member
+from .member import GatewayMember
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
class _PartialVoiceStateOptional(TypedDict, total=False):
- member: Member
+ member: GatewayMember
self_stream: bool