aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
Diffstat (limited to 'discord')
-rw-r--r--discord/ext/commands/converter.py177
-rw-r--r--discord/ext/commands/core.py125
2 files changed, 186 insertions, 116 deletions
diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py
index e94a08ef..b12b9804 100644
--- a/discord/ext/commands/converter.py
+++ b/discord/ext/commands/converter.py
@@ -26,7 +26,21 @@ from __future__ import annotations
import re
import inspect
-from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Literal,
+ Optional,
+ TYPE_CHECKING,
+ List,
+ Protocol,
+ Type,
+ TypeVar,
+ Tuple,
+ Union,
+ runtime_checkable,
+)
import discord
from .errors import *
@@ -58,6 +72,7 @@ __all__ = (
'StoreChannelConverter',
'clean_content',
'Greedy',
+ 'run_converters',
)
@@ -867,3 +882,163 @@ class Greedy(List[T]):
raise TypeError(f'Greedy[{converter!r}] is invalid.')
return cls(converter=converter)
+
+
+def _convert_to_bool(argument: str) -> bool:
+ lowered = argument.lower()
+ if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
+ return True
+ elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
+ return False
+ else:
+ raise BadBoolArgument(lowered)
+
+
+def get_converter(param: inspect.Parameter) -> Any:
+ converter = param.annotation
+ if converter is param.empty:
+ if param.default is not param.empty:
+ converter = str if param.default is None else type(param.default)
+ else:
+ converter = str
+ return converter
+
+
+CONVERTER_MAPPING: Dict[Type[Any], Any] = {
+ discord.Object: ObjectConverter,
+ discord.Member: MemberConverter,
+ discord.User: UserConverter,
+ discord.Message: MessageConverter,
+ discord.PartialMessage: PartialMessageConverter,
+ discord.TextChannel: TextChannelConverter,
+ discord.Invite: InviteConverter,
+ discord.Guild: GuildConverter,
+ discord.Role: RoleConverter,
+ discord.Game: GameConverter,
+ discord.Colour: ColourConverter,
+ discord.VoiceChannel: VoiceChannelConverter,
+ discord.StageChannel: StageChannelConverter,
+ discord.Emoji: EmojiConverter,
+ discord.PartialEmoji: PartialEmojiConverter,
+ discord.CategoryChannel: CategoryChannelConverter,
+ discord.StoreChannel: StoreChannelConverter,
+}
+
+
+async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
+ if converter is bool:
+ return _convert_to_bool(argument)
+
+ try:
+ module = converter.__module__
+ except AttributeError:
+ pass
+ else:
+ if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
+ converter = CONVERTER_MAPPING.get(converter, converter)
+
+ try:
+ if inspect.isclass(converter) and issubclass(converter, Converter):
+ if inspect.ismethod(converter.convert):
+ return await converter.convert(ctx, argument)
+ else:
+ return await converter().convert(ctx, argument)
+ elif isinstance(converter, Converter):
+ return await converter.convert(ctx, argument)
+ except CommandError:
+ raise
+ except Exception as exc:
+ raise ConversionError(converter, exc) from exc
+
+ try:
+ return converter(argument)
+ except CommandError:
+ raise
+ except Exception as exc:
+ try:
+ name = converter.__name__
+ except AttributeError:
+ name = converter.__class__.__name__
+
+ raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
+
+
+async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
+ """|coro|
+
+ Runs converters for a given converter, argument, and parameter.
+
+ This function does the same work that the library does under the hood.
+
+ .. versionadded:: 2.0
+
+ Parameters
+ ------------
+ ctx: :class:`Context`
+ The invocation context to run the converters under.
+ converter: Any
+ The converter to run, this corresponds to the annotation in the function.
+ argument: :class:`str`
+ The argument to convert to.
+ param: :class:`inspect.Parameter`
+ The parameter being converted. This is mainly for error reporting.
+
+ Raises
+ -------
+ CommandError
+ The converter failed to convert.
+
+ Returns
+ --------
+ Any
+ The resulting conversion.
+ """
+ origin = getattr(converter, '__origin__', None)
+
+ if origin is Union:
+ errors = []
+ _NoneType = type(None)
+ union_args = converter.__args__
+ for conv in union_args:
+ # if we got to this part in the code, then the previous conversions have failed
+ # so we should just undo the view, return the default, and allow parsing to continue
+ # with the other parameters
+ if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
+ ctx.view.undo()
+ return None if param.default is param.empty else param.default
+
+ try:
+ value = await run_converters(ctx, conv, argument, param)
+ except CommandError as exc:
+ errors.append(exc)
+ else:
+ return value
+
+ # if we're here, then we failed all the converters
+ raise BadUnionArgument(param, union_args, errors)
+
+ if origin is Literal:
+ errors = []
+ conversions = {}
+ literal_args = converter.__args__
+ for literal in literal_args:
+ literal_type = type(literal)
+ try:
+ value = conversions[literal_type]
+ except KeyError:
+ try:
+ value = await _actual_conversion(ctx, literal_type, argument, param)
+ except CommandError as exc:
+ errors.append(exc)
+ conversions[literal_type] = object()
+ continue
+ else:
+ conversions[literal_type] = value
+
+ if value == literal:
+ return value
+
+ # if we're here, then we failed to match all the literals
+ raise BadLiteralArgument(param, literal_args, errors)
+
+ return await _actual_conversion(ctx, converter, argument, param)
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index cbdd4fcf..5809a156 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -43,7 +43,7 @@ import discord
from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
-from . import converter as converters
+from .converter import run_converters, get_converter, Greedy
from ._types import _BaseCommand
from .cog import Cog
@@ -175,7 +175,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.
continue
annotation = _evaluate_annotation(annotation, globalns, globalns, cache)
- if annotation is converters.Greedy:
+ if annotation is Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
params[name] = parameter.replace(annotation=annotation)
@@ -219,14 +219,6 @@ def hooked_wrapped_callback(command, ctx, coro):
return ret
return wrapped
-def _convert_to_bool(argument):
- lowered = argument.lower()
- if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
- return True
- elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
- return False
- else:
- raise BadBoolArgument(lowered)
class _CaseInsensitiveDict(dict):
def __contains__(self, k):
@@ -541,113 +533,16 @@ class Command(_BaseCommand):
finally:
ctx.bot.dispatch('command_error', ctx, error)
- async def _actual_conversion(self, ctx, converter, argument, param):
- if converter is bool:
- return _convert_to_bool(argument)
-
- try:
- module = converter.__module__
- except AttributeError:
- pass
- else:
- if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
- converter = getattr(converters, converter.__name__ + 'Converter', converter)
-
- try:
- if inspect.isclass(converter) and issubclass(converter, converters.Converter):
- if inspect.ismethod(converter.convert):
- return await converter.convert(ctx, argument)
- else:
- return await converter().convert(ctx, argument)
- elif isinstance(converter, converters.Converter):
- return await converter.convert(ctx, argument)
- except CommandError:
- raise
- except Exception as exc:
- raise ConversionError(converter, exc) from exc
-
- try:
- return converter(argument)
- except CommandError:
- raise
- except Exception as exc:
- try:
- name = converter.__name__
- except AttributeError:
- name = converter.__class__.__name__
-
- raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
-
- async def do_conversion(self, ctx, converter, argument, param):
- origin = getattr(converter, '__origin__', None)
-
- if origin is Union:
- errors = []
- _NoneType = type(None)
- union_args = converter.__args__
- for conv in union_args:
- # if we got to this part in the code, then the previous conversions have failed
- # so we should just undo the view, return the default, and allow parsing to continue
- # with the other parameters
- if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
- ctx.view.undo()
- return None if param.default is param.empty else param.default
-
- try:
- value = await self.do_conversion(ctx, conv, argument, param)
- except CommandError as exc:
- errors.append(exc)
- else:
- return value
-
- # if we're here, then we failed all the converters
- raise BadUnionArgument(param, union_args, errors)
-
- if origin is Literal:
- errors = []
- conversions = {}
- literal_args = converter.__args__
- for literal in literal_args:
- literal_type = type(literal)
- try:
- value = conversions[literal_type]
- except KeyError:
- try:
- value = await self._actual_conversion(ctx, literal_type, argument, param)
- except CommandError as exc:
- errors.append(exc)
- conversions[literal_type] = object()
- continue
- else:
- conversions[literal_type] = value
-
- if value == literal:
- return value
-
- # if we're here, then we failed to match all the literals
- raise BadLiteralArgument(param, literal_args, errors)
-
- return await self._actual_conversion(ctx, converter, argument, param)
-
- def _get_converter(self, param):
- converter = param.annotation
- if converter is param.empty:
- if param.default is not param.empty:
- converter = str if param.default is None else type(param.default)
- else:
- converter = str
- return converter
-
async def transform(self, ctx, param):
required = param.default is param.empty
- converter = self._get_converter(param)
+ converter = get_converter(param)
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
view = ctx.view
view.skip_ws()
# The greedy converter is simple -- it keeps going until it fails in which case,
# it undos the view ready for the next parameter to use instead
- if isinstance(converter, converters.Greedy):
+ if isinstance(converter, Greedy):
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
return await self._transform_greedy_pos(ctx, param, required, converter.converter)
elif param.kind == param.VAR_POSITIONAL:
@@ -674,7 +569,7 @@ class Command(_BaseCommand):
argument = view.get_quoted_word()
view.previous = previous
- return await self.do_conversion(ctx, converter, argument, param)
+ return await run_converters(ctx, converter, argument, param)
async def _transform_greedy_pos(self, ctx, param, required, converter):
view = ctx.view
@@ -686,7 +581,7 @@ class Command(_BaseCommand):
view.skip_ws()
try:
argument = view.get_quoted_word()
- value = await self.do_conversion(ctx, converter, argument, param)
+ value = await run_converters(ctx, converter, argument, param)
except (CommandError, ArgumentParsingError):
view.index = previous
break
@@ -702,7 +597,7 @@ class Command(_BaseCommand):
previous = view.index
try:
argument = view.get_quoted_word()
- value = await self.do_conversion(ctx, converter, argument, param)
+ value = await run_converters(ctx, converter, argument, param)
except (CommandError, ArgumentParsingError):
view.index = previous
raise RuntimeError() from None # break loop
@@ -826,9 +721,9 @@ class Command(_BaseCommand):
elif param.kind == param.KEYWORD_ONLY:
# kwarg only param denotes "consume rest" semantics
if self.rest_is_raw:
- converter = self._get_converter(param)
+ converter = get_converter(param)
argument = view.read_rest()
- kwargs[name] = await self.do_conversion(ctx, converter, argument, param)
+ kwargs[name] = await run_converters(ctx, converter, argument, param)
else:
kwargs[name] = await self.transform(ctx, param)
break
@@ -1126,7 +1021,7 @@ class Command(_BaseCommand):
result = []
for name, param in params.items():
- greedy = isinstance(param.annotation, converters.Greedy)
+ greedy = isinstance(param.annotation, Greedy)
optional = False # postpone evaluation of if it's an optional argument
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the