aboutsummaryrefslogtreecommitdiff
path: root/discord/flags.py
diff options
context:
space:
mode:
authorNadir Chowdhury <[email protected]>2021-04-07 12:55:55 +0100
committerGitHub <[email protected]>2021-04-07 07:55:55 -0400
commit83fe98c20d2bedbc10fdd10d278e171916887ba9 (patch)
tree77c0644868503ae5c28dd0aeb8336f60dc830a8f /discord/flags.py
parent[docs] add note for possible Embed.type strings (diff)
downloaddiscord.py-83fe98c20d2bedbc10fdd10d278e171916887ba9.tar.xz
discord.py-83fe98c20d2bedbc10fdd10d278e171916887ba9.zip
Add typing for flags
Diffstat (limited to 'discord/flags.py')
-rw-r--r--discord/flags.py74
1 files changed, 48 insertions, 26 deletions
diff --git a/discord/flags.py b/discord/flags.py
index d7ceb907..af75132d 100644
--- a/discord/flags.py
+++ b/discord/flags.py
@@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from __future__ import annotations
+
+from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, overload
+
from .enums import UserFlags
__all__ = (
@@ -32,17 +36,28 @@ __all__ = (
'MemberCacheFlags',
)
-class flag_value:
- def __init__(self, func):
+FV = TypeVar('FV', bound='flag_value')
+BF = TypeVar('BF', bound='BaseFlags')
+
+class flag_value(Generic[BF]):
+ def __init__(self, func: Callable[[Any], int]):
self.flag = func(None)
self.__doc__ = func.__doc__
- def __get__(self, instance, owner):
+ @overload
+ def __get__(self: FV, instance: None, owner: Type[BF]) -> FV:
+ ...
+
+ @overload
+ def __get__(self, instance: BF, owner: Type[BF]) -> bool:
+ ...
+
+ def __get__(self, instance: Optional[BF], owner: Type[BF]) -> Any:
if instance is None:
return self
return instance._has_flag(self.flag)
- def __set__(self, instance, value):
+ def __set__(self, instance: BF, value: bool) -> None:
instance._set_flag(self.flag, value)
def __repr__(self):
@@ -51,8 +66,8 @@ class flag_value:
class alias_flag_value(flag_value):
pass
-def fill_with_flags(*, inverted=False):
- def decorator(cls):
+def fill_with_flags(*, inverted: bool = False):
+ def decorator(cls: Type[BF]):
cls.VALID_FLAGS = {
name: value.flag
for name, value in cls.__dict__.items()
@@ -70,9 +85,14 @@ def fill_with_flags(*, inverted=False):
# n.b. flags must inherit from this and use the decorator above
class BaseFlags:
+ VALID_FLAGS: ClassVar[Dict[str, int]]
+ DEFAULT_VALUE: ClassVar[int]
+
+ value: int
+
__slots__ = ('value',)
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
@@ -85,19 +105,19 @@ class BaseFlags:
self.value = value
return self
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.value == other.value
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(self.value)
- def __repr__(self):
+ def __repr__(self) -> str:
return f'<{self.__class__.__name__} value={self.value}>'
- def __iter__(self):
+ def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name, value in self.__class__.__dict__.items():
if isinstance(value, alias_flag_value):
continue
@@ -105,10 +125,10 @@ class BaseFlags:
if isinstance(value, flag_value):
yield (name, self._has_flag(value.flag))
- def _has_flag(self, o):
+ def _has_flag(self, o: int) -> bool:
return (self.value & o) == o
- def _set_flag(self, o, toggle):
+ def _set_flag(self, o: int, toggle: bool) -> None:
if toggle is True:
self.value |= o
elif toggle is False:
@@ -150,6 +170,7 @@ class SystemChannelFlags(BaseFlags):
representing the currently available flags. You should query
flags via the properties rather than using this raw value.
"""
+
__slots__ = ()
# For some reason the flags for system channels are "inverted"
@@ -157,10 +178,10 @@ class SystemChannelFlags(BaseFlags):
# Since this is counter-intuitive from an API perspective and annoying
# these will be inverted automatically
- def _has_flag(self, o):
+ def _has_flag(self, o: int) -> bool:
return (self.value & o) != o
- def _set_flag(self, o, toggle):
+ def _set_flag(self, o: int, toggle: bool) -> None:
if toggle is True:
self.value &= ~o
elif toggle is False:
@@ -210,6 +231,7 @@ class MessageFlags(BaseFlags):
representing the currently available flags. You should query
flags via the properties rather than using this raw value.
"""
+
__slots__ = ()
@flag_value
@@ -346,7 +368,7 @@ class PublicUserFlags(BaseFlags):
"""
return UserFlags.verified_bot_developer.value
- def all(self):
+ def all(self) -> List[UserFlags]:
"""List[:class:`UserFlags`]: Returns all public flags the user has."""
return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)]
@@ -393,7 +415,7 @@ class Intents(BaseFlags):
__slots__ = ()
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
@@ -401,7 +423,7 @@ class Intents(BaseFlags):
setattr(self, key, value)
@classmethod
- def all(cls):
+ def all(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything enabled."""
bits = max(cls.VALID_FLAGS.values()).bit_length()
value = (1 << bits) - 1
@@ -410,14 +432,14 @@ class Intents(BaseFlags):
return self
@classmethod
- def none(cls):
+ def none(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything disabled."""
self = cls.__new__(cls)
self.value = self.DEFAULT_VALUE
return self
@classmethod
- def default(cls):
+ def default(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything enabled
except :attr:`presences` and :attr:`members`.
"""
@@ -825,7 +847,7 @@ class MemberCacheFlags(BaseFlags):
__slots__ = ()
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: bool):
bits = max(self.VALID_FLAGS.values()).bit_length()
self.value = (1 << bits) - 1
for key, value in kwargs.items():
@@ -834,7 +856,7 @@ class MemberCacheFlags(BaseFlags):
setattr(self, key, value)
@classmethod
- def all(cls):
+ def all(cls: Type[MemberCacheFlags]) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with everything enabled."""
bits = max(cls.VALID_FLAGS.values()).bit_length()
value = (1 << bits) - 1
@@ -843,7 +865,7 @@ class MemberCacheFlags(BaseFlags):
return self
@classmethod
- def none(cls):
+ def none(cls: Type[MemberCacheFlags]) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with everything disabled."""
self = cls.__new__(cls)
self.value = self.DEFAULT_VALUE
@@ -886,7 +908,7 @@ class MemberCacheFlags(BaseFlags):
return 4
@classmethod
- def from_intents(cls, intents):
+ def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` based on
the currently selected :class:`Intents`.
@@ -914,7 +936,7 @@ class MemberCacheFlags(BaseFlags):
return self
- def _verify_intents(self, intents):
+ def _verify_intents(self, intents: Intents):
if self.online and not intents.presences:
raise ValueError('MemberCacheFlags.online requires Intents.presences enabled')