aboutsummaryrefslogtreecommitdiff
path: root/discord/ext/commands/flags.py
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-04-19 10:25:08 -0400
committerRapptz <[email protected]>2021-04-19 10:25:08 -0400
commitddb71e2aedf081c3d261a992c30b345f3e38baf5 (patch)
treebd622f28198058756629f4def419f26a557822e0 /discord/ext/commands/flags.py
parentRemove lingering User.avatar documentation (diff)
downloaddiscord.py-ddb71e2aedf081c3d261a992c30b345f3e38baf5.tar.xz
discord.py-ddb71e2aedf081c3d261a992c30b345f3e38baf5.zip
[commands] Initial support for FlagConverter
The name is currently pending and there's no command.signature hook for it yet since this requires bikeshedding.
Diffstat (limited to 'discord/ext/commands/flags.py')
-rw-r--r--discord/ext/commands/flags.py530
1 files changed, 530 insertions, 0 deletions
diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py
new file mode 100644
index 00000000..dc632721
--- /dev/null
+++ b/discord/ext/commands/flags.py
@@ -0,0 +1,530 @@
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-present Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+from __future__ import annotations
+
+from .errors import (
+ BadFlagArgument,
+ CommandError,
+ MissingFlagArgument,
+ TooManyFlags,
+ MissingRequiredFlag,
+)
+
+from .core import resolve_annotation
+from .view import StringView
+from .converter import run_converters
+
+from discord.utils import maybe_coroutine
+from dataclasses import dataclass
+from typing import (
+ Dict,
+ Optional,
+ Pattern,
+ Set,
+ TYPE_CHECKING,
+ Tuple,
+ List,
+ Any,
+ Type,
+ TypeVar,
+ Union,
+)
+
+import inspect
+import sys
+import re
+
+__all__ = (
+ 'Flag',
+ 'flag',
+ 'FlagConverter',
+)
+
+
+if TYPE_CHECKING:
+ from .context import Context
+
+
+class _MissingSentinel:
+ def __repr__(self):
+ return 'MISSING'
+
+
+MISSING: Any = _MissingSentinel()
+
+
+@dataclass
+class Flag:
+ """Represents a flag parameter for :class:`FlagConverter`.
+
+ The :func:`~discord.ext.commands.flag` function helps
+ create these flag objects, but it is not necessary to
+ do so. These cannot be constructed manually.
+
+ Attributes
+ ------------
+ name: :class:`str`
+ The name of the flag.
+ attribute: :class:`str`
+ The attribute in the class that corresponds to this flag.
+ default: Any
+ The default value of the flag, if available.
+ annotation: Any
+ The underlying evaluated annotation of the flag.
+ max_args: :class:`int`
+ The maximum number of arguments the flag can accept.
+ A negative value indicates an unlimited amount of arguments.
+ override: :class:`bool`
+ Whether multiple given values overrides the previous value.
+ """
+
+ name: str = MISSING
+ attribute: str = MISSING
+ annotation: Any = MISSING
+ default: Any = MISSING
+ max_args: int = MISSING
+ override: bool = MISSING
+ cast_to_dict: bool = False
+
+ @property
+ def required(self) -> bool:
+ """:class:`bool`: Whether the flag is required.
+
+ A required flag has no default value.
+ """
+ return self.default is MISSING
+
+
+def flag(
+ *,
+ name: str = MISSING,
+ default: Any = MISSING,
+ max_args: int = MISSING,
+ override: bool = MISSING,
+) -> Any:
+ """Override default functionality and parameters of the underlying :class:`FlagConverter`
+ class attributes.
+
+ Parameters
+ ------------
+ name: :class:`str`
+ The flag name. If not given, defaults to the attribute name.
+ default: Any
+ The default parameter. This could be either a value or a callable that takes
+ :class:`Context` as its sole parameter. If not given then it defaults to
+ the default value given to the attribute.
+ max_args: :class:`int`
+ The maximum number of arguments the flag can accept.
+ A negative value indicates an unlimited amount of arguments.
+ The default value depends on the annotation given.
+ override: :class:`bool`
+ Whether multiple given values overrides the previous value. The default
+ value depends on the annotation given.
+ """
+ return Flag(name=name, default=default, max_args=max_args, override=override)
+
+
+def validate_flag_name(name: str, forbidden: Set[str]):
+ if not name:
+ raise ValueError('flag names should not be empty')
+
+ for ch in name:
+ if ch.isspace():
+ raise ValueError(f'flag name {name!r} cannot have spaces')
+ if ch == '\\':
+ raise ValueError(f'flag name {name!r} cannot have backslashes')
+ if ch in forbidden:
+ raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
+
+
+def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
+ annotations = namespace.get('__annotations__', {})
+ flags: Dict[str, Flag] = {}
+ cache: Dict[str, Any] = {}
+ for name, annotation in annotations.items():
+ flag = namespace.pop(name, MISSING)
+ if isinstance(flag, Flag):
+ flag.annotation = annotation
+ else:
+ flag = Flag(name=name, annotation=annotation, default=flag)
+
+ flag.attribute = name
+ if flag.name is MISSING:
+ flag.name = name
+
+ annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
+
+ # Add sensible defaults based off of the type annotation
+ # <type> -> (max_args=1)
+ # List[str] -> (max_args=-1)
+ # Tuple[int, ...] -> (max_args=1)
+ # Dict[K, V] -> (max_args=-1, override=True)
+ # Optional[str] -> (default=None, max_args=1)
+
+ try:
+ origin = annotation.__origin__
+ except AttributeError:
+ # A regular type hint
+ if flag.max_args is MISSING:
+ flag.max_args = 1
+ else:
+ if origin is Union and annotation.__args__[-1] is type(None):
+ # typing.Optional
+ if flag.max_args is MISSING:
+ flag.max_args = 1
+ if flag.default is MISSING:
+ flag.default = None
+ elif origin is tuple:
+ # typing.Tuple
+ # tuple parsing is e.g. `flag: peter 20`
+ # for Tuple[str, int] would give you flag: ('peter', 20)
+ if flag.max_args is MISSING:
+ flag.max_args = 1
+ elif origin is list:
+ # typing.List
+ if flag.max_args is MISSING:
+ flag.max_args = -1
+ elif origin is dict:
+ # typing.Dict[K, V]
+ # Equivalent to:
+ # typing.List[typing.Tuple[K, V]]
+ flag.cast_to_dict = True
+ if flag.max_args is MISSING:
+ flag.max_args = -1
+ if flag.override is MISSING:
+ flag.override = True
+ else:
+ raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
+
+ if flag.override is MISSING:
+ flag.override = False
+
+ flags[flag.name] = flag
+
+ return flags
+
+
+class FlagsMeta(type):
+ if TYPE_CHECKING:
+ __commands_is_flag__: bool
+ __commands_flags__: Dict[str, Flag]
+ __commands_flag_regex__: Pattern[str]
+ __commands_flag_case_insensitive__: bool
+ __commands_flag_delimiter__: str
+ __commands_flag_prefix__: str
+
+ def __new__(
+ cls: Type[type],
+ name: str,
+ bases: Tuple[type, ...],
+ attrs: Dict[str, Any],
+ *,
+ case_insensitive: bool = False,
+ delimiter: str = ':',
+ prefix: str = '',
+ ):
+ attrs['__commands_is_flag__'] = True
+ attrs['__commands_flag_case_insensitive__'] = case_insensitive
+ attrs['__commands_flag_delimiter__'] = delimiter
+ attrs['__commands_flag_prefix__'] = prefix
+
+ if not prefix and not delimiter:
+ raise TypeError('Must have either a delimiter or a prefix set')
+
+ try:
+ global_ns = sys.modules[attrs['__module__']].__dict__
+ except KeyError:
+ global_ns = {}
+
+ frame = inspect.currentframe()
+ try:
+ if frame is None:
+ local_ns = {}
+ else:
+ if frame.f_back is None:
+ local_ns = frame.f_locals
+ else:
+ local_ns = frame.f_back.f_locals
+ finally:
+ del frame
+
+ flags: Dict[str, Flag] = {}
+ for base in reversed(bases):
+ if base.__dict__.get('__commands_is_flag__', False):
+ flags.update(base.__dict__['__commands_flags__'])
+
+ flags.update(get_flags(attrs, global_ns, local_ns))
+ forbidden = set(delimiter).union(prefix)
+ for flag_name in flags:
+ validate_flag_name(flag_name, forbidden)
+
+ regex_flags = 0
+ if case_insensitive:
+ flags = {key.casefold(): value for key, value in flags.items()}
+ regex_flags = re.IGNORECASE
+
+ keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True)
+ joined = '|'.join(keys)
+ pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
+ attrs['__commands_flag_regex__'] = pattern
+ attrs['__commands_flags__'] = flags
+
+ return type.__new__(cls, name, bases, attrs)
+
+
+async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
+ view = StringView(argument)
+ results = []
+ param: inspect.Parameter = ctx.current_parameter # type: ignore
+ while not view.eof:
+ view.skip_ws()
+ if view.eof:
+ break
+
+ word = view.get_quoted_word()
+ if word is None:
+ break
+
+ try:
+ converted = await run_converters(ctx, converter, word, param)
+ except CommandError:
+ raise
+ except Exception as e:
+ raise BadFlagArgument(flag) from e
+ else:
+ results.append(converted)
+
+ return tuple(results)
+
+
+async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
+ view = StringView(argument)
+ results = []
+ param: inspect.Parameter = ctx.current_parameter # type: ignore
+ for converter in converters:
+ view.skip_ws()
+ if view.eof:
+ break
+
+ word = view.get_quoted_word()
+ if word is None:
+ break
+
+ try:
+ converted = await run_converters(ctx, converter, word, param)
+ except CommandError:
+ raise
+ except Exception as e:
+ raise BadFlagArgument(flag) from e
+ else:
+ results.append(converted)
+
+ if len(results) != len(converters):
+ raise BadFlagArgument(flag)
+
+ return tuple(results)
+
+
+async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any:
+ param: inspect.Parameter = ctx.current_parameter # type: ignore
+ annotation = annotation or flag.annotation
+ try:
+ origin = annotation.__origin__
+ except AttributeError:
+ pass
+ else:
+ if origin is tuple:
+ if annotation.__args__[-1] is Ellipsis:
+ return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0])
+ else:
+ return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
+ elif origin is list or origin is Union and annotation.__args__[-1] is type(None):
+ # typing.List[x] or typing.Optional[x]
+ annotation = annotation.__args__[0]
+ return await convert_flag(ctx, argument, flag, annotation)
+ elif origin is dict:
+ # typing.Dict[K, V] -> typing.Tuple[K, V]
+ return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
+
+ try:
+ return await run_converters(ctx, annotation, argument, param)
+ except CommandError:
+ raise
+ except Exception as e:
+ raise BadFlagArgument(flag) from e
+
+
+F = TypeVar('F', bound='FlagConverter')
+
+
+class FlagConverter(metaclass=FlagsMeta):
+ """A converter that allows for a user-friendly flag syntax.
+
+ The flags are defined using :pep:`526` type annotations similar
+ to the :mod:`dataclasses` Python module. For more information on
+ how this converter works, check the appropriate
+ :ref:`documentation <ext_commands_flag_converter>`.
+
+ .. versionadded:: 2.0
+
+ Parameters
+ -----------
+ case_insensitive: :class:`bool`
+ A class parameter to toggle case insensitivity of the flag parsing.
+ If ``True`` then flags are parsed in a case insensitive manner.
+ Defaults to ``False``.
+ prefix: :class:`str`
+ The prefix that all flags must be prefixed with. By default
+ there is no prefix.
+ delimiter: :class:`str`
+ The delimiter that separates a flag's argument from the flag's name.
+ By default this is ``:``.
+ """
+
+ @classmethod
+ def get_flags(cls) -> Dict[str, Flag]:
+ """Dict[:class:`str`, :class:`Flag`]: A mapping of flag name to flag object this converter has."""
+ return cls.__commands_flags__.copy()
+
+ def __repr__(self) -> str:
+ pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
+ return f'<{self.__class__.__name__} {pairs}>'
+
+ @classmethod
+ def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
+ result: Dict[str, List[str]] = {}
+ flags = cls.get_flags()
+ last_position = 0
+ last_flag: Optional[Flag] = None
+
+ case_insensitive = cls.__commands_flag_case_insensitive__
+ for match in cls.__commands_flag_regex__.finditer(argument):
+ begin, end = match.span(0)
+ key = match.group('flag')
+ if case_insensitive:
+ key = key.casefold()
+
+ flag = flags.get(key)
+ if last_position and last_flag is not None:
+ value = argument[last_position : begin - 1].lstrip()
+ if not value:
+ raise MissingFlagArgument(last_flag)
+
+ try:
+ values = result[last_flag.name]
+ except KeyError:
+ result[last_flag.name] = [value]
+ else:
+ values.append(value)
+
+ last_position = end
+ last_flag = flag
+
+ # Add the remaining string to the last available flag
+ if last_position and last_flag is not None:
+ value = argument[last_position:].strip()
+ if not value:
+ raise MissingFlagArgument(last_flag)
+
+ try:
+ values = result[last_flag.name]
+ except KeyError:
+ result[last_flag.name] = [value]
+ else:
+ values.append(value)
+
+ # Verification of values will come at a later stage
+ return result
+
+ @classmethod
+ async def convert(cls: Type[F], ctx: Context, argument: str) -> F:
+ """|coro|
+
+ The method that actually converters an argument to the flag mapping.
+
+ Parameters
+ ----------
+ cls: Type[:class:`FlagConverter`]
+ The flag converter class.
+ ctx: :class:`Context`
+ The invocation context.
+ argument: :class:`str`
+ The argument to convert from.
+
+ Raises
+ --------
+ FlagError
+ A flag related parsing error.
+ CommandError
+ A command related error.
+
+ Returns
+ --------
+ :class:`FlagConverter`
+ The flag converter instance with all flags parsed.
+ """
+ arguments = cls.parse_flags(argument)
+ flags = cls.get_flags()
+
+ self: F = cls.__new__(cls)
+ for name, flag in flags.items():
+ try:
+ values = arguments[name]
+ except KeyError:
+ if flag.required:
+ raise MissingRequiredFlag(flag)
+ else:
+ if callable(flag.default):
+ default = await maybe_coroutine(flag.default, ctx)
+ setattr(self, flag.attribute, default)
+ else:
+ setattr(self, flag.attribute, flag.default)
+ continue
+
+ if flag.max_args > 0 and len(values) > flag.max_args:
+ if flag.override:
+ values = values[-flag.max_args :]
+ else:
+ raise TooManyFlags(flag, values)
+
+ # Special case:
+ if flag.max_args == 1:
+ value = await convert_flag(ctx, values[0], flag)
+ setattr(self, flag.attribute, value)
+ continue
+
+ # Another special case, tuple parsing.
+ # Tuple parsing is basically converting arguments within the flag
+ # So, given flag: hello 20 as the input and Tuple[str, int] as the type hint
+ # We would receive ('hello', 20) as the resulting value
+ # This uses the same whitespace and quoting rules as regular parameters.
+ values = [await convert_flag(ctx, value, flag) for value in values]
+
+ if flag.cast_to_dict:
+ values = dict(values) # type: ignore
+
+ setattr(self, flag.attribute, values)
+
+ return self