aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
Diffstat (limited to 'discord')
-rw-r--r--discord/ext/commands/core.py99
-rw-r--r--discord/ext/commands/flags.py2
-rw-r--r--discord/utils.py105
3 files changed, 106 insertions, 100 deletions
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index fcf58add..cb986e3b 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -25,11 +25,7 @@ DEALINGS IN THE SOFTWARE.
from typing import (
Any,
Dict,
- ForwardRef,
- Iterable,
Literal,
- Optional,
- Tuple,
Union,
)
import asyncio
@@ -37,7 +33,6 @@ import functools
import inspect
import datetime
import types
-import sys
import discord
@@ -74,102 +69,12 @@ __all__ = (
'bot_has_guild_permissions'
)
-PY_310 = sys.version_info >= (3, 10)
-
-def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
- params = []
- literal_cls = type(Literal[0])
- for p in parameters:
- if isinstance(p, literal_cls):
- params.extend(p.__args__)
- else:
- params.append(p)
- return tuple(params)
-
-def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
- none_cls = type(None)
- return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
-
-def _evaluate_annotation(
- tp: Any,
- globals: Dict[str, Any],
- locals: Dict[str, Any],
- cache: Dict[str, Any],
- *,
- implicit_str: bool = True,
-):
- if isinstance(tp, ForwardRef):
- tp = tp.__forward_arg__
- # ForwardRefs always evaluate their internals
- implicit_str = True
-
- if implicit_str and isinstance(tp, str):
- if tp in cache:
- return cache[tp]
- evaluated = eval(tp, globals, locals)
- cache[tp] = evaluated
- return _evaluate_annotation(evaluated, globals, locals, cache)
-
- if hasattr(tp, '__args__'):
- implicit_str = True
- is_literal = False
- args = tp.__args__
- if not hasattr(tp, '__origin__'):
- if PY_310 and tp.__class__ is types.Union:
- converted = Union[args] # type: ignore
- return _evaluate_annotation(converted, globals, locals, cache)
-
- return tp
- if tp.__origin__ is Union:
- try:
- if args.index(type(None)) != len(args) - 1:
- args = normalise_optional_params(tp.__args__)
- except ValueError:
- pass
- if tp.__origin__ is Literal:
- if not PY_310:
- args = flatten_literal_params(tp.__args__)
- implicit_str = False
- is_literal = True
-
- evaluated_args = tuple(
- _evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
- )
-
- if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
- raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
-
- if evaluated_args == args:
- return tp
-
- try:
- return tp.copy_with(evaluated_args)
- except AttributeError:
- return tp.__origin__[evaluated_args]
-
- return tp
-
-def resolve_annotation(
- annotation: Any,
- globalns: Dict[str, Any],
- localns: Optional[Dict[str, Any]],
- cache: Optional[Dict[str, Any]],
-) -> Any:
- if annotation is None:
- return type(None)
- if isinstance(annotation, str):
- annotation = ForwardRef(annotation)
-
- locals = globalns if localns is None else localns
- if cache is None:
- cache = {}
- return _evaluate_annotation(annotation, globalns, locals, cache)
-
def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]:
globalns = function.__globals__
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
+ eval_annotation = discord.utils.evaluate_annotation
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
@@ -179,7 +84,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.
params[name] = parameter.replace(annotation=type(None))
continue
- annotation = _evaluate_annotation(annotation, globalns, globalns, cache)
+ annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py
index e58c9ce5..3aa9a65f 100644
--- a/discord/ext/commands/flags.py
+++ b/discord/ext/commands/flags.py
@@ -32,7 +32,7 @@ from .errors import (
MissingRequiredFlag,
)
-from .core import resolve_annotation
+from discord.utils import resolve_annotation
from .view import StringView
from .converter import run_converters
diff --git a/discord/utils.py b/discord/utils.py
index 88948da5..293103cd 100644
--- a/discord/utils.py
+++ b/discord/utils.py
@@ -31,13 +31,16 @@ from typing import (
AsyncIterator,
Callable,
Dict,
+ ForwardRef,
Generic,
Iterable,
Iterator,
List,
+ Literal,
Optional,
Protocol,
Sequence,
+ Tuple,
Type,
TypeVar,
Union,
@@ -53,6 +56,8 @@ from inspect import isawaitable as _isawaitable, signature as _signature
from operator import attrgetter
import json
import re
+import sys
+import types
import warnings
from .errors import InvalidArgument
@@ -99,6 +104,7 @@ if TYPE_CHECKING:
class _RequestLike(Protocol):
headers: Dict[str, Any]
+
else:
cached_property = _cached_property
@@ -741,6 +747,7 @@ def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
if ret:
yield ret
+
async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]:
ret = []
n = 0
@@ -767,9 +774,9 @@ def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T
def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
"""A helper function that collects an iterator into chunks of a given size.
-
+
.. versionadded:: 2.0
-
+
Parameters
----------
iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`]
@@ -793,3 +800,97 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size)
return _chunk(iterator, max_size)
+
+
+PY_310 = sys.version_info >= (3, 10)
+
+
+def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
+ params = []
+ literal_cls = type(Literal[0])
+ for p in parameters:
+ if isinstance(p, literal_cls):
+ params.extend(p.__args__)
+ else:
+ params.append(p)
+ return tuple(params)
+
+
+def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
+ none_cls = type(None)
+ return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
+
+
+def evaluate_annotation(
+ tp: Any,
+ globals: Dict[str, Any],
+ locals: Dict[str, Any],
+ cache: Dict[str, Any],
+ *,
+ implicit_str: bool = True,
+):
+ if isinstance(tp, ForwardRef):
+ tp = tp.__forward_arg__
+ # ForwardRefs always evaluate their internals
+ implicit_str = True
+
+ if implicit_str and isinstance(tp, str):
+ if tp in cache:
+ return cache[tp]
+ evaluated = eval(tp, globals, locals)
+ cache[tp] = evaluated
+ return evaluate_annotation(evaluated, globals, locals, cache)
+
+ if hasattr(tp, '__args__'):
+ implicit_str = True
+ is_literal = False
+ args = tp.__args__
+ if not hasattr(tp, '__origin__'):
+ if PY_310 and tp.__class__ is types.Union:
+ converted = Union[args] # type: ignore
+ return evaluate_annotation(converted, globals, locals, cache)
+
+ return tp
+ if tp.__origin__ is Union:
+ try:
+ if args.index(type(None)) != len(args) - 1:
+ args = normalise_optional_params(tp.__args__)
+ except ValueError:
+ pass
+ if tp.__origin__ is Literal:
+ if not PY_310:
+ args = flatten_literal_params(tp.__args__)
+ implicit_str = False
+ is_literal = True
+
+ evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
+
+ if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
+ raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
+
+ if evaluated_args == args:
+ return tp
+
+ try:
+ return tp.copy_with(evaluated_args)
+ except AttributeError:
+ return tp.__origin__[evaluated_args]
+
+ return tp
+
+
+def resolve_annotation(
+ annotation: Any,
+ globalns: Dict[str, Any],
+ localns: Optional[Dict[str, Any]],
+ cache: Optional[Dict[str, Any]],
+) -> Any:
+ if annotation is None:
+ return type(None)
+ if isinstance(annotation, str):
+ annotation = ForwardRef(annotation)
+
+ locals = globalns if localns is None else localns
+ if cache is None:
+ cache = {}
+ return evaluate_annotation(annotation, globalns, locals, cache)