diff options
| author | Dan Hess <[email protected]> | 2021-04-09 23:30:01 -0800 |
|---|---|---|
| committer | GitHub <[email protected]> | 2021-04-10 03:30:01 -0400 |
| commit | f2d5ab6f8051d33a26bbf11abbe688c0821c4a0e (patch) | |
| tree | d25a818501646dda90e603cfd5580f756c66a68b /discord/ext/commands/cooldowns.py | |
| parent | Fix all warnings with Sphinx (diff) | |
| download | discord.py-f2d5ab6f8051d33a26bbf11abbe688c0821c4a0e.tar.xz discord.py-f2d5ab6f8051d33a26bbf11abbe688c0821c4a0e.zip | |
[commands] Provide a dynamic cooldown system
Diffstat (limited to 'discord/ext/commands/cooldowns.py')
| -rw-r--r-- | discord/ext/commands/cooldowns.py | 42 |
1 files changed, 30 insertions, 12 deletions
diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index fc438c9f..cb0f75cf 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -34,6 +34,7 @@ __all__ = ( 'BucketType', 'Cooldown', 'CooldownMapping', + 'DynamicCooldownMapping', 'MaxConcurrency', ) @@ -69,19 +70,15 @@ class BucketType(Enum): class Cooldown: - __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last') + __slots__ = ('rate', 'per', '_window', '_tokens', '_last') - def __init__(self, rate, per, type): + def __init__(self, rate, per): self.rate = int(rate) self.per = float(per) - self.type = type self._window = 0.0 self._tokens = self.rate self._last = 0.0 - if not callable(self.type): - raise TypeError('Cooldown type must be a BucketType or callable') - def get_tokens(self, current=None): if not current: current = time.time() @@ -128,15 +125,19 @@ class Cooldown: self._last = 0.0 def copy(self): - return Cooldown(self.rate, self.per, self.type) + return Cooldown(self.rate, self.per) def __repr__(self): return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>' class CooldownMapping: - def __init__(self, original): + def __init__(self, original, type): + if not callable(type): + raise TypeError('Cooldown type must be a BucketType or callable') + self._cache = {} self._cooldown = original + self._type = type def copy(self): ret = CooldownMapping(self._cooldown) @@ -152,7 +153,7 @@ class CooldownMapping: return cls(Cooldown(rate, per, type)) def _bucket_key(self, msg): - return self._cooldown.type(msg) + return self._type(msg) def _verify_cache_integrity(self, current=None): # we want to delete all cache objects that haven't been used @@ -163,15 +164,19 @@ class CooldownMapping: for k in dead_keys: del self._cache[k] + def create_bucket(self, message): + return self._cooldown.copy() + def get_bucket(self, message, current=None): - if self._cooldown.type is BucketType.default: + if self._type is BucketType.default: return self._cooldown self._verify_cache_integrity(current) key = self._bucket_key(message) if key not in self._cache: - bucket = self._cooldown.copy() - self._cache[key] = bucket + bucket = self.create_bucket(message) + if bucket is not None: + self._cache[key] = bucket else: bucket = self._cache[key] @@ -181,6 +186,19 @@ class CooldownMapping: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) +class DynamicCooldownMapping(CooldownMapping): + + def __init__(self, factory, type): + super().__init__(None, type) + self._factory = factory + + @property + def valid(self): + return True + + def create_bucket(self, message): + return self._factory(message) + class _Semaphore: """This class is a version of a semaphore. |