aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh <[email protected]>2021-04-11 14:38:17 +1000
committerGitHub <[email protected]>2021-04-11 00:38:17 -0400
commit7f91ae8b676908fec8b7a76aef6f86993871fb05 (patch)
tree3de935e35f3935ea441ad54f77b228336c006079
parent[commands] Fix repr for Greedy (diff)
downloaddiscord.py-7f91ae8b676908fec8b7a76aef6f86993871fb05.tar.xz
discord.py-7f91ae8b676908fec8b7a76aef6f86993871fb05.zip
[commands] use __args__ and __origin__ where applicable
-rw-r--r--discord/ext/commands/core.py41
-rw-r--r--discord/ext/commands/errors.py3
2 files changed, 27 insertions, 17 deletions
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index ab8c52c1..215c89b1 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -30,8 +30,6 @@ from typing import (
Literal,
Tuple,
Union,
- get_args as get_typing_args,
- get_origin as get_typing_origin,
)
import asyncio
import functools
@@ -86,6 +84,10 @@ def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params.append(p)
return tuple(params)
+def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
+ none_cls = type(None)
+ return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
+
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
@@ -102,6 +104,12 @@ def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any]
if hasattr(tp, '__args__'):
implicit_str = True
args = tp.__args__
+ if tp.__origin__ is Union:
+ try:
+ if args.index(type(None)) != len(args) - 1:
+ args = normalise_optional_params(tp.__args__)
+ except ValueError:
+ pass
if tp.__origin__ is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
@@ -547,12 +555,13 @@ class Command(_BaseCommand):
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def do_conversion(self, ctx, converter, argument, param):
- origin = get_typing_origin(converter)
+ origin = getattr(converter, '__origin__', None)
if origin is Union:
errors = []
_NoneType = type(None)
- for conv in get_typing_args(converter):
+ union_args = converter.__args__
+ for conv in union_args:
# if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters
@@ -568,12 +577,13 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed all the converters
- raise BadUnionArgument(param, get_typing_args(converter), errors)
+ raise BadUnionArgument(param, union_args, errors)
if origin is Literal:
errors = []
conversions = {}
- for literal in converter.__args__:
+ literal_args = converter.__args__
+ for literal in literal_args:
literal_type = type(literal)
try:
value = conversions[literal_type]
@@ -591,7 +601,7 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed to match all the literals
- raise BadLiteralArgument(param, converter.__args__, errors)
+ raise BadLiteralArgument(param, literal_args, errors)
return await self._actual_conversion(ctx, converter, argument, param)
@@ -614,7 +624,7 @@ class Command(_BaseCommand):
# The greedy converter is simple -- it keeps going until it fails in which case,
# it undos the view ready for the next parameter to use instead
if isinstance(converter, converters.Greedy):
- if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
+ if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
return await self._transform_greedy_pos(ctx, param, required, converter.converter)
elif param.kind == param.VAR_POSITIONAL:
return await self._transform_greedy_var_pos(ctx, param, converter.converter)
@@ -782,7 +792,7 @@ class Command(_BaseCommand):
raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')
for name, param in iterator:
- if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
+ if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
transformed = await self.transform(ctx, param)
args.append(transformed)
elif param.kind == param.KEYWORD_ONLY:
@@ -1074,7 +1084,7 @@ class Command(_BaseCommand):
return ''
def _is_typing_optional(self, annotation):
- return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)
+ return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__
@property
def signature(self):
@@ -1094,13 +1104,14 @@ class Command(_BaseCommand):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation
- origin = get_typing_origin(annotation)
+ origin = getattr(annotation, '__origin__', None)
if not greedy and origin is Union:
- union_args = get_typing_args(annotation)
- optional = union_args[-1] is type(None)
- if optional:
+ none_cls = type(None)
+ union_args = annotation.__args__
+ optional = union_args[-1] is none_cls
+ if len(union_args) == 2 and optional:
annotation = union_args[0]
- origin = get_typing_origin(annotation)
+ origin = getattr(annotation, '__origin__', None)
if origin is Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py
index b825057e..98154d10 100644
--- a/discord/ext/commands/errors.py
+++ b/discord/ext/commands/errors.py
@@ -23,7 +23,6 @@ DEALINGS IN THE SOFTWARE.
"""
from discord.errors import ClientException, DiscordException
-import typing
__all__ = (
@@ -646,7 +645,7 @@ class BadUnionArgument(UserInputError):
try:
return x.__name__
except AttributeError:
- if typing.get_origin(x) is not None:
+ if hasattr(x, '__origin__'):
return repr(x)
return x.__class__.__name__