aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/client.py80
-rw-r--r--discord/iterators.py71
2 files changed, 118 insertions, 33 deletions
diff --git a/discord/client.py b/discord/client.py
index 7e49aa89..98f62a44 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -41,6 +41,7 @@ from .permissions import Permissions
from . import utils
from .enums import ChannelType, ServerRegion
from .voice_client import VoiceClient
+from .iterators import LogsFromIterator
import asyncio
import aiohttp
@@ -53,6 +54,7 @@ import itertools
import zlib
from random import randint as random_integer
+PY35 = sys.version_info >= (3, 5)
log = logging.getLogger(__name__)
request_logging_format = '{method} {response.url} has returned {response.status}'
request_success_log = '{response.url} with {json} received {data}'
@@ -1115,24 +1117,6 @@ class Client:
@asyncio.coroutine
def _logs_from(self, channel, limit=100, before=None, after=None):
- url = '{}/{}/messages'.format(endpoints.CHANNELS, channel.id)
- params = {
- 'limit': limit
- }
-
- if before:
- params['before'] = before.id
- if after:
- params['after'] = after.id
-
- response = yield from aiohttp.get(url, params=params, headers=self.headers, loop=self.loop)
- log.debug(request_logging_format.format(method='GET', response=response))
- yield from utils._verify_successful_response(response)
- messages = yield from response.json()
- return messages
-
- @asyncio.coroutine
- def logs_from(self, channel, limit=100, *, before=None, after=None):
"""|coro|
This coroutine returns a generator that obtains logs from a specified channel.
@@ -1172,24 +1156,54 @@ class Client:
if message.content.startswith('!hello'):
if message.author == client.user:
yield from client.edit_message(message, 'goodbye')
+
+ Python 3.5 Usage ::
+
+ counter = 0
+ async for message in client.logs_from(channel, limit=500):
+ if message.author == client.user:
+ counter += 1
"""
+ url = '{}/{}/messages'.format(endpoints.CHANNELS, channel.id)
+ params = {
+ 'limit': limit
+ }
- def generator(data):
- for message in data:
- yield Message(channel=channel, **message)
-
- result = []
- while limit > 0:
- retrieve = limit if limit <= 100 else 100
- data = yield from self._logs_from(channel, retrieve, before, after)
- if len(data):
- limit -= retrieve
- result.extend(data)
- before = Object(id=data[-1]['id'])
- else:
- break
+ if before:
+ params['before'] = before.id
+ if after:
+ params['after'] = after.id
+
+ response = yield from aiohttp.get(url, params=params, headers=self.headers, loop=self.loop)
+ log.debug(request_logging_format.format(method='GET', response=response))
+ yield from utils._verify_successful_response(response)
+ messages = yield from response.json()
+ return messages
+
+ if PY35:
+ def logs_from(self, channel, limit=100, *, before=None, after=None):
+ return LogsFromIterator(self, channel, limit, before, after)
+ else:
+ @asyncio.coroutine
+ def logs_from(self, channel, limit=100, *, before=None, after=None):
+ def generator(data):
+ for message in data:
+ yield Message(channel=channel, **message)
+
+ result = []
+ while limit > 0:
+ retrieve = limit if limit <= 100 else 100
+ data = yield from self._logs_from(channel, retrieve, before, after)
+ if len(data):
+ limit -= retrieve
+ result.extend(data)
+ before = Object(id=data[-1]['id'])
+ else:
+ break
+
+ return generator(result)
- return generator(result)
+ logs_from.__doc__ = _logs_from.__doc__
# Member management
diff --git a/discord/iterators.py b/discord/iterators.py
new file mode 100644
index 00000000..b0cb0c77
--- /dev/null
+++ b/discord/iterators.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-2016 Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+import sys
+import asyncio
+import aiohttp
+from .message import Message
+from .object import Object
+
+PY35 = sys.version_info >= (3, 5)
+
+class LogsFromIterator:
+ def __init__(self, client, channel, limit, before, after):
+ self.client = client
+ self.channel = channel
+ self.limit = limit
+ self.before = before
+ self.after = after
+ self.messages = asyncio.LifoQueue()
+
+ @asyncio.coroutine
+ def fill_messages(self):
+ if self.limit > 0:
+ retrieve = self.limit if self.limit <= 100 else 100
+ data = yield from self.client._logs_from(self.channel, retrieve, self.before, self.after)
+ if len(data):
+ self.limit -= retrieve
+ self.before = Object(id=data[-1]['id'])
+ for element in data:
+ yield from self.messages.put(Message(channel=self.channel, **element))
+
+ if PY35:
+ @asyncio.coroutine
+ def __aiter__(self):
+ return self
+
+ @asyncio.coroutine
+ def __anext__(self):
+ if self.messages.empty():
+ yield from self.fill_messages()
+
+ try:
+ msg = self.messages.get_nowait()
+ return msg
+ except asyncio.QueueEmpty:
+ # if we're still empty at this point...
+ # we didn't get any new messages so stop looping
+ raise StopAsyncIteration()