diff options
Diffstat (limited to 'discord/ext/commands/converter.py')
| -rw-r--r-- | discord/ext/commands/converter.py | 177 |
1 files changed, 176 insertions, 1 deletions
diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index e94a08ef..b12b9804 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -26,7 +26,21 @@ from __future__ import annotations import re import inspect -from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable +from typing import ( + Any, + Dict, + Iterable, + Literal, + Optional, + TYPE_CHECKING, + List, + Protocol, + Type, + TypeVar, + Tuple, + Union, + runtime_checkable, +) import discord from .errors import * @@ -58,6 +72,7 @@ __all__ = ( 'StoreChannelConverter', 'clean_content', 'Greedy', + 'run_converters', ) @@ -867,3 +882,163 @@ class Greedy(List[T]): raise TypeError(f'Greedy[{converter!r}] is invalid.') return cls(converter=converter) + + +def _convert_to_bool(argument: str) -> bool: + lowered = argument.lower() + if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): + return True + elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): + return False + else: + raise BadBoolArgument(lowered) + + +def get_converter(param: inspect.Parameter) -> Any: + converter = param.annotation + if converter is param.empty: + if param.default is not param.empty: + converter = str if param.default is None else type(param.default) + else: + converter = str + return converter + + +CONVERTER_MAPPING: Dict[Type[Any], Any] = { + discord.Object: ObjectConverter, + discord.Member: MemberConverter, + discord.User: UserConverter, + discord.Message: MessageConverter, + discord.PartialMessage: PartialMessageConverter, + discord.TextChannel: TextChannelConverter, + discord.Invite: InviteConverter, + discord.Guild: GuildConverter, + discord.Role: RoleConverter, + discord.Game: GameConverter, + discord.Colour: ColourConverter, + discord.VoiceChannel: VoiceChannelConverter, + discord.StageChannel: StageChannelConverter, + discord.Emoji: EmojiConverter, + discord.PartialEmoji: PartialEmojiConverter, + discord.CategoryChannel: CategoryChannelConverter, + discord.StoreChannel: StoreChannelConverter, +} + + +async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): + if converter is bool: + return _convert_to_bool(argument) + + try: + module = converter.__module__ + except AttributeError: + pass + else: + if module is not None and (module.startswith('discord.') and not module.endswith('converter')): + converter = CONVERTER_MAPPING.get(converter, converter) + + try: + if inspect.isclass(converter) and issubclass(converter, Converter): + if inspect.ismethod(converter.convert): + return await converter.convert(ctx, argument) + else: + return await converter().convert(ctx, argument) + elif isinstance(converter, Converter): + return await converter.convert(ctx, argument) + except CommandError: + raise + except Exception as exc: + raise ConversionError(converter, exc) from exc + + try: + return converter(argument) + except CommandError: + raise + except Exception as exc: + try: + name = converter.__name__ + except AttributeError: + name = converter.__class__.__name__ + + raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc + + +async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): + """|coro| + + Runs converters for a given converter, argument, and parameter. + + This function does the same work that the library does under the hood. + + .. versionadded:: 2.0 + + Parameters + ------------ + ctx: :class:`Context` + The invocation context to run the converters under. + converter: Any + The converter to run, this corresponds to the annotation in the function. + argument: :class:`str` + The argument to convert to. + param: :class:`inspect.Parameter` + The parameter being converted. This is mainly for error reporting. + + Raises + ------- + CommandError + The converter failed to convert. + + Returns + -------- + Any + The resulting conversion. + """ + origin = getattr(converter, '__origin__', None) + + if origin is Union: + errors = [] + _NoneType = type(None) + union_args = converter.__args__ + for conv in union_args: + # if we got to this part in the code, then the previous conversions have failed + # so we should just undo the view, return the default, and allow parsing to continue + # with the other parameters + if conv is _NoneType and param.kind != param.VAR_POSITIONAL: + ctx.view.undo() + return None if param.default is param.empty else param.default + + try: + value = await run_converters(ctx, conv, argument, param) + except CommandError as exc: + errors.append(exc) + else: + return value + + # if we're here, then we failed all the converters + raise BadUnionArgument(param, union_args, errors) + + if origin is Literal: + errors = [] + conversions = {} + literal_args = converter.__args__ + for literal in literal_args: + literal_type = type(literal) + try: + value = conversions[literal_type] + except KeyError: + try: + value = await _actual_conversion(ctx, literal_type, argument, param) + except CommandError as exc: + errors.append(exc) + conversions[literal_type] = object() + continue + else: + conversions[literal_type] = value + + if value == literal: + return value + + # if we're here, then we failed to match all the literals + raise BadLiteralArgument(param, literal_args, errors) + + return await _actual_conversion(ctx, converter, argument, param) |