diff options
| author | Rapptz <[email protected]> | 2021-04-19 10:25:08 -0400 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2021-04-19 10:25:08 -0400 |
| commit | ddb71e2aedf081c3d261a992c30b345f3e38baf5 (patch) | |
| tree | bd622f28198058756629f4def419f26a557822e0 /discord/ext/commands/flags.py | |
| parent | Remove lingering User.avatar documentation (diff) | |
| download | discord.py-ddb71e2aedf081c3d261a992c30b345f3e38baf5.tar.xz discord.py-ddb71e2aedf081c3d261a992c30b345f3e38baf5.zip | |
[commands] Initial support for FlagConverter
The name is currently pending and there's no command.signature hook
for it yet since this requires bikeshedding.
Diffstat (limited to 'discord/ext/commands/flags.py')
| -rw-r--r-- | discord/ext/commands/flags.py | 530 |
1 files changed, 530 insertions, 0 deletions
diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py new file mode 100644 index 00000000..dc632721 --- /dev/null +++ b/discord/ext/commands/flags.py @@ -0,0 +1,530 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from .errors import ( + BadFlagArgument, + CommandError, + MissingFlagArgument, + TooManyFlags, + MissingRequiredFlag, +) + +from .core import resolve_annotation +from .view import StringView +from .converter import run_converters + +from discord.utils import maybe_coroutine +from dataclasses import dataclass +from typing import ( + Dict, + Optional, + Pattern, + Set, + TYPE_CHECKING, + Tuple, + List, + Any, + Type, + TypeVar, + Union, +) + +import inspect +import sys +import re + +__all__ = ( + 'Flag', + 'flag', + 'FlagConverter', +) + + +if TYPE_CHECKING: + from .context import Context + + +class _MissingSentinel: + def __repr__(self): + return 'MISSING' + + +MISSING: Any = _MissingSentinel() + + +@dataclass +class Flag: + """Represents a flag parameter for :class:`FlagConverter`. + + The :func:`~discord.ext.commands.flag` function helps + create these flag objects, but it is not necessary to + do so. These cannot be constructed manually. + + Attributes + ------------ + name: :class:`str` + The name of the flag. + attribute: :class:`str` + The attribute in the class that corresponds to this flag. + default: Any + The default value of the flag, if available. + annotation: Any + The underlying evaluated annotation of the flag. + max_args: :class:`int` + The maximum number of arguments the flag can accept. + A negative value indicates an unlimited amount of arguments. + override: :class:`bool` + Whether multiple given values overrides the previous value. + """ + + name: str = MISSING + attribute: str = MISSING + annotation: Any = MISSING + default: Any = MISSING + max_args: int = MISSING + override: bool = MISSING + cast_to_dict: bool = False + + @property + def required(self) -> bool: + """:class:`bool`: Whether the flag is required. + + A required flag has no default value. + """ + return self.default is MISSING + + +def flag( + *, + name: str = MISSING, + default: Any = MISSING, + max_args: int = MISSING, + override: bool = MISSING, +) -> Any: + """Override default functionality and parameters of the underlying :class:`FlagConverter` + class attributes. + + Parameters + ------------ + name: :class:`str` + The flag name. If not given, defaults to the attribute name. + default: Any + The default parameter. This could be either a value or a callable that takes + :class:`Context` as its sole parameter. If not given then it defaults to + the default value given to the attribute. + max_args: :class:`int` + The maximum number of arguments the flag can accept. + A negative value indicates an unlimited amount of arguments. + The default value depends on the annotation given. + override: :class:`bool` + Whether multiple given values overrides the previous value. The default + value depends on the annotation given. + """ + return Flag(name=name, default=default, max_args=max_args, override=override) + + +def validate_flag_name(name: str, forbidden: Set[str]): + if not name: + raise ValueError('flag names should not be empty') + + for ch in name: + if ch.isspace(): + raise ValueError(f'flag name {name!r} cannot have spaces') + if ch == '\\': + raise ValueError(f'flag name {name!r} cannot have backslashes') + if ch in forbidden: + raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them') + + +def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: + annotations = namespace.get('__annotations__', {}) + flags: Dict[str, Flag] = {} + cache: Dict[str, Any] = {} + for name, annotation in annotations.items(): + flag = namespace.pop(name, MISSING) + if isinstance(flag, Flag): + flag.annotation = annotation + else: + flag = Flag(name=name, annotation=annotation, default=flag) + + flag.attribute = name + if flag.name is MISSING: + flag.name = name + + annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) + + # Add sensible defaults based off of the type annotation + # <type> -> (max_args=1) + # List[str] -> (max_args=-1) + # Tuple[int, ...] -> (max_args=1) + # Dict[K, V] -> (max_args=-1, override=True) + # Optional[str] -> (default=None, max_args=1) + + try: + origin = annotation.__origin__ + except AttributeError: + # A regular type hint + if flag.max_args is MISSING: + flag.max_args = 1 + else: + if origin is Union and annotation.__args__[-1] is type(None): + # typing.Optional + if flag.max_args is MISSING: + flag.max_args = 1 + if flag.default is MISSING: + flag.default = None + elif origin is tuple: + # typing.Tuple + # tuple parsing is e.g. `flag: peter 20` + # for Tuple[str, int] would give you flag: ('peter', 20) + if flag.max_args is MISSING: + flag.max_args = 1 + elif origin is list: + # typing.List + if flag.max_args is MISSING: + flag.max_args = -1 + elif origin is dict: + # typing.Dict[K, V] + # Equivalent to: + # typing.List[typing.Tuple[K, V]] + flag.cast_to_dict = True + if flag.max_args is MISSING: + flag.max_args = -1 + if flag.override is MISSING: + flag.override = True + else: + raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag') + + if flag.override is MISSING: + flag.override = False + + flags[flag.name] = flag + + return flags + + +class FlagsMeta(type): + if TYPE_CHECKING: + __commands_is_flag__: bool + __commands_flags__: Dict[str, Flag] + __commands_flag_regex__: Pattern[str] + __commands_flag_case_insensitive__: bool + __commands_flag_delimiter__: str + __commands_flag_prefix__: str + + def __new__( + cls: Type[type], + name: str, + bases: Tuple[type, ...], + attrs: Dict[str, Any], + *, + case_insensitive: bool = False, + delimiter: str = ':', + prefix: str = '', + ): + attrs['__commands_is_flag__'] = True + attrs['__commands_flag_case_insensitive__'] = case_insensitive + attrs['__commands_flag_delimiter__'] = delimiter + attrs['__commands_flag_prefix__'] = prefix + + if not prefix and not delimiter: + raise TypeError('Must have either a delimiter or a prefix set') + + try: + global_ns = sys.modules[attrs['__module__']].__dict__ + except KeyError: + global_ns = {} + + frame = inspect.currentframe() + try: + if frame is None: + local_ns = {} + else: + if frame.f_back is None: + local_ns = frame.f_locals + else: + local_ns = frame.f_back.f_locals + finally: + del frame + + flags: Dict[str, Flag] = {} + for base in reversed(bases): + if base.__dict__.get('__commands_is_flag__', False): + flags.update(base.__dict__['__commands_flags__']) + + flags.update(get_flags(attrs, global_ns, local_ns)) + forbidden = set(delimiter).union(prefix) + for flag_name in flags: + validate_flag_name(flag_name, forbidden) + + regex_flags = 0 + if case_insensitive: + flags = {key.casefold(): value for key, value in flags.items()} + regex_flags = re.IGNORECASE + + keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True) + joined = '|'.join(keys) + pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags) + attrs['__commands_flag_regex__'] = pattern + attrs['__commands_flags__'] = flags + + return type.__new__(cls, name, bases, attrs) + + +async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: + view = StringView(argument) + results = [] + param: inspect.Parameter = ctx.current_parameter # type: ignore + while not view.eof: + view.skip_ws() + if view.eof: + break + + word = view.get_quoted_word() + if word is None: + break + + try: + converted = await run_converters(ctx, converter, word, param) + except CommandError: + raise + except Exception as e: + raise BadFlagArgument(flag) from e + else: + results.append(converted) + + return tuple(results) + + +async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: + view = StringView(argument) + results = [] + param: inspect.Parameter = ctx.current_parameter # type: ignore + for converter in converters: + view.skip_ws() + if view.eof: + break + + word = view.get_quoted_word() + if word is None: + break + + try: + converted = await run_converters(ctx, converter, word, param) + except CommandError: + raise + except Exception as e: + raise BadFlagArgument(flag) from e + else: + results.append(converted) + + if len(results) != len(converters): + raise BadFlagArgument(flag) + + return tuple(results) + + +async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any: + param: inspect.Parameter = ctx.current_parameter # type: ignore + annotation = annotation or flag.annotation + try: + origin = annotation.__origin__ + except AttributeError: + pass + else: + if origin is tuple: + if annotation.__args__[-1] is Ellipsis: + return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0]) + else: + return await tuple_convert_flag(ctx, argument, flag, annotation.__args__) + elif origin is list or origin is Union and annotation.__args__[-1] is type(None): + # typing.List[x] or typing.Optional[x] + annotation = annotation.__args__[0] + return await convert_flag(ctx, argument, flag, annotation) + elif origin is dict: + # typing.Dict[K, V] -> typing.Tuple[K, V] + return await tuple_convert_flag(ctx, argument, flag, annotation.__args__) + + try: + return await run_converters(ctx, annotation, argument, param) + except CommandError: + raise + except Exception as e: + raise BadFlagArgument(flag) from e + + +F = TypeVar('F', bound='FlagConverter') + + +class FlagConverter(metaclass=FlagsMeta): + """A converter that allows for a user-friendly flag syntax. + + The flags are defined using :pep:`526` type annotations similar + to the :mod:`dataclasses` Python module. For more information on + how this converter works, check the appropriate + :ref:`documentation <ext_commands_flag_converter>`. + + .. versionadded:: 2.0 + + Parameters + ----------- + case_insensitive: :class:`bool` + A class parameter to toggle case insensitivity of the flag parsing. + If ``True`` then flags are parsed in a case insensitive manner. + Defaults to ``False``. + prefix: :class:`str` + The prefix that all flags must be prefixed with. By default + there is no prefix. + delimiter: :class:`str` + The delimiter that separates a flag's argument from the flag's name. + By default this is ``:``. + """ + + @classmethod + def get_flags(cls) -> Dict[str, Flag]: + """Dict[:class:`str`, :class:`Flag`]: A mapping of flag name to flag object this converter has.""" + return cls.__commands_flags__.copy() + + def __repr__(self) -> str: + pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()]) + return f'<{self.__class__.__name__} {pairs}>' + + @classmethod + def parse_flags(cls, argument: str) -> Dict[str, List[str]]: + result: Dict[str, List[str]] = {} + flags = cls.get_flags() + last_position = 0 + last_flag: Optional[Flag] = None + + case_insensitive = cls.__commands_flag_case_insensitive__ + for match in cls.__commands_flag_regex__.finditer(argument): + begin, end = match.span(0) + key = match.group('flag') + if case_insensitive: + key = key.casefold() + + flag = flags.get(key) + if last_position and last_flag is not None: + value = argument[last_position : begin - 1].lstrip() + if not value: + raise MissingFlagArgument(last_flag) + + try: + values = result[last_flag.name] + except KeyError: + result[last_flag.name] = [value] + else: + values.append(value) + + last_position = end + last_flag = flag + + # Add the remaining string to the last available flag + if last_position and last_flag is not None: + value = argument[last_position:].strip() + if not value: + raise MissingFlagArgument(last_flag) + + try: + values = result[last_flag.name] + except KeyError: + result[last_flag.name] = [value] + else: + values.append(value) + + # Verification of values will come at a later stage + return result + + @classmethod + async def convert(cls: Type[F], ctx: Context, argument: str) -> F: + """|coro| + + The method that actually converters an argument to the flag mapping. + + Parameters + ---------- + cls: Type[:class:`FlagConverter`] + The flag converter class. + ctx: :class:`Context` + The invocation context. + argument: :class:`str` + The argument to convert from. + + Raises + -------- + FlagError + A flag related parsing error. + CommandError + A command related error. + + Returns + -------- + :class:`FlagConverter` + The flag converter instance with all flags parsed. + """ + arguments = cls.parse_flags(argument) + flags = cls.get_flags() + + self: F = cls.__new__(cls) + for name, flag in flags.items(): + try: + values = arguments[name] + except KeyError: + if flag.required: + raise MissingRequiredFlag(flag) + else: + if callable(flag.default): + default = await maybe_coroutine(flag.default, ctx) + setattr(self, flag.attribute, default) + else: + setattr(self, flag.attribute, flag.default) + continue + + if flag.max_args > 0 and len(values) > flag.max_args: + if flag.override: + values = values[-flag.max_args :] + else: + raise TooManyFlags(flag, values) + + # Special case: + if flag.max_args == 1: + value = await convert_flag(ctx, values[0], flag) + setattr(self, flag.attribute, value) + continue + + # Another special case, tuple parsing. + # Tuple parsing is basically converting arguments within the flag + # So, given flag: hello 20 as the input and Tuple[str, int] as the type hint + # We would receive ('hello', 20) as the resulting value + # This uses the same whitespace and quoting rules as regular parameters. + values = [await convert_flag(ctx, value, flag) for value in values] + + if flag.cast_to_dict: + values = dict(values) # type: ignore + + setattr(self, flag.attribute, values) + + return self |