aboutsummaryrefslogtreecommitdiff
path: root/discord/ext/commands/cooldowns.py
diff options
context:
space:
mode:
authorDan Hess <[email protected]>2021-04-09 23:30:01 -0800
committerGitHub <[email protected]>2021-04-10 03:30:01 -0400
commitf2d5ab6f8051d33a26bbf11abbe688c0821c4a0e (patch)
treed25a818501646dda90e603cfd5580f756c66a68b /discord/ext/commands/cooldowns.py
parentFix all warnings with Sphinx (diff)
downloaddiscord.py-f2d5ab6f8051d33a26bbf11abbe688c0821c4a0e.tar.xz
discord.py-f2d5ab6f8051d33a26bbf11abbe688c0821c4a0e.zip
[commands] Provide a dynamic cooldown system
Diffstat (limited to 'discord/ext/commands/cooldowns.py')
-rw-r--r--discord/ext/commands/cooldowns.py42
1 files changed, 30 insertions, 12 deletions
diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py
index fc438c9f..cb0f75cf 100644
--- a/discord/ext/commands/cooldowns.py
+++ b/discord/ext/commands/cooldowns.py
@@ -34,6 +34,7 @@ __all__ = (
'BucketType',
'Cooldown',
'CooldownMapping',
+ 'DynamicCooldownMapping',
'MaxConcurrency',
)
@@ -69,19 +70,15 @@ class BucketType(Enum):
class Cooldown:
- __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
+ __slots__ = ('rate', 'per', '_window', '_tokens', '_last')
- def __init__(self, rate, per, type):
+ def __init__(self, rate, per):
self.rate = int(rate)
self.per = float(per)
- self.type = type
self._window = 0.0
self._tokens = self.rate
self._last = 0.0
- if not callable(self.type):
- raise TypeError('Cooldown type must be a BucketType or callable')
-
def get_tokens(self, current=None):
if not current:
current = time.time()
@@ -128,15 +125,19 @@ class Cooldown:
self._last = 0.0
def copy(self):
- return Cooldown(self.rate, self.per, self.type)
+ return Cooldown(self.rate, self.per)
def __repr__(self):
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
class CooldownMapping:
- def __init__(self, original):
+ def __init__(self, original, type):
+ if not callable(type):
+ raise TypeError('Cooldown type must be a BucketType or callable')
+
self._cache = {}
self._cooldown = original
+ self._type = type
def copy(self):
ret = CooldownMapping(self._cooldown)
@@ -152,7 +153,7 @@ class CooldownMapping:
return cls(Cooldown(rate, per, type))
def _bucket_key(self, msg):
- return self._cooldown.type(msg)
+ return self._type(msg)
def _verify_cache_integrity(self, current=None):
# we want to delete all cache objects that haven't been used
@@ -163,15 +164,19 @@ class CooldownMapping:
for k in dead_keys:
del self._cache[k]
+ def create_bucket(self, message):
+ return self._cooldown.copy()
+
def get_bucket(self, message, current=None):
- if self._cooldown.type is BucketType.default:
+ if self._type is BucketType.default:
return self._cooldown
self._verify_cache_integrity(current)
key = self._bucket_key(message)
if key not in self._cache:
- bucket = self._cooldown.copy()
- self._cache[key] = bucket
+ bucket = self.create_bucket(message)
+ if bucket is not None:
+ self._cache[key] = bucket
else:
bucket = self._cache[key]
@@ -181,6 +186,19 @@ class CooldownMapping:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
+class DynamicCooldownMapping(CooldownMapping):
+
+ def __init__(self, factory, type):
+ super().__init__(None, type)
+ self._factory = factory
+
+ @property
+ def valid(self):
+ return True
+
+ def create_bucket(self, message):
+ return self._factory(message)
+
class _Semaphore:
"""This class is a version of a semaphore.