aboutsummaryrefslogtreecommitdiff
path: root/discord/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'discord/utils.py')
-rw-r--r--discord/utils.py105
1 files changed, 103 insertions, 2 deletions
diff --git a/discord/utils.py b/discord/utils.py
index 88948da5..293103cd 100644
--- a/discord/utils.py
+++ b/discord/utils.py
@@ -31,13 +31,16 @@ from typing import (
AsyncIterator,
Callable,
Dict,
+ ForwardRef,
Generic,
Iterable,
Iterator,
List,
+ Literal,
Optional,
Protocol,
Sequence,
+ Tuple,
Type,
TypeVar,
Union,
@@ -53,6 +56,8 @@ from inspect import isawaitable as _isawaitable, signature as _signature
from operator import attrgetter
import json
import re
+import sys
+import types
import warnings
from .errors import InvalidArgument
@@ -99,6 +104,7 @@ if TYPE_CHECKING:
class _RequestLike(Protocol):
headers: Dict[str, Any]
+
else:
cached_property = _cached_property
@@ -741,6 +747,7 @@ def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
if ret:
yield ret
+
async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]:
ret = []
n = 0
@@ -767,9 +774,9 @@ def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T
def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
"""A helper function that collects an iterator into chunks of a given size.
-
+
.. versionadded:: 2.0
-
+
Parameters
----------
iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`]
@@ -793,3 +800,97 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size)
return _chunk(iterator, max_size)
+
+
+PY_310 = sys.version_info >= (3, 10)
+
+
+def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
+ params = []
+ literal_cls = type(Literal[0])
+ for p in parameters:
+ if isinstance(p, literal_cls):
+ params.extend(p.__args__)
+ else:
+ params.append(p)
+ return tuple(params)
+
+
+def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
+ none_cls = type(None)
+ return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
+
+
+def evaluate_annotation(
+ tp: Any,
+ globals: Dict[str, Any],
+ locals: Dict[str, Any],
+ cache: Dict[str, Any],
+ *,
+ implicit_str: bool = True,
+):
+ if isinstance(tp, ForwardRef):
+ tp = tp.__forward_arg__
+ # ForwardRefs always evaluate their internals
+ implicit_str = True
+
+ if implicit_str and isinstance(tp, str):
+ if tp in cache:
+ return cache[tp]
+ evaluated = eval(tp, globals, locals)
+ cache[tp] = evaluated
+ return evaluate_annotation(evaluated, globals, locals, cache)
+
+ if hasattr(tp, '__args__'):
+ implicit_str = True
+ is_literal = False
+ args = tp.__args__
+ if not hasattr(tp, '__origin__'):
+ if PY_310 and tp.__class__ is types.Union:
+ converted = Union[args] # type: ignore
+ return evaluate_annotation(converted, globals, locals, cache)
+
+ return tp
+ if tp.__origin__ is Union:
+ try:
+ if args.index(type(None)) != len(args) - 1:
+ args = normalise_optional_params(tp.__args__)
+ except ValueError:
+ pass
+ if tp.__origin__ is Literal:
+ if not PY_310:
+ args = flatten_literal_params(tp.__args__)
+ implicit_str = False
+ is_literal = True
+
+ evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
+
+ if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
+ raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
+
+ if evaluated_args == args:
+ return tp
+
+ try:
+ return tp.copy_with(evaluated_args)
+ except AttributeError:
+ return tp.__origin__[evaluated_args]
+
+ return tp
+
+
+def resolve_annotation(
+ annotation: Any,
+ globalns: Dict[str, Any],
+ localns: Optional[Dict[str, Any]],
+ cache: Optional[Dict[str, Any]],
+) -> Any:
+ if annotation is None:
+ return type(None)
+ if isinstance(annotation, str):
+ annotation = ForwardRef(annotation)
+
+ locals = globalns if localns is None else localns
+ if cache is None:
+ cache = {}
+ return evaluate_annotation(annotation, globalns, locals, cache)