aboutsummaryrefslogtreecommitdiff
path: root/discord/iterators.py
diff options
context:
space:
mode:
authorRapptz <[email protected]>2017-02-11 23:34:19 -0500
committerRapptz <[email protected]>2017-02-11 23:34:19 -0500
commit2abdbc70c2637da33f35af69abc8cd559c0b05f7 (patch)
tree19df7f0093a9c086a5b892c9470b11210267cb87 /discord/iterators.py
parentAdd Client.get_user_profile to get an arbitrary user's profile. (diff)
downloaddiscord.py-2abdbc70c2637da33f35af69abc8cd559c0b05f7.tar.xz
discord.py-2abdbc70c2637da33f35af69abc8cd559c0b05f7.zip
Implement utilities for AsyncIterator.
Closes #473.
Diffstat (limited to 'discord/iterators.py')
-rw-r--r--discord/iterators.py145
1 files changed, 103 insertions, 42 deletions
diff --git a/discord/iterators.py b/discord/iterators.py
index 5aa78089..31d72569 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -35,7 +35,108 @@ from .object import Object
PY35 = sys.version_info >= (3, 5)
-class ReactionIterator:
+def _probably_coroutine(f, e):
+ if asyncio.iscoroutinefunction(f):
+ return (yield from f(e))
+ else:
+ return f(e)
+
+class _AsyncIterator:
+ __slots__ = ()
+
+ def get(self, **attrs):
+ def predicate(elem):
+ for attr, val in attrs.items():
+ nested = attr.split('__')
+ obj = elem
+ for attribute in nested:
+ obj = getattr(obj, attribute)
+
+ if obj != val:
+ return False
+ return True
+
+ return self.find(predicate)
+
+ @asyncio.coroutine
+ def find(self, predicate):
+ while True:
+ try:
+ elem = yield from self.get()
+ except NoMoreItems:
+ return None
+
+ ret = yield from _probably_coroutine(predicate, elem)
+ if ret:
+ return elem
+
+ def map(self, func):
+ return _MappedAsyncIterator(self, func)
+
+ def filter(self, predicate):
+ return _FilteredAsyncIterator(self, predicate)
+
+ @asyncio.coroutine
+ def flatten(self):
+ ret = []
+ while True:
+ try:
+ item = yield from self.get()
+ except NoMoreItems:
+ return ret
+ else:
+ ret.append(item)
+
+ if PY35:
+ @asyncio.coroutine
+ def __aiter__(self):
+ return self
+
+ @asyncio.coroutine
+ def __anext__(self):
+ try:
+ msg = yield from self.get()
+ except NoMoreItems:
+ raise StopAsyncIteration()
+ else:
+ return msg
+
+def _identity(x):
+ return x
+
+class _MappedAsyncIterator(_AsyncIterator):
+ def __init__(self, iterator, func):
+ self.iterator = iterator
+ self.func = func
+
+ @asyncio.coroutine
+ def get(self):
+ # this raises NoMoreItems and will propagate appropriately
+ item = yield from self.iterator.get()
+ return (yield from _probably_coroutine(self.func, item))
+
+class _FilteredAsyncIterator(_AsyncIterator):
+ def __init__(self, iterator, predicate):
+ self.iterator = iterator
+
+ if predicate is None:
+ predicate = _identity
+
+ self.predicate = predicate
+
+ @asyncio.coroutine
+ def get(self):
+ getter = self.iterator.get
+ pred = self.predicate
+ while True:
+ # propagate NoMoreItems similar to _MappedAsyncIterator
+ item = yield from getter()
+ ret = yield from _probably_coroutine(pred, item)
+ if ret:
+ return item
+
+class ReactionIterator(_AsyncIterator):
def __init__(self, message, emoji, limit=100, after=None):
self.message = message
self.limit = limit
@@ -85,32 +186,7 @@ class ReactionIterator:
else:
yield from self.users.put(User(state=self.state, data=element))
- @asyncio.coroutine
- def flatten(self):
- ret = []
- while True:
- try:
- user = yield from self.get()
- except NoMoreItems:
- return ret
- else:
- ret.append(user)
-
- if PY35:
- @asyncio.coroutine
- def __aiter__(self):
- return self
-
- @asyncio.coroutine
- def __anext__(self):
- try:
- msg = yield from self.get()
- except NoMoreItems:
- raise StopAsyncIteration()
- else:
- return msg
-
-class HistoryIterator:
+class HistoryIterator(_AsyncIterator):
"""Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here:
@@ -281,18 +357,3 @@ class HistoryIterator:
self.around = None
return data
return []
-
- if PY35:
- @asyncio.coroutine
- def __aiter__(self):
- return self
-
- @asyncio.coroutine
- def __anext__(self):
- try:
- msg = yield from self.get()
- return msg
- except NoMoreItems:
- # if we're still empty at this point...
- # we didn't get any new messages so stop looping
- raise StopAsyncIteration()