aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/client.py12
-rw-r--r--discord/iterators.py5
-rw-r--r--discord/message.py4
-rw-r--r--discord/reaction.py18
-rw-r--r--discord/state.py45
5 files changed, 49 insertions, 35 deletions
diff --git a/discord/client.py b/discord/client.py
index ff247959..b1b87ca9 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -953,7 +953,7 @@ class Client:
data = yield from self.http.send_message(channel_id, content, guild_id=guild_id, tts=tts)
channel = self.get_channel(data.get('channel_id'))
- message = Message(channel=channel, **data)
+ message = self.connection._create_message(channel=channel, **data)
return message
@asyncio.coroutine
@@ -1035,7 +1035,7 @@ class Client:
data = yield from self.http.send_file(channel_id, buffer, guild_id=guild_id,
filename=filename, content=content, tts=tts)
channel = self.get_channel(data.get('channel_id'))
- message = Message(channel=channel, **data)
+ message = self.connection._create_message(channel=channel, **data)
return message
@asyncio.coroutine
@@ -1234,7 +1234,7 @@ class Client:
content = str(new_content)
guild_id = channel.server.id if not getattr(channel, 'is_private', True) else None
data = yield from self.http.edit_message(message.id, channel.id, content, guild_id=guild_id)
- return Message(channel=channel, **data)
+ return self.connection._create_message(channel=channel, **data)
@asyncio.coroutine
def get_message(self, channel, id):
@@ -1267,7 +1267,7 @@ class Client:
"""
data = yield from self.http.get_message(channel.id, id)
- return Message(channel=channel, **data)
+ return self.connection._create_message(channel=channel, **data)
@asyncio.coroutine
def pin_message(self, message):
@@ -1337,7 +1337,7 @@ class Client:
"""
data = yield from self.http.pins_from(channel.id)
- return [Message(channel=channel, **m) for m in data]
+ return [self.connection._create_message(channel=channel, **m) for m in data]
def _logs_from(self, channel, limit=100, before=None, after=None, around=None):
"""|coro|
@@ -1418,7 +1418,7 @@ class Client:
def generator(data):
for message in data:
- yield Message(channel=channel, **message)
+ yield self.connection._create_message(channel=channel, **message)
result = []
while limit > 0:
diff --git a/discord/iterators.py b/discord/iterators.py
index fbf1a72c..2ea51436 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -72,6 +72,7 @@ class LogsFromIterator:
def __init__(self, client, channel, limit,
before=None, after=None, around=None, reverse=False):
self.client = client
+ self.connection = client.connection
self.channel = channel
self.limit = limit
self.before = before
@@ -125,7 +126,9 @@ class LogsFromIterator:
if self._filter:
data = filter(self._filter, data)
for element in data:
- yield from self.messages.put(Message(channel=self.channel, **element))
+ yield from self.messages.put(
+ self.connection._create_message(
+ channel=self.channel, **element))
@asyncio.coroutine
def _retrieve_messages(self, retrieve):
diff --git a/discord/message.py b/discord/message.py
index d2bdf87e..e6e2fdd1 100644
--- a/discord/message.py
+++ b/discord/message.py
@@ -115,6 +115,9 @@ class Message:
'_system_content', 'reactions' ]
def __init__(self, **kwargs):
+ self.reactions = kwargs.pop('reactions')
+ for reaction in self.reactions:
+ reaction.message = self
self._update(**kwargs)
def _update(self, **data):
@@ -138,7 +141,6 @@ class Message:
self._handle_upgrades(data.get('channel_id'))
self._handle_mentions(data.get('mentions', []), data.get('mention_roles', []))
self._handle_call(data.get('call'))
- self.reactions = [Reaction(message=self, **reaction) for reaction in data.get('reactions', [])]
# clear the cached properties
cached = filter(lambda attr: attr[0] == '_', self.__slots__)
diff --git a/discord/reaction.py b/discord/reaction.py
index ec30fa22..7232a7b4 100644
--- a/discord/reaction.py
+++ b/discord/reaction.py
@@ -62,19 +62,11 @@ class Reaction:
__slots__ = ['message', 'count', 'emoji', 'me', 'custom_emoji']
def __init__(self, **kwargs):
- self.message = kwargs.pop('message')
- self._from_data(kwargs)
-
- def _from_data(self, reaction):
- self.count = reaction.get('count', 1)
- self.me = reaction.get('me')
- emoji = reaction['emoji']
- if emoji['id']:
- self.custom_emoji = True
- self.emoji = Emoji(server=None, id=emoji['id'], name=emoji['name'])
- else:
- self.custom_emoji = False
- self.emoji = emoji['name']
+ self.message = kwargs.get('message')
+ self.emoji = kwargs['emoji']
+ self.count = kwargs.get('count', 1)
+ self.me = kwargs.get('me')
+ self.custom_emoji = isinstance(self.emoji, Emoji)
def __eq__(self, other):
return isinstance(other, self.__class__) and other.emoji == self.emoji
diff --git a/discord/state.py b/discord/state.py
index 4d3855fc..00b0e06f 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -219,7 +219,7 @@ class ConnectionState:
def parse_message_create(self, data):
channel = self.get_channel(data.get('channel_id'))
- message = Message(channel=channel, **data)
+ message = self._create_message(channel=channel, **data)
self.dispatch('message', message)
self.messages.append(message)
@@ -255,17 +255,14 @@ class ConnectionState:
def parse_message_reaction_add(self, data):
message = self._get_message(data['message_id'])
if message is not None:
- if data['emoji']['id']:
- reaction_emoji = Emoji(server=None, **data['emoji'])
- else:
- reaction_emoji = data['emoji']['name']
- reaction = utils.get(
- message.reactions, emoji=reaction_emoji)
+ emoji = self._get_reaction_emoji(**data.pop('emoji'))
+ reaction = utils.get(message.reactions, emoji=emoji)
is_me = data['user_id'] == self.user.id
if not reaction:
- reaction = Reaction(message=message, me=is_me, **data)
+ reaction = Reaction(
+ message=message, emoji=emoji, me=is_me, **data)
message.reactions.append(reaction)
else:
reaction.count += 1
@@ -280,12 +277,8 @@ class ConnectionState:
def parse_message_reaction_remove(self, data):
message = self._get_message(data['message_id'])
if message is not None:
- if data['emoji']['id']:
- reaction_emoji = Emoji(server=None, **data['emoji'])
- else:
- reaction_emoji = data['emoji']['name']
- reaction = utils.get(
- message.reactions, emoji=reaction_emoji)
+ emoji = self._get_reaction_emoji(**data['emoji'])
+ reaction = utils.get(message.reactions, emoji=emoji)
# if reaction isn't in the list, we crash. This means discord
# sent bad data, or we stored improperly
@@ -680,6 +673,30 @@ class ConnectionState:
else:
return channel.server.get_member(id)
+ def _create_message(self, **message):
+ """Helper mostly for injecting reactions."""
+ reactions = [
+ self._create_reaction(**r) for r in message.pop('reactions', [])
+ ]
+ return Message(channel=message.pop('channel'),
+ reactions=reactions, **message)
+
+ def _create_reaction(self, **reaction):
+ emoji = self._get_reaction_emoji(**reaction.pop('emoji'))
+ return Reaction(emoji=emoji, **reaction)
+
+ def _get_reaction_emoji(self, **data):
+ id = data['id']
+
+ if id is None:
+ return data['name']
+
+ for server in self.servers:
+ for emoji in server.emojis:
+ if emoji.id == id:
+ return emoji
+ return Emoji(server=None, **data)
+
def get_channel(self, id):
if id is None:
return None