aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/abc.py16
-rw-r--r--discord/channel.py3
-rw-r--r--discord/ext/commands/context.py1
-rw-r--r--discord/iterators.py16
4 files changed, 23 insertions, 13 deletions
diff --git a/discord/abc.py b/discord/abc.py
index 8ef15db3..89229279 100644
--- a/discord/abc.py
+++ b/discord/abc.py
@@ -467,6 +467,7 @@ class GuildChannel:
class Messageable(metaclass=abc.ABCMeta):
__slots__ = ()
+ @asyncio.coroutine
@abc.abstractmethod
def _get_channel(self):
raise NotImplementedError
@@ -534,7 +535,7 @@ class Messageable(metaclass=abc.ABCMeta):
The message that was sent.
"""
- channel = self._get_channel()
+ channel = yield from self._get_channel()
guild_id = self._get_guild_id()
state = self._state
content = str(content) if content else None
@@ -576,7 +577,7 @@ class Messageable(metaclass=abc.ABCMeta):
*Typing* indicator will go away after 10 seconds, or after a message is sent.
"""
- channel = self._get_channel()
+ channel = yield from self._get_channel()
yield from self._state.http.send_typing(channel.id)
def typing(self):
@@ -596,7 +597,8 @@ class Messageable(metaclass=abc.ABCMeta):
await channel.send_message('done!')
"""
- return Typing(self._get_channel())
+ channel = yield from self._get_channel()
+ return Typing(channel)
@asyncio.coroutine
def get_message(self, id):
@@ -626,7 +628,7 @@ class Messageable(metaclass=abc.ABCMeta):
Retrieving the message failed.
"""
- channel = self._get_channel()
+ channel = yield from self._get_channel()
data = yield from self._state.http.get_message(channel.id, id)
return state.create_message(channel=channel, data=data)
@@ -660,7 +662,7 @@ class Messageable(metaclass=abc.ABCMeta):
raise ClientException('Can only delete messages in the range of [2, 100]')
message_ids = [m.id for m in messages]
- channel = self._get_channel()
+ channel = yield from self._get_channel()
guild_id = self._get_guild_id()
yield from self._state.http.delete_messages(channel.id, message_ids, guild_id)
@@ -677,7 +679,7 @@ class Messageable(metaclass=abc.ABCMeta):
Retrieving the pinned messages failed.
"""
- channel = self._get_channel()
+ channel = yield from self._get_channel()
state = self._state
data = yield from state.http.pins_from(channel.id)
return [state.create_message(channel=channel, data=m) for m in data]
@@ -745,7 +747,7 @@ class Messageable(metaclass=abc.ABCMeta):
if message.author == client.user:
counter += 1
"""
- return LogsFromIterator(self._get_channel(), limit=limit, before=before, after=after, around=around, reverse=reverse)
+ return LogsFromIterator(self, limit=limit, before=before, after=after, around=around, reverse=reverse)
@asyncio.coroutine
def purge(self, *, limit=100, check=None, before=None, after=None, around=None):
diff --git a/discord/channel.py b/discord/channel.py
index 89a6051e..8efae546 100644
--- a/discord/channel.py
+++ b/discord/channel.py
@@ -88,6 +88,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
self.position = data['position']
self._fill_overwrites(data)
+ @asyncio.coroutine
def _get_channel(self):
return self
@@ -262,6 +263,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
self.me = me
self.id = int(data['id'])
+ @asyncio.coroutine
def _get_channel(self):
return self
@@ -360,6 +362,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
else:
self.owner = utils.find(lambda u: u.id == owner_id, self.recipients)
+ @asyncio.coroutine
def _get_channel(self):
return self
diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py
index ff78c562..59f09117 100644
--- a/discord/ext/commands/context.py
+++ b/discord/ext/commands/context.py
@@ -117,6 +117,7 @@ class Context(discord.abc.Messageable):
ret = yield from command.callback(*arguments, **kwargs)
return ret
+ @asyncio.coroutine
def _get_channel(self):
return self.channel
diff --git a/discord/iterators.py b/discord/iterators.py
index 86b5e9de..3ac75593 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -70,7 +70,7 @@ class LogsFromIterator:
will be out of order.
"""
- def __init__(self, channel, limit,
+ def __init__(self, messageable, limit,
before=None, after=None, around=None, reverse=None):
if isinstance(before, datetime.datetime):
@@ -80,9 +80,7 @@ class LogsFromIterator:
if isinstance(around, datetime.datetime):
around = Object(id=time_snowflake(around))
- self.channel = channel
- self.ctx = channel._state
- self.logs_from = channel._state.http.logs_from
+ self.messageable = messageable
self.limit = limit
self.before = before
self.after = after
@@ -135,6 +133,13 @@ class LogsFromIterator:
@asyncio.coroutine
def fill_messages(self):
+ if not hasattr(self, 'channel'):
+ # do the required set up
+ channel = yield from self.messageable._get_channel()
+ self.channel = channel
+ self.state = channel._state
+ self.logs_from = channel._state.http.logs_from
+
if self.limit > 0:
retrieve = self.limit if self.limit <= 100 else 100
data = yield from self._retrieve_messages(retrieve)
@@ -144,9 +149,8 @@ class LogsFromIterator:
data = filter(self._filter, data)
channel = self.channel
- state = self.ctx
for element in data:
- yield from self.messages.put(state.create_message(channel=channel, data=element))
+ yield from self.messages.put(self.state.create_message(channel=channel, data=element))
@asyncio.coroutine
def _retrieve_messages(self, retrieve):