aboutsummaryrefslogtreecommitdiff
path: root/discord/ext
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-04-27 05:47:26 -0400
committerRapptz <[email protected]>2021-04-27 05:48:27 -0400
commit9f3551926ad5176c0cbf23a1a127452a2749a135 (patch)
tree0e6cbe78da169ee23bbf646679f10daa2f988057 /discord/ext
parent[commands] Disallow float/complex in Literal but allow None (diff)
downloaddiscord.py-9f3551926ad5176c0cbf23a1a127452a2749a135.tar.xz
discord.py-9f3551926ad5176c0cbf23a1a127452a2749a135.zip
Split annotation resolution to discord.utils
Diffstat (limited to 'discord/ext')
-rw-r--r--discord/ext/commands/core.py99
-rw-r--r--discord/ext/commands/flags.py2
2 files changed, 3 insertions, 98 deletions
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index fcf58add..cb986e3b 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -25,11 +25,7 @@ DEALINGS IN THE SOFTWARE.
from typing import (
Any,
Dict,
- ForwardRef,
- Iterable,
Literal,
- Optional,
- Tuple,
Union,
)
import asyncio
@@ -37,7 +33,6 @@ import functools
import inspect
import datetime
import types
-import sys
import discord
@@ -74,102 +69,12 @@ __all__ = (
'bot_has_guild_permissions'
)
-PY_310 = sys.version_info >= (3, 10)
-
-def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
- params = []
- literal_cls = type(Literal[0])
- for p in parameters:
- if isinstance(p, literal_cls):
- params.extend(p.__args__)
- else:
- params.append(p)
- return tuple(params)
-
-def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
- none_cls = type(None)
- return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
-
-def _evaluate_annotation(
- tp: Any,
- globals: Dict[str, Any],
- locals: Dict[str, Any],
- cache: Dict[str, Any],
- *,
- implicit_str: bool = True,
-):
- if isinstance(tp, ForwardRef):
- tp = tp.__forward_arg__
- # ForwardRefs always evaluate their internals
- implicit_str = True
-
- if implicit_str and isinstance(tp, str):
- if tp in cache:
- return cache[tp]
- evaluated = eval(tp, globals, locals)
- cache[tp] = evaluated
- return _evaluate_annotation(evaluated, globals, locals, cache)
-
- if hasattr(tp, '__args__'):
- implicit_str = True
- is_literal = False
- args = tp.__args__
- if not hasattr(tp, '__origin__'):
- if PY_310 and tp.__class__ is types.Union:
- converted = Union[args] # type: ignore
- return _evaluate_annotation(converted, globals, locals, cache)
-
- return tp
- if tp.__origin__ is Union:
- try:
- if args.index(type(None)) != len(args) - 1:
- args = normalise_optional_params(tp.__args__)
- except ValueError:
- pass
- if tp.__origin__ is Literal:
- if not PY_310:
- args = flatten_literal_params(tp.__args__)
- implicit_str = False
- is_literal = True
-
- evaluated_args = tuple(
- _evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
- )
-
- if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
- raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
-
- if evaluated_args == args:
- return tp
-
- try:
- return tp.copy_with(evaluated_args)
- except AttributeError:
- return tp.__origin__[evaluated_args]
-
- return tp
-
-def resolve_annotation(
- annotation: Any,
- globalns: Dict[str, Any],
- localns: Optional[Dict[str, Any]],
- cache: Optional[Dict[str, Any]],
-) -> Any:
- if annotation is None:
- return type(None)
- if isinstance(annotation, str):
- annotation = ForwardRef(annotation)
-
- locals = globalns if localns is None else localns
- if cache is None:
- cache = {}
- return _evaluate_annotation(annotation, globalns, locals, cache)
-
def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]:
globalns = function.__globals__
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
+ eval_annotation = discord.utils.evaluate_annotation
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
@@ -179,7 +84,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.
params[name] = parameter.replace(annotation=type(None))
continue
- annotation = _evaluate_annotation(annotation, globalns, globalns, cache)
+ annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py
index e58c9ce5..3aa9a65f 100644
--- a/discord/ext/commands/flags.py
+++ b/discord/ext/commands/flags.py
@@ -32,7 +32,7 @@ from .errors import (
MissingRequiredFlag,
)
-from .core import resolve_annotation
+from discord.utils import resolve_annotation
from .view import StringView
from .converter import run_converters