diff options
| author | Rapptz <[email protected]> | 2021-04-27 05:47:26 -0400 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2021-04-27 05:48:27 -0400 |
| commit | 9f3551926ad5176c0cbf23a1a127452a2749a135 (patch) | |
| tree | 0e6cbe78da169ee23bbf646679f10daa2f988057 /discord/utils.py | |
| parent | [commands] Disallow float/complex in Literal but allow None (diff) | |
| download | discord.py-9f3551926ad5176c0cbf23a1a127452a2749a135.tar.xz discord.py-9f3551926ad5176c0cbf23a1a127452a2749a135.zip | |
Split annotation resolution to discord.utils
Diffstat (limited to 'discord/utils.py')
| -rw-r--r-- | discord/utils.py | 105 |
1 files changed, 103 insertions, 2 deletions
diff --git a/discord/utils.py b/discord/utils.py index 88948da5..293103cd 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -31,13 +31,16 @@ from typing import ( AsyncIterator, Callable, Dict, + ForwardRef, Generic, Iterable, Iterator, List, + Literal, Optional, Protocol, Sequence, + Tuple, Type, TypeVar, Union, @@ -53,6 +56,8 @@ from inspect import isawaitable as _isawaitable, signature as _signature from operator import attrgetter import json import re +import sys +import types import warnings from .errors import InvalidArgument @@ -99,6 +104,7 @@ if TYPE_CHECKING: class _RequestLike(Protocol): headers: Dict[str, Any] + else: cached_property = _cached_property @@ -741,6 +747,7 @@ def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]: if ret: yield ret + async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]: ret = [] n = 0 @@ -767,9 +774,9 @@ def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]: """A helper function that collects an iterator into chunks of a given size. - + .. versionadded:: 2.0 - + Parameters ---------- iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`] @@ -793,3 +800,97 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]: if isinstance(iterator, AsyncIterator): return _achunk(iterator, max_size) return _chunk(iterator, max_size) + + +PY_310 = sys.version_info >= (3, 10) + + +def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: + params = [] + literal_cls = type(Literal[0]) + for p in parameters: + if isinstance(p, literal_cls): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) + + +def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: + none_cls = type(None) + return tuple(p for p in parameters if p is not none_cls) + (none_cls,) + + +def evaluate_annotation( + tp: Any, + globals: Dict[str, Any], + locals: Dict[str, Any], + cache: Dict[str, Any], + *, + implicit_str: bool = True, +): + if isinstance(tp, ForwardRef): + tp = tp.__forward_arg__ + # ForwardRefs always evaluate their internals + implicit_str = True + + if implicit_str and isinstance(tp, str): + if tp in cache: + return cache[tp] + evaluated = eval(tp, globals, locals) + cache[tp] = evaluated + return evaluate_annotation(evaluated, globals, locals, cache) + + if hasattr(tp, '__args__'): + implicit_str = True + is_literal = False + args = tp.__args__ + if not hasattr(tp, '__origin__'): + if PY_310 and tp.__class__ is types.Union: + converted = Union[args] # type: ignore + return evaluate_annotation(converted, globals, locals, cache) + + return tp + if tp.__origin__ is Union: + try: + if args.index(type(None)) != len(args) - 1: + args = normalise_optional_params(tp.__args__) + except ValueError: + pass + if tp.__origin__ is Literal: + if not PY_310: + args = flatten_literal_params(tp.__args__) + implicit_str = False + is_literal = True + + evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args) + + if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): + raise TypeError('Literal arguments must be of type str, int, bool, float or complex.') + + if evaluated_args == args: + return tp + + try: + return tp.copy_with(evaluated_args) + except AttributeError: + return tp.__origin__[evaluated_args] + + return tp + + +def resolve_annotation( + annotation: Any, + globalns: Dict[str, Any], + localns: Optional[Dict[str, Any]], + cache: Optional[Dict[str, Any]], +) -> Any: + if annotation is None: + return type(None) + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + + locals = globalns if localns is None else localns + if cache is None: + cache = {} + return evaluate_annotation(annotation, globalns, locals, cache) |