aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/ext/tasks/__init__.py95
-rw-r--r--docs/ext/tasks/index.rst48
2 files changed, 124 insertions, 19 deletions
diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py
index 91e49c3a..dab4b5ea 100644
--- a/discord/ext/tasks/__init__.py
+++ b/discord/ext/tasks/__init__.py
@@ -35,6 +35,9 @@ class Loop:
websockets.WebSocketProtocolError,
)
+ self._before_loop = None
+ self._after_loop = None
+
if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.')
@@ -47,25 +50,42 @@ class Loop:
raise ValueError('Total number of seconds cannot be less than zero.')
if not inspect.iscoroutinefunction(self.coro):
- raise TypeError('Expected coroutine function, not {0!r}.'.format(type(self.coro)))
+ raise TypeError('Expected coroutine function, not {0.__name__!r}.'.format(type(self.coro)))
- async def _loop(self, *args, **kwargs):
- backoff = ExponentialBackoff()
- while True:
- try:
- await self.coro(*args, **kwargs)
- except asyncio.CancelledError:
- return
- except self._valid_exception as exc:
- if not self.reconnect:
- raise
- await asyncio.sleep(backoff.delay())
+ async def _call_loop_function(self, name):
+ coro = getattr(self, '_' + name)
+ if coro is None:
+ return
+
+ if inspect.iscoroutinefunction(coro):
+ if self._injected is not None:
+ await coro(self._injected)
else:
- self._current_loop += 1
- if self._current_loop == self.count:
- return
+ await coro()
+ else:
+ await coro
- await asyncio.sleep(self._sleep)
+ async def _loop(self, *args, **kwargs):
+ backoff = ExponentialBackoff()
+ await self._call_loop_function('before_loop')
+ try:
+ while True:
+ try:
+ await self.coro(*args, **kwargs)
+ except asyncio.CancelledError:
+ break
+ except self._valid_exception as exc:
+ if not self.reconnect:
+ raise
+ await asyncio.sleep(backoff.delay())
+ else:
+ self._current_loop += 1
+ if self._current_loop == self.count:
+ break
+
+ await asyncio.sleep(self._sleep)
+ finally:
+ await self._call_loop_function('after_loop')
def __get__(self, obj, objtype):
if obj is None:
@@ -171,6 +191,49 @@ class Loop:
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task
+ def before_loop(self, coro):
+ """A function that also acts as a decorator to register a coroutine to be
+ called before the loop starts running. This is useful if you want to wait
+ for some bot state before the loop starts,
+ such as :meth:`discord.Client.wait_until_ready`.
+
+ Parameters
+ ------------
+ coro: :term:`py:awaitable`
+ The coroutine to register before the loop runs.
+
+ Raises
+ -------
+ TypeError
+ The function was not a coroutine.
+ """
+
+ if not (inspect.iscoroutinefunction(coro) or inspect.isawaitable(coro)):
+ raise TypeError('Expected coroutine or awaitable, received {0.__name__!r}.'.format(type(coro)))
+
+ self._before_loop = coro
+
+
+ def after_loop(self, coro):
+ """A function that also acts as a decorator to register a coroutine to be
+ called after the loop finished running.
+
+ Parameters
+ ------------
+ coro: :term:`py:awaitable`
+ The coroutine to register after the loop finishes.
+
+ Raises
+ -------
+ TypeError
+ The function was not a coroutine.
+ """
+
+ if not (inspect.iscoroutinefunction(coro) or inspect.isawaitable(coro)):
+ raise TypeError('Expected coroutine or awaitable, received {0.__name__!r}.'.format(type(coro)))
+
+ self._after_loop = coro
+
def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None):
"""A decorator that schedules a task in the background for you with
optional reconnect logic.
diff --git a/docs/ext/tasks/index.rst b/docs/ext/tasks/index.rst
index 94d1320d..93e7b3f8 100644
--- a/docs/ext/tasks/index.rst
+++ b/docs/ext/tasks/index.rst
@@ -66,14 +66,56 @@ Looping a certain amount of times before exiting:
async def slow_count():
print(slow_count.current_loop)
+ @slow_count.after_loop
+ async def after_slow_count():
+ print('done!')
+
slow_count.start()
-Doing something after a task finishes is as simple as using :meth:`asyncio.Task.add_done_callback`:
+Waiting until the bot is ready before the loop starts:
.. code-block:: python3
- afterwards = lambda f: print('done!')
- slow_count.get_task().add_done_callback(afterwards)
+ from discord.ext import tasks, commands
+
+ class MyCog(commands.Cog):
+ def __init__(self, bot):
+ self.index = 0
+ self.printer.before_loop(bot.wait_until_ready())
+ self.printer.start()
+
+ def cog_unload(self):
+ self.printer.cancel()
+
+ @tasks.loop(seconds=5.0)
+ async def printer(self):
+ print(self.index)
+ self.index += 1
+
+:meth:`~.tasks.Loop.before_loop` can be used as a decorator as well:
+
+.. code-block:: python3
+
+ from discord.ext import tasks, commands
+
+ class MyCog(commands.Cog):
+ def __init__(self, bot):
+ self.index = 0
+ self.bot = bot
+ self.printer.start()
+
+ def cog_unload(self):
+ self.printer.cancel()
+
+ @tasks.loop(seconds=5.0)
+ async def printer(self):
+ print(self.index)
+ self.index += 1
+
+ @printer.before_loop
+ async def before_printer(self):
+ print('waiting...')
+ await self.bot.wait_until_ready()
API Reference
---------------