aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2020-01-21 03:26:41 -0500
committerRapptz <[email protected]>2020-01-21 03:26:41 -0500
commitbf84c633963a63f81e88a67df784900759cdb4e0 (patch)
tree9a4402d8485a77cf88b854bbe6b6583862e1c898
parent[tasks] Use new sleep_until util instead of internal function (diff)
downloaddiscord.py-bf84c633963a63f81e88a67df784900759cdb4e0.tar.xz
discord.py-bf84c633963a63f81e88a67df784900759cdb4e0.zip
[commands] Add max_concurrency decorator
-rw-r--r--discord/ext/commands/cooldowns.py130
-rw-r--r--discord/ext/commands/core.py52
-rw-r--r--discord/ext/commands/errors.py23
3 files changed, 204 insertions, 1 deletions
diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py
index 5e7f2aa3..fe763fb6 100644
--- a/discord/ext/commands/cooldowns.py
+++ b/discord/ext/commands/cooldowns.py
@@ -26,13 +26,17 @@ DEALINGS IN THE SOFTWARE.
from discord.enums import Enum
import time
+import asyncio
+from collections import deque
from ...abc import PrivateChannel
+from .errors import MaxConcurrencyReached
__all__ = (
'BucketType',
'Cooldown',
'CooldownMapping',
+ 'MaxConcurrency',
)
class BucketType(Enum):
@@ -163,3 +167,129 @@ class CooldownMapping:
def update_rate_limit(self, message, current=None):
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
+
+class _Semaphore:
+ """This class is a version of a semaphore.
+
+ If you're wondering why asyncio.Semaphore isn't being used,
+ it's because it doesn't expose the internal value. This internal
+ value is necessary because I need to support both `wait=True` and
+ `wait=False`.
+
+ An asyncio.Queue could have been used to do this as well -- but it
+ not as inefficient since internally that uses two queues and is a bit
+ overkill for what is basically a counter.
+ """
+
+ __slots__ = ('value', 'loop', '_waiters')
+
+ def __init__(self, number):
+ self.value = number
+ self.loop = asyncio.get_event_loop()
+ self._waiters = deque()
+
+ def __repr__(self):
+ return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters))
+
+ def locked(self):
+ return self.value == 0
+
+ def wake_up(self):
+ while self._waiters:
+ future = self._waiters.popleft()
+ if not future.done():
+ future.set_result(None)
+ return
+
+ async def acquire(self, *, wait=False):
+ if not wait and self.value <= 0:
+ # signal that we're not acquiring
+ return False
+
+ while self.value <= 0:
+ future = self.loop.create_future()
+ self._waiters.append(future)
+ try:
+ await future
+ except:
+ future.cancel()
+ if self.value > 0 and not future.cancelled():
+ self.wake_up()
+ raise
+
+ self.value -= 1
+ return True
+
+ def release(self):
+ self.value += 1
+ self.wake_up()
+
+class MaxConcurrency:
+ __slots__ = ('number', 'per', 'wait', '_mapping')
+
+ def __init__(self, number, *, per, wait):
+ self._mapping = {}
+ self.per = per
+ self.number = number
+ self.wait = wait
+
+ if number <= 0:
+ raise ValueError('max_concurrency \'number\' cannot be less than 1')
+
+ if not isinstance(per, BucketType):
+ raise TypeError('max_concurrency \'per\' must be of type BucketType not %r' % type(per))
+
+ def copy(self):
+ return self.__class__(self.number, per=self.per, wait=self.wait)
+
+ def __repr__(self):
+ return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
+
+ def get_bucket(self, message):
+ bucket_type = self.per
+ if bucket_type is BucketType.default:
+ return 'global'
+ elif bucket_type is BucketType.user:
+ return message.author.id
+ elif bucket_type is BucketType.guild:
+ return (message.guild or message.author).id
+ elif bucket_type is BucketType.channel:
+ return message.channel.id
+ elif bucket_type is BucketType.member:
+ return ((message.guild and message.guild.id), message.author.id)
+ elif bucket_type is BucketType.category:
+ return (message.channel.category or message.channel).id
+ elif bucket_type is BucketType.role:
+ # we return the channel id of a private-channel as there are only roles in guilds
+ # and that yields the same result as for a guild with only the @everyone role
+ # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
+ # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
+ return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id
+
+ async def acquire(self, message):
+ key = self.get_bucket(message)
+
+ try:
+ sem = self._mapping[key]
+ except KeyError:
+ self._mapping[key] = sem = _Semaphore(self.number)
+
+ acquired = await sem.acquire(wait=self.wait)
+ if not acquired:
+ raise MaxConcurrencyReached(self.number, self.per)
+
+ async def release(self, message):
+ # Technically there's no reason for this function to be async
+ # But it might be more useful in the future
+ key = self.get_bucket(message)
+
+ try:
+ sem = self._mapping[key]
+ except KeyError:
+ # ...? peculiar
+ return
+ else:
+ sem.release()
+
+ if sem.value >= self.number:
+ del self._mapping[key]
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index aa4343f5..e587d3a7 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -33,7 +33,7 @@ import datetime
import discord
from .errors import *
-from .cooldowns import Cooldown, BucketType, CooldownMapping
+from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency
from . import converter as converters
from ._types import _BaseCommand
from .cog import Cog
@@ -53,6 +53,7 @@ __all__ = (
'bot_has_permissions',
'bot_has_any_role',
'cooldown',
+ 'max_concurrency',
'dm_only',
'guild_only',
'is_owner',
@@ -90,6 +91,9 @@ def hooked_wrapped_callback(command, ctx, coro):
ctx.command_failed = True
raise CommandInvokeError(exc) from exc
finally:
+ if command._max_concurrency is not None:
+ await command._max_concurrency.release(ctx)
+
await command.call_after_hooks(ctx)
return ret
return wrapped
@@ -248,6 +252,13 @@ class Command(_BaseCommand):
finally:
self._buckets = CooldownMapping(cooldown)
+ try:
+ max_concurrency = func.__commands_max_concurrency__
+ except AttributeError:
+ max_concurrency = kwargs.get('max_concurrency')
+ finally:
+ self._max_concurrency = max_concurrency
+
self.ignore_extra = kwargs.get('ignore_extra', True)
self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False)
self.cog = None
@@ -331,6 +342,9 @@ class Command(_BaseCommand):
other.checks = self.checks.copy()
if self._buckets.valid and not other._buckets.valid:
other._buckets = self._buckets.copy()
+ if self._max_concurrency != other._max_concurrency:
+ other._max_concurrency = self._max_concurrency.copy()
+
try:
other.on_error = self.on_error
except AttributeError:
@@ -718,6 +732,9 @@ class Command(_BaseCommand):
self._prepare_cooldowns(ctx)
await self._parse_arguments(ctx)
+ if self._max_concurrency is not None:
+ await self._max_concurrency.acquire(ctx)
+
await self.call_before_hooks(ctx)
def is_on_cooldown(self, ctx):
@@ -1800,3 +1817,36 @@ def cooldown(rate, per, type=BucketType.default):
func.__commands_cooldown__ = Cooldown(rate, per, type)
return func
return decorator
+
+def max_concurrency(number, per=BucketType.default, *, wait=False):
+ """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
+
+ This enables you to only allow a certain number of command invocations at the same time,
+ for example if a command takes too long or if only one user can use it at a time. This
+ differs from a cooldown in that there is no set waiting period or token bucket -- only
+ a set number of people can run the command.
+
+ .. versionadded:: 1.3.0
+
+ Parameters
+ -------------
+ number: :class:`int`
+ The maximum number of invocations of this command that can be running at the same time.
+ per: :class:`.BucketType`
+ The bucket that this concurrency is based on, e.g. ``BucketType.guild`` would allow
+ it to be used up to ``number`` times per guild.
+ wait: :class:`bool`
+ Whether the command should wait for the queue to be over. If this is set to ``False``
+ then instead of waiting until the command can run again, the command raises
+ :exc:`.MaxConcurrencyReached` to its error handler. If this is set to ``True``
+ then the command waits until it can be executed.
+ """
+
+ def decorator(func):
+ value = MaxConcurrency(number, per=per, wait=wait)
+ if isinstance(func, Command):
+ func._max_concurrency = value
+ else:
+ func.__commands_max_concurrency__ = value
+ return func
+ return decorator
diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py
index 6087c1df..0d5e0d0f 100644
--- a/discord/ext/commands/errors.py
+++ b/discord/ext/commands/errors.py
@@ -41,6 +41,7 @@ __all__ = (
'TooManyArguments',
'UserInputError',
'CommandOnCooldown',
+ 'MaxConcurrencyReached',
'NotOwner',
'MissingRole',
'BotMissingRole',
@@ -240,6 +241,28 @@ class CommandOnCooldown(CommandError):
self.retry_after = retry_after
super().__init__('You are on cooldown. Try again in {:.2f}s'.format(retry_after))
+class MaxConcurrencyReached(CommandError):
+ """Exception raised when the command being invoked has reached its maximum concurrency.
+
+ This inherits from :exc:`CommandError`.
+
+ Attributes
+ ------------
+ number: :class:`int`
+ The maximum number of concurrent invokers allowed.
+ per: :class:`BucketType`
+ The bucket type passed to the :func:`.max_concurrency` decorator.
+ """
+
+ def __init__(self, number, per):
+ self.number = number
+ self.per = per
+ name = per.name
+ suffix = 'per %s' % name if per.name != 'default' else 'globally'
+ plural = '%s times %s' if number > 1 else '%s time %s'
+ fmt = plural % (number, suffix)
+ super().__init__('Too many people using this command. It can only be used {}.'.format(fmt))
+
class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command.