aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2020-01-21 03:30:56 -0500
committerRapptz <[email protected]>2020-01-21 03:30:56 -0500
commit1a7b838d2adff881f825c8a16d14dab2207b311f (patch)
tree1370ebabdabe6c1ed1270be57dac79219fa98a38
parent[commands] Add max_concurrency decorator (diff)
downloaddiscord.py-1a7b838d2adff881f825c8a16d14dab2207b311f.tar.xz
discord.py-1a7b838d2adff881f825c8a16d14dab2207b311f.zip
[commands] Refactor BucketType to not repeat in other places in code
-rw-r--r--discord/ext/commands/cooldowns.py63
1 files changed, 24 insertions, 39 deletions
diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py
index fe763fb6..9efb5104 100644
--- a/discord/ext/commands/cooldowns.py
+++ b/discord/ext/commands/cooldowns.py
@@ -48,6 +48,25 @@ class BucketType(Enum):
category = 5
role = 6
+ def get_key(self, msg):
+ if self is BucketType.user:
+ return msg.author.id
+ elif self is BucketType.guild:
+ return (msg.guild or msg.author).id
+ elif self is BucketType.channel:
+ return msg.channel.id
+ elif self is BucketType.member:
+ return ((msg.guild and msg.guild.id), msg.author.id)
+ elif self is BucketType.category:
+ return (msg.channel.category or msg.channel).id
+ elif self is BucketType.role:
+ # we return the channel id of a private-channel as there are only roles in guilds
+ # and that yields the same result as for a guild with only the @everyone role
+ # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
+ # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
+ return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
+
+
class Cooldown:
__slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
@@ -123,23 +142,7 @@ class CooldownMapping:
return cls(Cooldown(rate, per, type))
def _bucket_key(self, msg):
- bucket_type = self._cooldown.type
- if bucket_type is BucketType.user:
- return msg.author.id
- elif bucket_type is BucketType.guild:
- return (msg.guild or msg.author).id
- elif bucket_type is BucketType.channel:
- return msg.channel.id
- elif bucket_type is BucketType.member:
- return ((msg.guild and msg.guild.id), msg.author.id)
- elif bucket_type is BucketType.category:
- return (msg.channel.category or msg.channel).id
- elif bucket_type is BucketType.role:
- # we return the channel id of a private-channel as there are only roles in guilds
- # and that yields the same result as for a guild with only the @everyone role
- # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
- # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
- return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
+ return self._cooldown.type.get_key(msg)
def _verify_cache_integrity(self, current=None):
# we want to delete all cache objects that haven't been used
@@ -245,29 +248,11 @@ class MaxConcurrency:
def __repr__(self):
return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
- def get_bucket(self, message):
- bucket_type = self.per
- if bucket_type is BucketType.default:
- return 'global'
- elif bucket_type is BucketType.user:
- return message.author.id
- elif bucket_type is BucketType.guild:
- return (message.guild or message.author).id
- elif bucket_type is BucketType.channel:
- return message.channel.id
- elif bucket_type is BucketType.member:
- return ((message.guild and message.guild.id), message.author.id)
- elif bucket_type is BucketType.category:
- return (message.channel.category or message.channel).id
- elif bucket_type is BucketType.role:
- # we return the channel id of a private-channel as there are only roles in guilds
- # and that yields the same result as for a guild with only the @everyone role
- # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
- # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
- return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id
+ def get_key(self, message):
+ return self.per.get_key(message)
async def acquire(self, message):
- key = self.get_bucket(message)
+ key = self.get_key(message)
try:
sem = self._mapping[key]
@@ -281,7 +266,7 @@ class MaxConcurrency:
async def release(self, message):
# Technically there's no reason for this function to be async
# But it might be more useful in the future
- key = self.get_bucket(message)
+ key = self.get_key(message)
try:
sem = self._mapping[key]