aboutsummaryrefslogtreecommitdiff
path: root/discord/flags.py
diff options
context:
space:
mode:
authorRapptz <[email protected]>2019-12-20 21:18:12 -0500
committerRapptz <[email protected]>2019-12-20 21:18:31 -0500
commitf7687e0a684023ab91e5bb97d1f1cadbaf8e0413 (patch)
tree88b7968ecf096c96642d299a844fb84f3920700f /discord/flags.py
parentImplement discord.MessageFlags (diff)
downloaddiscord.py-f7687e0a684023ab91e5bb97d1f1cadbaf8e0413.tar.xz
discord.py-f7687e0a684023ab91e5bb97d1f1cadbaf8e0413.zip
Clean up flag code significantly.
This also fixes the False setting bug.
Diffstat (limited to 'discord/flags.py')
-rw-r--r--discord/flags.py168
1 files changed, 74 insertions, 94 deletions
diff --git a/discord/flags.py b/discord/flags.py
index 1c3ce536..dc89c46d 100644
--- a/discord/flags.py
+++ b/discord/flags.py
@@ -29,7 +29,7 @@ __all__ = (
'MessageFlags',
)
-class _flag_descriptor:
+class flag_value:
def __init__(self, func):
self.flag = func(None)
self.__doc__ = func.__doc__
@@ -40,19 +40,70 @@ class _flag_descriptor:
def __set__(self, instance, value):
instance._set_flag(self.flag, value)
-def fill_with_flags(cls):
- cls.VALID_FLAGS = {
- name: value.flag
- for name, value in cls.__dict__.items()
- if isinstance(value, _flag_descriptor)
- }
+def fill_with_flags(*, inverted=False):
+ def decorator(cls):
+ cls.VALID_FLAGS = {
+ name: value.flag
+ for name, value in cls.__dict__.items()
+ if isinstance(value, flag_value)
+ }
+
+ if inverted:
+ max_bits = max(cls.VALID_FLAGS.values()).bit_length()
+ cls.DEFAULT_VALUE = -1 + (2 ** max_bits)
+ else:
+ cls.DEFAULT_VALUE = 0
+
+ return cls
+ return decorator
+
+# n.b. flags must inherit from this and use the decorator above
+class BaseFlags:
+ __slots__ = ('value',)
+
+ def __init__(self, **kwargs):
+ self.value = self.DEFAULT_VALUE
+ for key, value in kwargs.items():
+ if key not in self.VALID_FLAGS:
+ raise TypeError('%r is not a valid flag name.' % key)
+ setattr(self, key, value)
+
+ @classmethod
+ def _from_value(cls, value):
+ self = cls.__new__(cls)
+ self.value = value
+ return self
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self.value == other.value
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return hash(self.value)
+
+ def __repr__(self):
+ return '<%s value=%s>' % (self.__class__.__name__, self.value)
- max_bits = max(cls.VALID_FLAGS.values()).bit_length()
- cls.ALL_OFF_VALUE = -1 + (2 ** max_bits)
- return cls
+ def __iter__(self):
+ for name, value in self.__class__.__dict__.items():
+ if isinstance(value, flag_value):
+ yield (name, self._has_flag(value.flag))
+
+ def _has_flag(self, o):
+ return (self.value & o) == o
+
+ def _set_flag(self, o, toggle):
+ if toggle is True:
+ self.value |= o
+ elif toggle is False:
+ self.value &= ~o
+ else:
+ raise TypeError('Value to set for %s must be a bool.' % self.__class__.__name__)
-@fill_with_flags
-class SystemChannelFlags:
+@fill_with_flags(inverted=True)
+class SystemChannelFlags(BaseFlags):
r"""Wraps up a Discord system channel flag value.
Similar to :class:`Permissions`\, the properties provided are two way.
@@ -85,37 +136,7 @@ class SystemChannelFlags:
representing the currently available flags. You should query
flags via the properties rather than using this raw value.
"""
- __slots__ = ('value',)
-
- def __init__(self, **kwargs):
- self.value = self.ALL_OFF_VALUE
- for key, value in kwargs.items():
- if key not in self.VALID_FLAGS:
- raise TypeError('%r is not a valid flag name.' % key)
- setattr(self, key, value)
-
- @classmethod
- def _from_value(cls, value):
- self = cls.__new__(cls)
- self.value = value
- return self
-
- def __eq__(self, other):
- return isinstance(other, SystemChannelFlags) and self.value == other.value
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __hash__(self):
- return hash(self.value)
-
- def __repr__(self):
- return '<SystemChannelFlags value=%s>' % self.value
-
- def __iter__(self):
- for name, value in self.__class__.__dict__.items():
- if isinstance(value, _flag_descriptor):
- yield (name, self._has_flag(value.flag))
+ __slots__ = ()
# For some reason the flags for system channels are "inverted"
# ergo, if they're set then it means "suppress" (off in the GUI toggle)
@@ -133,19 +154,19 @@ class SystemChannelFlags:
else:
raise TypeError('Value to set for SystemChannelFlags must be a bool.')
- @_flag_descriptor
+ @flag_value
def join_notifications(self):
""":class:`bool`: Returns ``True`` if the system channel is used for member join notifications."""
return 1
- @_flag_descriptor
+ @flag_value
def premium_subscriptions(self):
""":class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications."""
return 2
-@fill_with_flags
-class MessageFlags:
+@fill_with_flags()
+class MessageFlags(BaseFlags):
r"""Wraps up a Discord Message flag value.
See :class:`SystemChannelFlags`.
@@ -173,65 +194,24 @@ class MessageFlags:
representing the currently available flags. You should query
flags via the properties rather than using this raw value.
"""
- __slots__ = ('value',)
+ __slots__ = ()
- def __init__(self, **kwargs):
- self.value = 0
- for key, value in kwargs.items():
- if key not in self.VALID_FLAGS:
- raise TypeError('%r is not a valid flag name.' % key)
- setattr(self, key, value)
-
- @classmethod
- def _from_value(cls, value):
- self = cls.__new__(cls)
- self.value = value
- return self
-
- def __eq__(self, other):
- return isinstance(other, MessageFlags) and self.value == other.value
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __hash__(self):
- return hash(self.value)
-
- def __repr__(self):
- return '<MessageFlags value=%s>' % self.value
-
- def __iter__(self):
- for name, value in self.__class__.__dict__.items():
- if isinstance(value, _flag_descriptor):
- yield (name, self._has_flag(value.flag))
-
- def _has_flag(self, o):
- return (self.value & o) == o
-
- def _set_flag(self, o, toggle):
- if toggle is True:
- self.value |= o
- elif toggle is False:
- self.value &= o
- else:
- raise TypeError('Value to set for MessageFlags must be a bool.')
-
- @_flag_descriptor
+ @flag_value
def crossposted(self):
""":class:`bool`: Returns ``True`` if the message is the original crossposted message."""
return 1
- @_flag_descriptor
+ @flag_value
def is_crossposted(self):
""":class:`bool`: Returns ``True`` if the message was crossposted from another channel."""
return 2
- @_flag_descriptor
+ @flag_value
def suppress_embeds(self):
""":class:`bool`: Returns ``True`` if the message's embeds have been suppressed."""
return 4
-
- @_flag_descriptor
+
+ @flag_value
def source_message_deleted(self):
""":class:`bool`: Returns ``True`` if the source message for this crosspost has been deleted."""
return 8