diff options
| author | Nadir Chowdhury <[email protected]> | 2021-04-08 01:28:12 +0100 |
|---|---|---|
| committer | GitHub <[email protected]> | 2021-04-07 20:28:12 -0400 |
| commit | f8bea3bb05fc9e3960d5ac4b2773e8dbe4f083c0 (patch) | |
| tree | 8b31bda35b170a1ec2474d73ae51b04ff10e85cc /discord/iterators.py | |
| parent | Add typings for channels and `PartialUser` (diff) | |
| download | discord.py-f8bea3bb05fc9e3960d5ac4b2773e8dbe4f083c0.tar.xz discord.py-f8bea3bb05fc9e3960d5ac4b2773e8dbe4f083c0.zip | |
Fix inaccuracies with `AsyncIterator` typings
Diffstat (limited to 'discord/iterators.py')
| -rw-r--r-- | discord/iterators.py | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/discord/iterators.py b/discord/iterators.py index 0bf47460..d717d83f 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -26,7 +26,7 @@ from __future__ import annotations import asyncio import datetime -from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine +from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator from .errors import NoMoreItems from .utils import time_snowflake, maybe_coroutine @@ -50,16 +50,18 @@ if TYPE_CHECKING: T = TypeVar('T') OT = TypeVar('OT') -_Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]] -_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]] +_Func = Callable[[T], Union[OT, Awaitable[OT]]] OLDEST_OBJECT = Object(id=0) class _AsyncIterator(AsyncIterator[T]): __slots__ = () - def get(self, **attrs: Any) -> Optional[T]: - def predicate(elem): + async def next(self) -> T: + raise NotImplementedError + + def get(self, **attrs: Any) -> Awaitable[Optional[T]]: + def predicate(elem: T): for attr, val in attrs.items(): nested = attr.split('__') obj = elem @@ -72,7 +74,7 @@ class _AsyncIterator(AsyncIterator[T]): return self.find(predicate) - async def find(self, predicate: _Predicate[T]) -> Optional[T]: + async def find(self, predicate: _Func[T, bool]) -> Optional[T]: while True: try: elem = await self.next() @@ -91,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]): def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: return _MappedAsyncIterator(self, func) - def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]: + def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: return _FilteredAsyncIterator(self, predicate) async def flatten(self) -> List[T]: @@ -106,13 +108,13 @@ class _AsyncIterator(AsyncIterator[T]): def _identity(x): return x -class _ChunkedAsyncIterator(_AsyncIterator[T]): +class _ChunkedAsyncIterator(_AsyncIterator[List[T]]): def __init__(self, iterator, max_size): self.iterator = iterator self.max_size = max_size - async def next(self) -> T: - ret = [] + async def next(self) -> List[T]: + ret: List[T] = [] n = 0 while n < self.max_size: try: @@ -168,7 +170,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): self.channel_id = message.channel.id self.users = asyncio.Queue() - async def next(self) -> T: + async def next(self) -> Union[User, Member]: if self.users.empty(): await self.fill_users() @@ -289,7 +291,7 @@ class HistoryIterator(_AsyncIterator['Message']): if (self.after and self.after != OLDEST_OBJECT): self._filter = lambda m: int(m['id']) > self.after.id - async def next(self) -> T: + async def next(self) -> Message: if self.messages.empty(): await self.fill_messages() @@ -422,7 +424,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): self.after = Object(id=int(entries[0]['id'])) return data.get('users', []), entries - async def next(self) -> T: + async def next(self) -> AuditLogEntry: if self.entries.empty(): await self._fill() @@ -519,7 +521,7 @@ class GuildIterator(_AsyncIterator['Guild']): else: self._retrieve_guilds = self._retrieve_guilds_before_strategy - async def next(self) -> T: + async def next(self) -> Guild: if self.guilds.empty(): await self.fill_guilds() @@ -591,7 +593,7 @@ class MemberIterator(_AsyncIterator['Member']): self.get_members = self.state.http.get_members self.members = asyncio.Queue() - async def next(self) -> T: + async def next(self) -> Member: if self.members.empty(): await self.fill_members() |