aboutsummaryrefslogtreecommitdiff
path: root/discord/iterators.py
diff options
context:
space:
mode:
authorNadir Chowdhury <[email protected]>2021-04-08 01:28:12 +0100
committerGitHub <[email protected]>2021-04-07 20:28:12 -0400
commitf8bea3bb05fc9e3960d5ac4b2773e8dbe4f083c0 (patch)
tree8b31bda35b170a1ec2474d73ae51b04ff10e85cc /discord/iterators.py
parentAdd typings for channels and `PartialUser` (diff)
downloaddiscord.py-f8bea3bb05fc9e3960d5ac4b2773e8dbe4f083c0.tar.xz
discord.py-f8bea3bb05fc9e3960d5ac4b2773e8dbe4f083c0.zip
Fix inaccuracies with `AsyncIterator` typings
Diffstat (limited to 'discord/iterators.py')
-rw-r--r--discord/iterators.py32
1 files changed, 17 insertions, 15 deletions
diff --git a/discord/iterators.py b/discord/iterators.py
index 0bf47460..d717d83f 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -26,7 +26,7 @@ from __future__ import annotations
import asyncio
import datetime
-from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
+from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
from .errors import NoMoreItems
from .utils import time_snowflake, maybe_coroutine
@@ -50,16 +50,18 @@ if TYPE_CHECKING:
T = TypeVar('T')
OT = TypeVar('OT')
-_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
-_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
+_Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0)
class _AsyncIterator(AsyncIterator[T]):
__slots__ = ()
- def get(self, **attrs: Any) -> Optional[T]:
- def predicate(elem):
+ async def next(self) -> T:
+ raise NotImplementedError
+
+ def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
+ def predicate(elem: T):
for attr, val in attrs.items():
nested = attr.split('__')
obj = elem
@@ -72,7 +74,7 @@ class _AsyncIterator(AsyncIterator[T]):
return self.find(predicate)
- async def find(self, predicate: _Predicate[T]) -> Optional[T]:
+ async def find(self, predicate: _Func[T, bool]) -> Optional[T]:
while True:
try:
elem = await self.next()
@@ -91,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]):
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
return _MappedAsyncIterator(self, func)
- def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
+ def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]:
return _FilteredAsyncIterator(self, predicate)
async def flatten(self) -> List[T]:
@@ -106,13 +108,13 @@ class _AsyncIterator(AsyncIterator[T]):
def _identity(x):
return x
-class _ChunkedAsyncIterator(_AsyncIterator[T]):
+class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
def __init__(self, iterator, max_size):
self.iterator = iterator
self.max_size = max_size
- async def next(self) -> T:
- ret = []
+ async def next(self) -> List[T]:
+ ret: List[T] = []
n = 0
while n < self.max_size:
try:
@@ -168,7 +170,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
self.channel_id = message.channel.id
self.users = asyncio.Queue()
- async def next(self) -> T:
+ async def next(self) -> Union[User, Member]:
if self.users.empty():
await self.fill_users()
@@ -289,7 +291,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if (self.after and self.after != OLDEST_OBJECT):
self._filter = lambda m: int(m['id']) > self.after.id
- async def next(self) -> T:
+ async def next(self) -> Message:
if self.messages.empty():
await self.fill_messages()
@@ -422,7 +424,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries
- async def next(self) -> T:
+ async def next(self) -> AuditLogEntry:
if self.entries.empty():
await self._fill()
@@ -519,7 +521,7 @@ class GuildIterator(_AsyncIterator['Guild']):
else:
self._retrieve_guilds = self._retrieve_guilds_before_strategy
- async def next(self) -> T:
+ async def next(self) -> Guild:
if self.guilds.empty():
await self.fill_guilds()
@@ -591,7 +593,7 @@ class MemberIterator(_AsyncIterator['Member']):
self.get_members = self.state.http.get_members
self.members = asyncio.Queue()
- async def next(self) -> T:
+ async def next(self) -> Member:
if self.members.empty():
await self.fill_members()