diff options
Diffstat (limited to 'discord/state.py')
| -rw-r--r-- | discord/state.py | 115 |
1 files changed, 49 insertions, 66 deletions
diff --git a/discord/state.py b/discord/state.py index c9a330b3..9df9f365 100644 --- a/discord/state.py +++ b/discord/state.py @@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE. from .guild import Guild from .user import User from .game import Game -from .emoji import Emoji +from .emoji import Emoji, PartialEmoji from .reaction import Reaction from .message import Message from .channel import * @@ -47,10 +47,16 @@ class ListenerType(enum.Enum): chunk = 0 Listener = namedtuple('Listener', ('type', 'future', 'predicate')) -StateContext = namedtuple('StateContext', 'store_user http self_id') log = logging.getLogger(__name__) ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) +class StateContext: + __slots__ = ('store_user', 'http', 'self_id', 'store_emoji', 'reaction_emoji') + + def __init__(self, **kwargs): + for attr, value in kwargs.items(): + setattr(self, attr, value) + class ConnectionState: def __init__(self, *, dispatch, chunker, syncer, http, loop, **options): self.loop = loop @@ -60,7 +66,10 @@ class ConnectionState: self.syncer = syncer self.is_bot = None self._listeners = [] - self.ctx = StateContext(store_user=self.store_user, http=http, self_id=None) + self.ctx = StateContext(store_user=self.store_user, + store_emoji=self.store_emoji, + reaction_emoji=self._get_reaction_emoji, + http=http, self_id=None) self.clear() def clear(self): @@ -69,6 +78,7 @@ class ConnectionState: self.session_id = None self._calls = {} self._users = {} + self._emojis = {} self._guilds = {} self._voice_clients = {} self._private_channels = {} @@ -128,6 +138,14 @@ class ConnectionState: self._users[user_id] = user = User(state=self.ctx, data=data) return user + def store_emoji(self, guild, data): + emoji_id = int(data['id']) + try: + return self._emojis[emoji_id] + except KeyError: + self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self.ctx, data=data) + return emoji + @property def guilds(self): return self._guilds.values() @@ -274,26 +292,11 @@ class ConnectionState: self.dispatch('message_edit', older_message, message) def parse_message_reaction_add(self, data): - message = self._get_message(data['message_id']) + message = self._get_message(int(data['message_id'])) if message is not None: - emoji = self._get_reaction_emoji(**data.pop('emoji')) - reaction = utils.get(message.reactions, emoji=emoji) - - is_me = data['user_id'] == self.user.id - - if not reaction: - reaction = Reaction( - message=message, emoji=emoji, me=is_me, **data) - message.reactions.append(reaction) - else: - reaction.count += 1 - if is_me: - reaction.me = True - - channel = self.get_channel(data['channel_id']) - member = self._get_member(channel, data['user_id']) - - self.dispatch('reaction_add', reaction, member) + reaction = message._add_reaction(data) + user = self._get_reaction_user(message.channel, int(data['user_id'])) + self.dispatch('reaction_add', reaction, user) def parse_message_reaction_remove_all(self, data): message = self._get_message(data['message_id']) @@ -303,26 +306,15 @@ class ConnectionState: self.dispatch('reaction_clear', message, old_reactions) def parse_message_reaction_remove(self, data): - message = self._get_message(data['message_id']) + message = self._get_message(int(data['message_id'])) if message is not None: - emoji = self._get_reaction_emoji(**data['emoji']) - reaction = utils.get(message.reactions, emoji=emoji) - - # Eventual consistency means we can get out of order or duplicate removes. - if not reaction: - log.warning("Unexpected reaction remove {}".format(data)) - return - - reaction.count -= 1 - if data['user_id'] == self.user.id: - reaction.me = False - if reaction.count == 0: - message.reactions.remove(reaction) - - channel = self.get_channel(data['channel_id']) - member = self._get_member(channel, data['user_id']) - - self.dispatch('reaction_remove', reaction, member) + try: + reaction = message._remove_reaction(data) + except (AttributeError, ValueError) as e: # eventual consistency lol + pass + else: + user = self._get_reaction_user(message.channel, int(data['user_id'])) + self.dispatch('reaction_remove', reaction, user) def parse_presence_update(self, data): guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) @@ -462,7 +454,7 @@ class ConnectionState: def parse_guild_emojis_update(self, data): guild = self._get_guild(int(data['guild_id'])) before_emojis = guild.emojis - guild.emojis = [Emoji(guild=guild, data=e, state=self.ctx) for e in data.get('emojis', [])] + guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) self.dispatch('guild_emojis_update', before_emojis, guild.emojis) def _get_create_guild(self, data): @@ -675,35 +667,26 @@ class ConnectionState: if call is not None: self.dispatch('call_remove', call) - def _get_member(self, channel, id): - if channel.is_private: - return utils.get(channel.recipients, id=id) + def _get_reaction_user(self, channel, user_id): + if isinstance(channel, DMChannel) and user_id == channel.recipient.id: + return channel.recipient + elif isinstance(channel, TextChannel): + return channel.guild.get_member(user_id) + elif isinstance(channel, GroupChannel): + return utils.find(lambda m: m.id == user_id, channel.recipients) else: - return channel.server.get_member(id) - - def _create_message(self, **message): - """Helper mostly for injecting reactions.""" - reactions = [ - self._create_reaction(**r) for r in message.pop('reactions', []) - ] - return Message(channel=message.pop('channel'), - reactions=reactions, **message) - - def _create_reaction(self, **reaction): - emoji = self._get_reaction_emoji(**reaction.pop('emoji')) - return Reaction(emoji=emoji, **reaction) + return None - def _get_reaction_emoji(self, **data): - id = data['id'] + def _get_reaction_emoji(self, data): + emoji_id = utils._get_as_snowflake(data, 'id') - if not id: + if not emoji_id: return data['name'] - for server in self.servers: - for emoji in server.emojis: - if emoji.id == id: - return emoji - return Emoji(server=None, **data) + try: + return self._emojis[emoji_id] + except KeyError: + return PartialEmoji(id=emoji_id, name=data['name']) def get_channel(self, id): if id is None: |