diff options
Diffstat (limited to 'discord/iterators.py')
| -rw-r--r-- | discord/iterators.py | 126 |
1 files changed, 52 insertions, 74 deletions
diff --git a/discord/iterators.py b/discord/iterators.py index e6f5234b..4e57faa7 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -33,8 +33,6 @@ from .utils import time_snowflake, maybe_coroutine from .object import Object from .audit_logs import AuditLogEntry -PY35 = sys.version_info >= (3, 5) - class _AsyncIterator: __slots__ = () @@ -52,15 +50,14 @@ class _AsyncIterator: return self.find(predicate) - @asyncio.coroutine - def find(self, predicate): + async def find(self, predicate): while True: try: - elem = yield from self.next() + elem = await self.next() except NoMoreItems: return None - ret = yield from maybe_coroutine(predicate, elem) + ret = await maybe_coroutine(predicate, elem) if ret: return elem @@ -70,30 +67,26 @@ class _AsyncIterator: def filter(self, predicate): return _FilteredAsyncIterator(self, predicate) - @asyncio.coroutine - def flatten(self): + async def flatten(self): ret = [] while True: try: - item = yield from self.next() + item = await self.next() except NoMoreItems: return ret else: ret.append(item) - if PY35: - @asyncio.coroutine - def __aiter__(self): - return self + async def __aiter__(self): + return self - @asyncio.coroutine - def __anext__(self): - try: - msg = yield from self.next() - except NoMoreItems: - raise StopAsyncIteration() - else: - return msg + async def __anext__(self): + try: + msg = await self.next() + except NoMoreItems: + raise StopAsyncIteration() + else: + return msg def _identity(x): return x @@ -103,11 +96,10 @@ class _MappedAsyncIterator(_AsyncIterator): self.iterator = iterator self.func = func - @asyncio.coroutine - def next(self): + async def next(self): # this raises NoMoreItems and will propagate appropriately - item = yield from self.iterator.next() - return (yield from maybe_coroutine(self.func, item)) + item = await self.iterator.next() + return (await maybe_coroutine(self.func, item)) class _FilteredAsyncIterator(_AsyncIterator): def __init__(self, iterator, predicate): @@ -118,14 +110,13 @@ class _FilteredAsyncIterator(_AsyncIterator): self.predicate = predicate - @asyncio.coroutine - def next(self): + async def next(self): getter = self.iterator.next pred = self.predicate while True: # propagate NoMoreItems similar to _MappedAsyncIterator - item = yield from getter() - ret = yield from maybe_coroutine(pred, item) + item = await getter() + ret = await maybe_coroutine(pred, item) if ret: return item @@ -142,18 +133,16 @@ class ReactionIterator(_AsyncIterator): self.channel_id = message.channel.id self.users = asyncio.Queue(loop=state.loop) - @asyncio.coroutine - def next(self): + async def next(self): if self.users.empty(): - yield from self.fill_users() + await self.fill_users() try: return self.users.get_nowait() except asyncio.QueueEmpty: raise NoMoreItems() - @asyncio.coroutine - def fill_users(self): + async def fill_users(self): # this is a hack because >circular imports< from .user import User @@ -161,7 +150,7 @@ class ReactionIterator(_AsyncIterator): retrieve = self.limit if self.limit <= 100 else 100 after = self.after.id if self.after else None - data = yield from self.getter(self.message.id, self.channel_id, self.emoji, retrieve, after=after) + data = await self.getter(self.message.id, self.channel_id, self.emoji, retrieve, after=after) if data: self.limit -= retrieve @@ -169,15 +158,15 @@ class ReactionIterator(_AsyncIterator): if self.guild is None: for element in reversed(data): - yield from self.users.put(User(state=self.state, data=element)) + await self.users.put(User(state=self.state, data=element)) else: for element in reversed(data): member_id = int(element['id']) member = self.guild.get_member(member_id) if member is not None: - yield from self.users.put(member) + await self.users.put(member) else: - yield from self.users.put(User(state=self.state, data=element)) + await self.users.put(User(state=self.state, data=element)) class HistoryIterator(_AsyncIterator): """Iterator for receiving a channel's message history. @@ -270,10 +259,9 @@ class HistoryIterator(_AsyncIterator): else: self._retrieve_messages = self._retrieve_messages_before_strategy - @asyncio.coroutine - def next(self): + async def next(self): if self.messages.empty(): - yield from self.fill_messages() + await self.fill_messages() try: return self.messages.get_nowait() @@ -292,15 +280,14 @@ class HistoryIterator(_AsyncIterator): self.retrieve = r return r > 0 - @asyncio.coroutine - def flatten(self): + async def flatten(self): # this is similar to fill_messages except it uses a list instead # of a queue to place the messages in. result = [] - channel = yield from self.messageable._get_channel() + channel = await self.messageable._get_channel() self.channel = channel while self._get_retrieve(): - data = yield from self._retrieve_messages(self.retrieve) + data = await self._retrieve_messages(self.retrieve) if len(data) < 100: self.limit = 0 # terminate the infinite loop @@ -313,15 +300,14 @@ class HistoryIterator(_AsyncIterator): result.append(self.state.create_message(channel=channel, data=element)) return result - @asyncio.coroutine - def fill_messages(self): + async def fill_messages(self): if not hasattr(self, 'channel'): # do the required set up - channel = yield from self.messageable._get_channel() + channel = await self.messageable._get_channel() self.channel = channel if self._get_retrieve(): - data = yield from self._retrieve_messages(self.retrieve) + data = await self._retrieve_messages(self.retrieve) if self.limit is None and len(data) < 100: self.limit = 0 # terminate the infinite loop @@ -332,41 +318,37 @@ class HistoryIterator(_AsyncIterator): channel = self.channel for element in data: - yield from self.messages.put(self.state.create_message(channel=channel, data=element)) + await self.messages.put(self.state.create_message(channel=channel, data=element)) - @asyncio.coroutine - def _retrieve_messages(self, retrieve): + async def _retrieve_messages(self, retrieve): """Retrieve messages and update next parameters.""" pass - @asyncio.coroutine - def _retrieve_messages_before_strategy(self, retrieve): + async def _retrieve_messages_before_strategy(self, retrieve): """Retrieve messages using before parameter.""" before = self.before.id if self.before else None - data = yield from self.logs_from(self.channel.id, retrieve, before=before) + data = await self.logs_from(self.channel.id, retrieve, before=before) if len(data): if self.limit is not None: self.limit -= retrieve self.before = Object(id=int(data[-1]['id'])) return data - @asyncio.coroutine - def _retrieve_messages_after_strategy(self, retrieve): + async def _retrieve_messages_after_strategy(self, retrieve): """Retrieve messages using after parameter.""" after = self.after.id if self.after else None - data = yield from self.logs_from(self.channel.id, retrieve, after=after) + data = await self.logs_from(self.channel.id, retrieve, after=after) if len(data): if self.limit is not None: self.limit -= retrieve self.after = Object(id=int(data[0]['id'])) return data - @asyncio.coroutine - def _retrieve_messages_around_strategy(self, retrieve): + async def _retrieve_messages_around_strategy(self, retrieve): """Retrieve messages using around parameter.""" if self.around: around = self.around.id if self.around else None - data = yield from self.logs_from(self.channel.id, retrieve, around=around) + data = await self.logs_from(self.channel.id, retrieve, around=around) self.around = None return data return [] @@ -411,10 +393,9 @@ class AuditLogIterator(_AsyncIterator): else: self._strategy = self._before_strategy - @asyncio.coroutine - def _before_strategy(self, retrieve): + async def _before_strategy(self, retrieve): before = self.before.id if self.before else None - data = yield from self.request(self.guild.id, limit=retrieve, user_id=self.user_id, + data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before) entries = data.get('audit_log_entries', []) @@ -424,10 +405,9 @@ class AuditLogIterator(_AsyncIterator): self.before = Object(id=int(entries[-1]['id'])) return data.get('users', []), entries - @asyncio.coroutine - def _after_strategy(self, retrieve): + async def _after_strategy(self, retrieve): after = self.after.id if self.after else None - data = yield from self.request(self.guild.id, limit=retrieve, user_id=self.user_id, + data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after) entries = data.get('audit_log_entries', []) if len(data) and entries: @@ -436,10 +416,9 @@ class AuditLogIterator(_AsyncIterator): self.after = Object(id=int(entries[0]['id'])) return data.get('users', []), entries - @asyncio.coroutine - def next(self): + async def next(self): if self.entries.empty(): - yield from self._fill() + await self._fill() try: return self.entries.get_nowait() @@ -458,12 +437,11 @@ class AuditLogIterator(_AsyncIterator): self.retrieve = r return r > 0 - @asyncio.coroutine - def _fill(self): + async def _fill(self): from .user import User if self._get_retrieve(): - users, data = yield from self._strategy(self.retrieve) + users, data = await self._strategy(self.retrieve) if self.limit is None and len(data) < 100: self.limit = 0 # terminate the infinite loop @@ -481,4 +459,4 @@ class AuditLogIterator(_AsyncIterator): if element['action_type'] is None: continue - yield from self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) + await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) |