diff options
| author | Rapptz <[email protected]> | 2021-04-27 05:47:26 -0400 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2021-04-27 05:48:27 -0400 |
| commit | 9f3551926ad5176c0cbf23a1a127452a2749a135 (patch) | |
| tree | 0e6cbe78da169ee23bbf646679f10daa2f988057 /discord/ext | |
| parent | [commands] Disallow float/complex in Literal but allow None (diff) | |
| download | discord.py-9f3551926ad5176c0cbf23a1a127452a2749a135.tar.xz discord.py-9f3551926ad5176c0cbf23a1a127452a2749a135.zip | |
Split annotation resolution to discord.utils
Diffstat (limited to 'discord/ext')
| -rw-r--r-- | discord/ext/commands/core.py | 99 | ||||
| -rw-r--r-- | discord/ext/commands/flags.py | 2 |
2 files changed, 3 insertions, 98 deletions
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index fcf58add..cb986e3b 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -25,11 +25,7 @@ DEALINGS IN THE SOFTWARE. from typing import ( Any, Dict, - ForwardRef, - Iterable, Literal, - Optional, - Tuple, Union, ) import asyncio @@ -37,7 +33,6 @@ import functools import inspect import datetime import types -import sys import discord @@ -74,102 +69,12 @@ __all__ = ( 'bot_has_guild_permissions' ) -PY_310 = sys.version_info >= (3, 10) - -def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: - params = [] - literal_cls = type(Literal[0]) - for p in parameters: - if isinstance(p, literal_cls): - params.extend(p.__args__) - else: - params.append(p) - return tuple(params) - -def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: - none_cls = type(None) - return tuple(p for p in parameters if p is not none_cls) + (none_cls,) - -def _evaluate_annotation( - tp: Any, - globals: Dict[str, Any], - locals: Dict[str, Any], - cache: Dict[str, Any], - *, - implicit_str: bool = True, -): - if isinstance(tp, ForwardRef): - tp = tp.__forward_arg__ - # ForwardRefs always evaluate their internals - implicit_str = True - - if implicit_str and isinstance(tp, str): - if tp in cache: - return cache[tp] - evaluated = eval(tp, globals, locals) - cache[tp] = evaluated - return _evaluate_annotation(evaluated, globals, locals, cache) - - if hasattr(tp, '__args__'): - implicit_str = True - is_literal = False - args = tp.__args__ - if not hasattr(tp, '__origin__'): - if PY_310 and tp.__class__ is types.Union: - converted = Union[args] # type: ignore - return _evaluate_annotation(converted, globals, locals, cache) - - return tp - if tp.__origin__ is Union: - try: - if args.index(type(None)) != len(args) - 1: - args = normalise_optional_params(tp.__args__) - except ValueError: - pass - if tp.__origin__ is Literal: - if not PY_310: - args = flatten_literal_params(tp.__args__) - implicit_str = False - is_literal = True - - evaluated_args = tuple( - _evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args - ) - - if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): - raise TypeError('Literal arguments must be of type str, int, bool, float or complex.') - - if evaluated_args == args: - return tp - - try: - return tp.copy_with(evaluated_args) - except AttributeError: - return tp.__origin__[evaluated_args] - - return tp - -def resolve_annotation( - annotation: Any, - globalns: Dict[str, Any], - localns: Optional[Dict[str, Any]], - cache: Optional[Dict[str, Any]], -) -> Any: - if annotation is None: - return type(None) - if isinstance(annotation, str): - annotation = ForwardRef(annotation) - - locals = globalns if localns is None else localns - if cache is None: - cache = {} - return _evaluate_annotation(annotation, globalns, locals, cache) - def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]: globalns = function.__globals__ signature = inspect.signature(function) params = {} cache: Dict[str, Any] = {} + eval_annotation = discord.utils.evaluate_annotation for name, parameter in signature.parameters.items(): annotation = parameter.annotation if annotation is parameter.empty: @@ -179,7 +84,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect. params[name] = parameter.replace(annotation=type(None)) continue - annotation = _evaluate_annotation(annotation, globalns, globalns, cache) + annotation = eval_annotation(annotation, globalns, globalns, cache) if annotation is Greedy: raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index e58c9ce5..3aa9a65f 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -32,7 +32,7 @@ from .errors import ( MissingRequiredFlag, ) -from .core import resolve_annotation +from discord.utils import resolve_annotation from .view import StringView from .converter import run_converters |