diff options
| author | Rapptz <[email protected]> | 2020-01-21 03:30:56 -0500 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2020-01-21 03:30:56 -0500 |
| commit | 1a7b838d2adff881f825c8a16d14dab2207b311f (patch) | |
| tree | 1370ebabdabe6c1ed1270be57dac79219fa98a38 | |
| parent | [commands] Add max_concurrency decorator (diff) | |
| download | discord.py-1a7b838d2adff881f825c8a16d14dab2207b311f.tar.xz discord.py-1a7b838d2adff881f825c8a16d14dab2207b311f.zip | |
[commands] Refactor BucketType to not repeat in other places in code
| -rw-r--r-- | discord/ext/commands/cooldowns.py | 63 |
1 files changed, 24 insertions, 39 deletions
diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index fe763fb6..9efb5104 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -48,6 +48,25 @@ class BucketType(Enum): category = 5 role = 6 + def get_key(self, msg): + if self is BucketType.user: + return msg.author.id + elif self is BucketType.guild: + return (msg.guild or msg.author).id + elif self is BucketType.channel: + return msg.channel.id + elif self is BucketType.member: + return ((msg.guild and msg.guild.id), msg.author.id) + elif self is BucketType.category: + return (msg.channel.category or msg.channel).id + elif self is BucketType.role: + # we return the channel id of a private-channel as there are only roles in guilds + # and that yields the same result as for a guild with only the @everyone role + # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are + # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do + return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id + + class Cooldown: __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last') @@ -123,23 +142,7 @@ class CooldownMapping: return cls(Cooldown(rate, per, type)) def _bucket_key(self, msg): - bucket_type = self._cooldown.type - if bucket_type is BucketType.user: - return msg.author.id - elif bucket_type is BucketType.guild: - return (msg.guild or msg.author).id - elif bucket_type is BucketType.channel: - return msg.channel.id - elif bucket_type is BucketType.member: - return ((msg.guild and msg.guild.id), msg.author.id) - elif bucket_type is BucketType.category: - return (msg.channel.category or msg.channel).id - elif bucket_type is BucketType.role: - # we return the channel id of a private-channel as there are only roles in guilds - # and that yields the same result as for a guild with only the @everyone role - # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are - # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do - return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id + return self._cooldown.type.get_key(msg) def _verify_cache_integrity(self, current=None): # we want to delete all cache objects that haven't been used @@ -245,29 +248,11 @@ class MaxConcurrency: def __repr__(self): return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self) - def get_bucket(self, message): - bucket_type = self.per - if bucket_type is BucketType.default: - return 'global' - elif bucket_type is BucketType.user: - return message.author.id - elif bucket_type is BucketType.guild: - return (message.guild or message.author).id - elif bucket_type is BucketType.channel: - return message.channel.id - elif bucket_type is BucketType.member: - return ((message.guild and message.guild.id), message.author.id) - elif bucket_type is BucketType.category: - return (message.channel.category or message.channel).id - elif bucket_type is BucketType.role: - # we return the channel id of a private-channel as there are only roles in guilds - # and that yields the same result as for a guild with only the @everyone role - # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are - # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do - return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id + def get_key(self, message): + return self.per.get_key(message) async def acquire(self, message): - key = self.get_bucket(message) + key = self.get_key(message) try: sem = self._mapping[key] @@ -281,7 +266,7 @@ class MaxConcurrency: async def release(self, message): # Technically there's no reason for this function to be async # But it might be more useful in the future - key = self.get_bucket(message) + key = self.get_key(message) try: sem = self._mapping[key] |