diff options
| author | Rapptz <[email protected]> | 2017-02-11 23:34:19 -0500 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2017-02-11 23:34:19 -0500 |
| commit | 2abdbc70c2637da33f35af69abc8cd559c0b05f7 (patch) | |
| tree | 19df7f0093a9c086a5b892c9470b11210267cb87 /discord/iterators.py | |
| parent | Add Client.get_user_profile to get an arbitrary user's profile. (diff) | |
| download | discord.py-2abdbc70c2637da33f35af69abc8cd559c0b05f7.tar.xz discord.py-2abdbc70c2637da33f35af69abc8cd559c0b05f7.zip | |
Implement utilities for AsyncIterator.
Closes #473.
Diffstat (limited to 'discord/iterators.py')
| -rw-r--r-- | discord/iterators.py | 145 |
1 files changed, 103 insertions, 42 deletions
diff --git a/discord/iterators.py b/discord/iterators.py index 5aa78089..31d72569 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -35,7 +35,108 @@ from .object import Object PY35 = sys.version_info >= (3, 5) -class ReactionIterator: +def _probably_coroutine(f, e): + if asyncio.iscoroutinefunction(f): + return (yield from f(e)) + else: + return f(e) + +class _AsyncIterator: + __slots__ = () + + def get(self, **attrs): + def predicate(elem): + for attr, val in attrs.items(): + nested = attr.split('__') + obj = elem + for attribute in nested: + obj = getattr(obj, attribute) + + if obj != val: + return False + return True + + return self.find(predicate) + + @asyncio.coroutine + def find(self, predicate): + while True: + try: + elem = yield from self.get() + except NoMoreItems: + return None + + ret = yield from _probably_coroutine(predicate, elem) + if ret: + return elem + + def map(self, func): + return _MappedAsyncIterator(self, func) + + def filter(self, predicate): + return _FilteredAsyncIterator(self, predicate) + + @asyncio.coroutine + def flatten(self): + ret = [] + while True: + try: + item = yield from self.get() + except NoMoreItems: + return ret + else: + ret.append(item) + + if PY35: + @asyncio.coroutine + def __aiter__(self): + return self + + @asyncio.coroutine + def __anext__(self): + try: + msg = yield from self.get() + except NoMoreItems: + raise StopAsyncIteration() + else: + return msg + +def _identity(x): + return x + +class _MappedAsyncIterator(_AsyncIterator): + def __init__(self, iterator, func): + self.iterator = iterator + self.func = func + + @asyncio.coroutine + def get(self): + # this raises NoMoreItems and will propagate appropriately + item = yield from self.iterator.get() + return (yield from _probably_coroutine(self.func, item)) + +class _FilteredAsyncIterator(_AsyncIterator): + def __init__(self, iterator, predicate): + self.iterator = iterator + + if predicate is None: + predicate = _identity + + self.predicate = predicate + + @asyncio.coroutine + def get(self): + getter = self.iterator.get + pred = self.predicate + while True: + # propagate NoMoreItems similar to _MappedAsyncIterator + item = yield from getter() + ret = yield from _probably_coroutine(pred, item) + if ret: + return item + +class ReactionIterator(_AsyncIterator): def __init__(self, message, emoji, limit=100, after=None): self.message = message self.limit = limit @@ -85,32 +186,7 @@ class ReactionIterator: else: yield from self.users.put(User(state=self.state, data=element)) - @asyncio.coroutine - def flatten(self): - ret = [] - while True: - try: - user = yield from self.get() - except NoMoreItems: - return ret - else: - ret.append(user) - - if PY35: - @asyncio.coroutine - def __aiter__(self): - return self - - @asyncio.coroutine - def __anext__(self): - try: - msg = yield from self.get() - except NoMoreItems: - raise StopAsyncIteration() - else: - return msg - -class HistoryIterator: +class HistoryIterator(_AsyncIterator): """Iterator for receiving a channel's message history. The messages endpoint has two behaviours we care about here: @@ -281,18 +357,3 @@ class HistoryIterator: self.around = None return data return [] - - if PY35: - @asyncio.coroutine - def __aiter__(self): - return self - - @asyncio.coroutine - def __anext__(self): - try: - msg = yield from self.get() - return msg - except NoMoreItems: - # if we're still empty at this point... - # we didn't get any new messages so stop looping - raise StopAsyncIteration() |