aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/ext/commands/core.py161
1 files changed, 103 insertions, 58 deletions
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index 9bd8f4f8..ab8c52c1 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -22,10 +22,20 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from typing import (
+ Any,
+ Dict,
+ ForwardRef,
+ Iterable,
+ Literal,
+ Tuple,
+ Union,
+ get_args as get_typing_args,
+ get_origin as get_typing_origin,
+)
import asyncio
import functools
import inspect
-import typing
import datetime
import sys
@@ -64,6 +74,83 @@ __all__ = (
'bot_has_guild_permissions'
)
+PY_310 = sys.version_info >= (3, 10)
+
+def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
+ params = []
+ literal_cls = type(Literal[0])
+ for p in parameters:
+ if isinstance(p, literal_cls):
+ params.extend(p.__args__)
+ else:
+ params.append(p)
+ return tuple(params)
+
+def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
+ if isinstance(tp, ForwardRef):
+ tp = tp.__forward_arg__
+ # ForwardRefs always evaluate their internals
+ implicit_str = True
+
+ if implicit_str and isinstance(tp, str):
+ if tp in cache:
+ return cache[tp]
+ evaluated = eval(tp, globals)
+ cache[tp] = evaluated
+ return _evaluate_annotation(evaluated, globals, cache)
+
+ if hasattr(tp, '__args__'):
+ implicit_str = True
+ args = tp.__args__
+ if tp.__origin__ is Literal:
+ if not PY_310:
+ args = flatten_literal_params(tp.__args__)
+ implicit_str = False
+
+ evaluated_args = tuple(
+ _evaluate_annotation(arg, globals, cache, implicit_str=implicit_str) for arg in args
+ )
+
+ if evaluated_args == args:
+ return tp
+
+ try:
+ return tp.copy_with(evaluated_args)
+ except AttributeError:
+ return tp.__origin__[evaluated_args]
+
+ return tp
+
+def resolve_annotation(annotation: Any, globalns: Dict[str, Any], cache: Dict[str, Any] = {}) -> Any:
+ if annotation is None:
+ return type(None)
+ if isinstance(annotation, str):
+ annotation = ForwardRef(annotation)
+ return _evaluate_annotation(annotation, globalns, cache)
+
+def get_signature_parameters(function) -> Dict[str, inspect.Parameter]:
+ globalns = function.__globals__
+ signature = inspect.signature(function)
+ params = {}
+ cache: Dict[str, Any] = {}
+ for name, parameter in signature.parameters.items():
+ annotation = parameter.annotation
+ if annotation is parameter.empty:
+ params[name] = parameter
+ continue
+ if annotation is None:
+ params[name] = parameter.replace(annotation=type(None))
+ continue
+
+ annotation = _evaluate_annotation(annotation, globalns, cache)
+ if annotation is converters.Greedy:
+ raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
+
+ params[name] = parameter.replace(annotation=annotation)
+
+ return params
+
+
def wrap_callback(coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
@@ -300,40 +387,7 @@ class Command(_BaseCommand):
def callback(self, function):
self._callback = function
self.module = function.__module__
-
- signature = inspect.signature(function)
- self.params = signature.parameters.copy()
-
- # see: https://bugs.python.org/issue41341
- resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved
-
- try:
- type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()}
- except NameError as e:
- raise NameError(f'unresolved forward reference: {e.args[0]}') from None
-
- for key, value in self.params.items():
- # coalesce the forward references
- if key in type_hints:
- self.params[key] = value = value.replace(annotation=type_hints[key])
-
- # fail early for when someone passes an unparameterized Greedy type
- if value.annotation is converters.Greedy:
- raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
-
- def _return_resolved(self, type, **kwargs):
- return type
-
- def _recursive_resolve(self, type, *, globals=None):
- if not isinstance(type, typing.ForwardRef):
- return type
-
- resolved = eval(type.__forward_arg__, globals)
- args = typing.get_args(resolved)
- for index, arg in enumerate(args):
- inner_resolve_result = self._recursive_resolve(arg, globals=globals)
- resolved[index] = inner_resolve_result
- return resolved
+ self.params = get_signature_parameters(function)
def add_check(self, func):
"""Adds a check to the command.
@@ -493,12 +547,12 @@ class Command(_BaseCommand):
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def do_conversion(self, ctx, converter, argument, param):
- origin = typing.get_origin(converter)
+ origin = get_typing_origin(converter)
- if origin is typing.Union:
+ if origin is Union:
errors = []
_NoneType = type(None)
- for conv in typing.get_args(converter):
+ for conv in get_typing_args(converter):
# if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters
@@ -514,13 +568,12 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed all the converters
- raise BadUnionArgument(param, typing.get_args(converter), errors)
+ raise BadUnionArgument(param, get_typing_args(converter), errors)
- if origin is typing.Literal:
+ if origin is Literal:
errors = []
conversions = {}
- literal_args = tuple(self._flattened_typing_literal_args(converter))
- for literal in literal_args:
+ for literal in converter.__args__:
literal_type = type(literal)
try:
value = conversions[literal_type]
@@ -538,7 +591,7 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed to match all the literals
- raise BadLiteralArgument(param, literal_args, errors)
+ raise BadLiteralArgument(param, converter.__args__, errors)
return await self._actual_conversion(ctx, converter, argument, param)
@@ -1021,14 +1074,7 @@ class Command(_BaseCommand):
return ''
def _is_typing_optional(self, annotation):
- return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None)
-
- def _flattened_typing_literal_args(self, annotation):
- for literal in typing.get_args(annotation):
- if typing.get_origin(literal) is typing.Literal:
- yield from self._flattened_typing_literal_args(literal)
- else:
- yield literal
+ return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)
@property
def signature(self):
@@ -1048,17 +1094,16 @@ class Command(_BaseCommand):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation
- origin = typing.get_origin(annotation)
- if not greedy and origin is typing.Union:
- union_args = typing.get_args(annotation)
+ origin = get_typing_origin(annotation)
+ if not greedy and origin is Union:
+ union_args = get_typing_args(annotation)
optional = union_args[-1] is type(None)
if optional:
annotation = union_args[0]
- origin = typing.get_origin(annotation)
+ origin = get_typing_origin(annotation)
- if origin is typing.Literal:
- name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
- for v in self._flattened_typing_literal_args(annotation))
+ if origin is Literal:
+ name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user.