aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/client.py24
-rw-r--r--discord/iterators.py138
2 files changed, 138 insertions, 24 deletions
diff --git a/discord/client.py b/discord/client.py
index 688e76e4..8f7a6416 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -1213,15 +1213,9 @@ class Client:
}
if before:
- if isinstance(before, datetime.datetime):
- params['before'] = utils.time_snowflake(before, high=False)
- else:
- params['before'] = before.id
+ params['before'] = before.id
if after:
- if isinstance(after, datetime.datetime):
- params['after'] = utils.time_snowflake(after, high=True)
- else:
- params['after'] = after.id
+ params['after'] = after.id
response = yield from self.session.get(url, params=params, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=response))
@@ -1230,11 +1224,21 @@ class Client:
return messages
if PY35:
- def logs_from(self, channel, limit=100, *, before=None, after=None):
- return LogsFromIterator(self, channel, limit, before, after)
+ def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False):
+ if isinstance(before, datetime.datetime):
+ before = Object(utils.time_snowflake(before, high=False))
+ if isinstance(after, datetime.datetime):
+ after = Object(utils.time_snowflake(after, high=True))
+
+ return LogsFromIterator.create(self, channel, limit, before=before, after=after, reverse=reverse)
else:
@asyncio.coroutine
def logs_from(self, channel, limit=100, *, before=None, after=None):
+ if isinstance(before, datetime.datetime):
+ before = Object(utils.time_snowflake(before, high=False))
+ if isinstance(after, datetime.datetime):
+ after = Object(utils.time_snowflake(after, high=True))
+
def generator(data):
for message in data:
yield Message(channel=channel, **message)
diff --git a/discord/iterators.py b/discord/iterators.py
index f74227d8..83483b3b 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -33,26 +33,51 @@ from .object import Object
PY35 = sys.version_info >= (3, 5)
class LogsFromIterator:
- def __init__(self, client, channel, limit, before, after):
+ @staticmethod
+ def create(client, channel, limit, *, before=None, after=None, reverse=False):
+ """Create a proper iterator depending on parameters.
+
+ The messages endpoint has two behaviors:
+ If `before` is specified, it returns the `limit` newest messages before `before`, sorted with newest first.
+ - Fill strategy - update 'before' to oldest message
+ If `after` is specified, it returns the `limit` oldest messages after `after`, sorted with newest first.
+ - Fill strategy - update 'after' to newest message
+ - If messages are not reversed, they will be out of order (99-0, 199-100, so on)
+
+ A note that if both before and after are specified, before is ignored by the messages endpoint.
+
+ Parameters
+ -----------
+ client : class:`Client`
+ channel : class:`Channel`
+ Channel from which to request logs
+ limit : int
+ Maximum number of messages to retrieve
+ before : :class:`Message` or id-like
+ Message before which all messages must be.
+ after : :class:`Message` or id-like
+ Message after which all messages must be.
+ reverse : bool
+ If set to true, return messages in oldest->newest order. Recommended when using with "after" queries,
+ otherwise messages will be out of order. Defaults to False for backwards compatability.
+ """
+ if before and after:
+ if reverse:
+ return LogsFromBeforeAfterReversedIterator(client, channel, limit, before, after)
+ else:
+ return LogsFromBeforeAfterIterator(client, channel, limit, before, after)
+ elif after:
+ return LogsFromAfterIterator(client, channel, limit, after, reverse=reverse)
+ else:
+ return LogsFromBeforeIterator(client, channel, limit, before)
+
+ def __init__(self, client, channel, limit):
self.client = client
self.channel = channel
self.limit = limit
- self.before = before
- self.after = after
self.messages = asyncio.Queue()
@asyncio.coroutine
- def fill_messages(self):
- if self.limit > 0:
- retrieve = self.limit if self.limit <= 100 else 100
- data = yield from self.client._logs_from(self.channel, retrieve, self.before, self.after)
- if len(data):
- self.limit -= retrieve
- self.before = Object(id=data[-1]['id'])
- for element in data:
- yield from self.messages.put(Message(channel=self.channel, **element))
-
- @asyncio.coroutine
def iterate(self):
if self.messages.empty():
yield from self.fill_messages()
@@ -73,3 +98,88 @@ class LogsFromIterator:
# if we're still empty at this point...
# we didn't get any new messages so stop looping
raise StopAsyncIteration()
+
+class LogsFromBeforeIterator(LogsFromIterator):
+ def __init__(self, client, channel, limit, before):
+ super().__init__(client, channel, limit)
+ self.before = before
+
+ @asyncio.coroutine
+ def fill_messages(self):
+ if self.limit > 0:
+ retrieve = self.limit if self.limit <= 100 else 100
+
+ data = yield from self.client._logs_from(self.channel, retrieve, before=self.before)
+ if len(data):
+ self.limit -= retrieve
+ self.before = Object(id=data[-1]['id'])
+ for element in data:
+ yield from self.messages.put(Message(channel=self.channel, **element))
+
+class LogsFromAfterIterator(LogsFromIterator):
+ """Iterator for retrieving "after" style responses.
+
+ Recommended to use with reverse=True - this will return messages oldest to newest.
+ With reverse=False, you'll recieve messages 99-0, 199-100, etc."""
+ def __init__(self, client, channel, limit, after, *, reverse=False):
+ super().__init__(client, channel, limit)
+ self.after = after
+ self.reverse = reverse
+
+ @asyncio.coroutine
+ def fill_messages(self):
+ if self.limit > 0:
+ retrieve = self.limit if self.limit <= 100 else 100
+
+ data = yield from self.client._logs_from(self.channel, retrieve, after=self.after)
+ if len(data):
+ self.limit -= retrieve
+ self.after = Object(id=data[0]['id'])
+ for element in (data if not self.reverse else reversed(data)):
+ yield from self.messages.put(Message(channel=self.channel, **element))
+
+class LogsFromBeforeAfterIterator(LogsFromIterator):
+ """Newest -> Oldest."""
+ def __init__(self, client, channel, limit, before, after):
+ super().__init__(client, channel, limit)
+ self.before = before
+ self.after = after
+
+ @asyncio.coroutine
+ def fill_messages(self):
+ if self.limit > 0:
+ retrieve = self.limit if self.limit <= 100 else 100
+
+ data = yield from self.client._logs_from(self.channel, retrieve, before=self.before)
+ if len(data):
+ self.limit -= retrieve
+ self.before = Object(id=data[-1]['id'])
+ # Only filter if the oldest message is not after our endpoint
+ if int(data[-1]['id']) <= int(self.after.id):
+ data = filter(lambda d: int(d['id']) > int(self.after.id), data)
+ for element in data:
+ yield from self.messages.put(Message(channel=self.channel, **element))
+
+class LogsFromBeforeAfterReversedIterator(LogsFromIterator):
+ """Oldest -> Newest."""
+ def __init__(self, client, channel, limit, before, after):
+ super().__init__(client, channel, limit)
+ self.before = before
+ self.after = after
+
+ @asyncio.coroutine
+ def fill_messages(self):
+ if self.limit > 0:
+ retrieve = self.limit if self.limit <= 100 else 100
+
+ data = yield from self.client._logs_from(self.channel, retrieve, after=self.after)
+ if len(data):
+ self.limit -= retrieve
+ self.after = Object(id=data[0]['id'])
+ # Only filter if the newest is not before our endpoint
+ if int(data[0]['id']) >= int(self.before.id):
+ data = filter(lambda d: int(d['id']) < int(self.before.id), reversed(data))
+ else:
+ data = reversed(data)
+ for element in data:
+ yield from self.messages.put(Message(channel=self.channel, **element))