aboutsummaryrefslogtreecommitdiff
path: root/discord/ext/commands/converter.py
diff options
context:
space:
mode:
Diffstat (limited to 'discord/ext/commands/converter.py')
-rw-r--r--discord/ext/commands/converter.py177
1 files changed, 176 insertions, 1 deletions
diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py
index e94a08ef..b12b9804 100644
--- a/discord/ext/commands/converter.py
+++ b/discord/ext/commands/converter.py
@@ -26,7 +26,21 @@ from __future__ import annotations
import re
import inspect
-from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Literal,
+ Optional,
+ TYPE_CHECKING,
+ List,
+ Protocol,
+ Type,
+ TypeVar,
+ Tuple,
+ Union,
+ runtime_checkable,
+)
import discord
from .errors import *
@@ -58,6 +72,7 @@ __all__ = (
'StoreChannelConverter',
'clean_content',
'Greedy',
+ 'run_converters',
)
@@ -867,3 +882,163 @@ class Greedy(List[T]):
raise TypeError(f'Greedy[{converter!r}] is invalid.')
return cls(converter=converter)
+
+
+def _convert_to_bool(argument: str) -> bool:
+ lowered = argument.lower()
+ if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
+ return True
+ elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
+ return False
+ else:
+ raise BadBoolArgument(lowered)
+
+
+def get_converter(param: inspect.Parameter) -> Any:
+ converter = param.annotation
+ if converter is param.empty:
+ if param.default is not param.empty:
+ converter = str if param.default is None else type(param.default)
+ else:
+ converter = str
+ return converter
+
+
+CONVERTER_MAPPING: Dict[Type[Any], Any] = {
+ discord.Object: ObjectConverter,
+ discord.Member: MemberConverter,
+ discord.User: UserConverter,
+ discord.Message: MessageConverter,
+ discord.PartialMessage: PartialMessageConverter,
+ discord.TextChannel: TextChannelConverter,
+ discord.Invite: InviteConverter,
+ discord.Guild: GuildConverter,
+ discord.Role: RoleConverter,
+ discord.Game: GameConverter,
+ discord.Colour: ColourConverter,
+ discord.VoiceChannel: VoiceChannelConverter,
+ discord.StageChannel: StageChannelConverter,
+ discord.Emoji: EmojiConverter,
+ discord.PartialEmoji: PartialEmojiConverter,
+ discord.CategoryChannel: CategoryChannelConverter,
+ discord.StoreChannel: StoreChannelConverter,
+}
+
+
+async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
+ if converter is bool:
+ return _convert_to_bool(argument)
+
+ try:
+ module = converter.__module__
+ except AttributeError:
+ pass
+ else:
+ if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
+ converter = CONVERTER_MAPPING.get(converter, converter)
+
+ try:
+ if inspect.isclass(converter) and issubclass(converter, Converter):
+ if inspect.ismethod(converter.convert):
+ return await converter.convert(ctx, argument)
+ else:
+ return await converter().convert(ctx, argument)
+ elif isinstance(converter, Converter):
+ return await converter.convert(ctx, argument)
+ except CommandError:
+ raise
+ except Exception as exc:
+ raise ConversionError(converter, exc) from exc
+
+ try:
+ return converter(argument)
+ except CommandError:
+ raise
+ except Exception as exc:
+ try:
+ name = converter.__name__
+ except AttributeError:
+ name = converter.__class__.__name__
+
+ raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
+
+
+async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
+ """|coro|
+
+ Runs converters for a given converter, argument, and parameter.
+
+ This function does the same work that the library does under the hood.
+
+ .. versionadded:: 2.0
+
+ Parameters
+ ------------
+ ctx: :class:`Context`
+ The invocation context to run the converters under.
+ converter: Any
+ The converter to run, this corresponds to the annotation in the function.
+ argument: :class:`str`
+ The argument to convert to.
+ param: :class:`inspect.Parameter`
+ The parameter being converted. This is mainly for error reporting.
+
+ Raises
+ -------
+ CommandError
+ The converter failed to convert.
+
+ Returns
+ --------
+ Any
+ The resulting conversion.
+ """
+ origin = getattr(converter, '__origin__', None)
+
+ if origin is Union:
+ errors = []
+ _NoneType = type(None)
+ union_args = converter.__args__
+ for conv in union_args:
+ # if we got to this part in the code, then the previous conversions have failed
+ # so we should just undo the view, return the default, and allow parsing to continue
+ # with the other parameters
+ if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
+ ctx.view.undo()
+ return None if param.default is param.empty else param.default
+
+ try:
+ value = await run_converters(ctx, conv, argument, param)
+ except CommandError as exc:
+ errors.append(exc)
+ else:
+ return value
+
+ # if we're here, then we failed all the converters
+ raise BadUnionArgument(param, union_args, errors)
+
+ if origin is Literal:
+ errors = []
+ conversions = {}
+ literal_args = converter.__args__
+ for literal in literal_args:
+ literal_type = type(literal)
+ try:
+ value = conversions[literal_type]
+ except KeyError:
+ try:
+ value = await _actual_conversion(ctx, literal_type, argument, param)
+ except CommandError as exc:
+ errors.append(exc)
+ conversions[literal_type] = object()
+ continue
+ else:
+ conversions[literal_type] = value
+
+ if value == literal:
+ return value
+
+ # if we're here, then we failed to match all the literals
+ raise BadLiteralArgument(param, literal_args, errors)
+
+ return await _actual_conversion(ctx, converter, argument, param)