aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
authorNadir Chowdhury <[email protected]>2021-04-07 04:26:31 +0100
committerGitHub <[email protected]>2021-04-06 23:26:31 -0400
commit9f0c701a7a9a49d439d58e77ed065e01b0ca612b (patch)
tree43289180d36470a67f3946ecd24648b659a14f18 /discord
parent[commands] Use typing.get_type_hints to resolve ForwardRefs (diff)
downloaddiscord.py-9f0c701a7a9a49d439d58e77ed065e01b0ca612b.tar.xz
discord.py-9f0c701a7a9a49d439d58e77ed065e01b0ca612b.zip
use `typing.AsyncIterator` for iterators
Diffstat (limited to 'discord')
-rw-r--r--discord/iterators.py78
1 files changed, 48 insertions, 30 deletions
diff --git a/discord/iterators.py b/discord/iterators.py
index d67f3006..0bf47460 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -22,20 +22,43 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from __future__ import annotations
+
import asyncio
import datetime
+from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
from .errors import NoMoreItems
from .utils import time_snowflake, maybe_coroutine
from .object import Object
from .audit_logs import AuditLogEntry
+__all__ = (
+ 'ReactionIterator',
+ 'HistoryIterator',
+ 'AuditLogIterator',
+ 'GuildIterator',
+ 'MemberIterator',
+)
+
+if TYPE_CHECKING:
+ from .member import Member
+ from .user import User
+ from .message import Message
+ from .audit_logs import AuditLogEntry
+ from .guild import Guild
+
+T = TypeVar('T')
+OT = TypeVar('OT')
+_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
+_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
+
OLDEST_OBJECT = Object(id=0)
-class _AsyncIterator:
+class _AsyncIterator(AsyncIterator[T]):
__slots__ = ()
- def get(self, **attrs):
+ def get(self, **attrs: Any) -> Optional[T]:
def predicate(elem):
for attr, val in attrs.items():
nested = attr.split('__')
@@ -49,7 +72,7 @@ class _AsyncIterator:
return self.find(predicate)
- async def find(self, predicate):
+ async def find(self, predicate: _Predicate[T]) -> Optional[T]:
while True:
try:
elem = await self.next()
@@ -60,40 +83,35 @@ class _AsyncIterator:
if ret:
return elem
- def chunk(self, max_size):
+ def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0:
raise ValueError('async iterator chunk sizes must be greater than 0.')
return _ChunkedAsyncIterator(self, max_size)
- def map(self, func):
+ def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
return _MappedAsyncIterator(self, func)
- def filter(self, predicate):
+ def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
return _FilteredAsyncIterator(self, predicate)
- async def flatten(self):
+ async def flatten(self) -> List[T]:
return [element async for element in self]
- def __aiter__(self):
- return self
-
- async def __anext__(self):
+ async def __anext__(self) -> T:
try:
- msg = await self.next()
+ return await self.next()
except NoMoreItems:
raise StopAsyncIteration()
- else:
- return msg
def _identity(x):
return x
-class _ChunkedAsyncIterator(_AsyncIterator):
+class _ChunkedAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, max_size):
self.iterator = iterator
self.max_size = max_size
- async def next(self):
+ async def next(self) -> T:
ret = []
n = 0
while n < self.max_size:
@@ -108,17 +126,17 @@ class _ChunkedAsyncIterator(_AsyncIterator):
n += 1
return ret
-class _MappedAsyncIterator(_AsyncIterator):
+class _MappedAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, func):
self.iterator = iterator
self.func = func
- async def next(self):
+ async def next(self) -> T:
# this raises NoMoreItems and will propagate appropriately
item = await self.iterator.next()
return await maybe_coroutine(self.func, item)
-class _FilteredAsyncIterator(_AsyncIterator):
+class _FilteredAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, predicate):
self.iterator = iterator
@@ -127,7 +145,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
self.predicate = predicate
- async def next(self):
+ async def next(self) -> T:
getter = self.iterator.next
pred = self.predicate
while True:
@@ -137,7 +155,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
if ret:
return item
-class ReactionIterator(_AsyncIterator):
+class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
def __init__(self, message, emoji, limit=100, after=None):
self.message = message
self.limit = limit
@@ -150,7 +168,7 @@ class ReactionIterator(_AsyncIterator):
self.channel_id = message.channel.id
self.users = asyncio.Queue()
- async def next(self):
+ async def next(self) -> T:
if self.users.empty():
await self.fill_users()
@@ -185,7 +203,7 @@ class ReactionIterator(_AsyncIterator):
else:
await self.users.put(User(state=self.state, data=element))
-class HistoryIterator(_AsyncIterator):
+class HistoryIterator(_AsyncIterator['Message']):
"""Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here:
@@ -271,7 +289,7 @@ class HistoryIterator(_AsyncIterator):
if (self.after and self.after != OLDEST_OBJECT):
self._filter = lambda m: int(m['id']) > self.after.id
- async def next(self):
+ async def next(self) -> T:
if self.messages.empty():
await self.fill_messages()
@@ -342,7 +360,7 @@ class HistoryIterator(_AsyncIterator):
return data
return []
-class AuditLogIterator(_AsyncIterator):
+class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
@@ -404,7 +422,7 @@ class AuditLogIterator(_AsyncIterator):
self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries
- async def next(self):
+ async def next(self) -> T:
if self.entries.empty():
await self._fill()
@@ -447,7 +465,7 @@ class AuditLogIterator(_AsyncIterator):
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
-class GuildIterator(_AsyncIterator):
+class GuildIterator(_AsyncIterator['Guild']):
"""Iterator for receiving the client's guilds.
The guilds endpoint has the same two behaviours as described
@@ -501,7 +519,7 @@ class GuildIterator(_AsyncIterator):
else:
self._retrieve_guilds = self._retrieve_guilds_before_strategy
- async def next(self):
+ async def next(self) -> T:
if self.guilds.empty():
await self.fill_guilds()
@@ -559,7 +577,7 @@ class GuildIterator(_AsyncIterator):
self.after = Object(id=int(data[0]['id']))
return data
-class MemberIterator(_AsyncIterator):
+class MemberIterator(_AsyncIterator['Member']):
def __init__(self, guild, limit=1000, after=None):
if isinstance(after, datetime.datetime):
@@ -573,7 +591,7 @@ class MemberIterator(_AsyncIterator):
self.get_members = self.state.http.get_members
self.members = asyncio.Queue()
- async def next(self):
+ async def next(self) -> T:
if self.members.empty():
await self.fill_members()