aboutsummaryrefslogtreecommitdiff
path: root/discord/iterators.py
diff options
context:
space:
mode:
Diffstat (limited to 'discord/iterators.py')
-rw-r--r--discord/iterators.py32
1 files changed, 30 insertions, 2 deletions
diff --git a/discord/iterators.py b/discord/iterators.py
index 5b3dd202..fbf1a72c 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -60,6 +60,9 @@ class LogsFromIterator:
Message before which all messages must be.
after : :class:`Message` or id-like
Message after which all messages must be.
+ around : :class:`Message` or id-like
+ Message around which all messages must be. Limit max 101. Note that if
+ limit is an even number, this will return at most limit+1 messages.
reverse : bool
If set to true, return messages in oldest->newest order. Recommended
when using with "after" queries with limit over 100, otherwise messages
@@ -67,17 +70,33 @@ class LogsFromIterator:
"""
def __init__(self, client, channel, limit,
- before=None, after=None, reverse=False):
+ before=None, after=None, around=None, reverse=False):
self.client = client
self.channel = channel
self.limit = limit
self.before = before
self.after = after
+ self.around = around
self.reverse = reverse
self._filter = None # message dict -> bool
self.messages = asyncio.Queue()
- if self.before and self.after:
+ if self.around:
+ if self.limit > 101:
+ raise ValueError("LogsFrom max limit 101 when specifying around parameter")
+ elif self.limit == 101:
+ self.limit = 100 # Thanks discord
+ elif self.limit == 1:
+ raise ValueError("Use get_message.")
+
+ self._retrieve_messages = self._retrieve_messages_around_strategy
+ if self.before and self.after:
+ self._filter = lambda m: int(self.after.id) < int(m['id']) < int(self.before.id)
+ elif self.before:
+ self._filter = lambda m: int(m['id']) < int(self.before.id)
+ elif self.after:
+ self._filter = lambda m: int(self.after.id) < int(m['id'])
+ elif self.before and self.after:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy
self._filter = lambda m: int(m['id']) < int(self.before.id)
@@ -131,6 +150,15 @@ class LogsFromIterator:
self.after = Object(id=data[0]['id'])
return data
+ @asyncio.coroutine
+ def _retrieve_messages_around_strategy(self, retrieve):
+ """Retrieve messages using around parameter."""
+ if self.around:
+ data = yield from self.client._logs_from(self.channel, retrieve, around=self.around)
+ self.around = None
+ return data
+ return []
+
if PY35:
@asyncio.coroutine
def __aiter__(self):