aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/ext/tasks/__init__.py209
-rw-r--r--docs/ext/tasks/index.rst82
-rw-r--r--docs/index.rst1
3 files changed, 292 insertions, 0 deletions
diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py
new file mode 100644
index 00000000..623513d1
--- /dev/null
+++ b/discord/ext/tasks/__init__.py
@@ -0,0 +1,209 @@
+import asyncio
+import aiohttp
+import websockets
+import discord
+import inspect
+
+from discord.backoff import ExponentialBackoff
+
+MAX_ASYNCIO_SECONDS = 3456000
+
+class Loop:
+ """A background task helper that abstracts the loop and reconnection logic for you.
+
+ The main interface to create this is through :func:`loop`.
+ """
+ def __init__(self, coro, seconds, hours, minutes, count, reconnect, loop):
+ self.coro = coro
+ self.seconds = seconds
+ self.hours = hours
+ self.minutes = minutes
+ self.reconnect = reconnect
+ self.loop = loop or asyncio.get_event_loop()
+ self.count = count
+ self._current_loop = 0
+ self._task = None
+ self._injected = None
+ self._valid_exception = (
+ OSError,
+ discord.HTTPException,
+ discord.GatewayNotFound,
+ discord.ConnectionClosed,
+ aiohttp.ClientError,
+ asyncio.TimeoutError,
+ websockets.InvalidHandshake,
+ websockets.WebSocketProtocolError,
+ )
+
+ if self.count is not None and self.count <= 0:
+ raise ValueError('count must be greater than 0 or None.')
+
+ self._sleep = sleep = self.seconds + (self.minutes * 60.0) + (self.hours * 3600.0)
+ if sleep >= MAX_ASYNCIO_SECONDS:
+ raise ValueError('Total time exceeds asyncio imposed limit of {0} seconds.'.format(MAX_ASYNCIO_SECONDS))
+
+ if not inspect.iscoroutinefunction(self.coro):
+ raise TypeError('Expected coroutine function, not {0!r}.'.format(type(self.coro)))
+
+ async def _loop(self, *args, **kwargs):
+ backoff = ExponentialBackoff()
+ while True:
+ try:
+ await self.coro(*args, **kwargs)
+ except asyncio.CancelledError:
+ return
+ except self._valid_exception as exc:
+ if not self.reconnect:
+ raise
+ await asyncio.sleep(backoff.delay())
+ else:
+ self._current_loop += 1
+ if self._current_loop == self.count:
+ return
+
+ await asyncio.sleep(self._sleep)
+
+ def __get__(self, obj, objtype):
+ if obj is None:
+ return self
+ self._injected = obj
+ return self
+
+ @property
+ def current_loop(self):
+ """:class:`int`: The current iteration of the loop."""
+ return self._current_loop
+
+
+ def run(self, *args, **kwargs):
+ r"""Runs the internal task in the event loop.
+
+ Parameters
+ ------------
+ \*args
+ The arguments to to use.
+ \*\*kwargs
+ The keyword arguments to use.
+
+ Raises
+ --------
+ RuntimeError
+ A task has already been launched.
+
+ Returns
+ ---------
+ :class:`asyncio.Task`
+ The task that has been registered.
+ """
+
+ if self._task is not None:
+ raise RuntimeError('Task is already launched.')
+
+ if self._injected is not None:
+ args = (self._injected, *args)
+
+ self._task = self.loop.create_task(self._loop(*args, **kwargs))
+ return self._task
+
+ def cancel(self):
+ """Cancels the internal task, if any are running."""
+ if self._task:
+ self._task.cancel()
+
+ def add_exception_type(self, exc):
+ r"""Adds an exception type to be handled during the reconnect logic.
+
+ By default the exception types handled are those handled by
+ :meth:`discord.Client.connect`\, which includes a lot of internet disconnection
+ errors.
+
+ This function is useful if you're interacting with a 3rd party library that
+ raises its own set of exceptions.
+
+ Parameters
+ ------------
+ exc: Type[:class:`BaseException`]
+ The exception class to handle.
+
+ Raises
+ --------
+ TypeError
+ The exception passed is either not a class or not inherited from :class:`BaseException`.
+ """
+
+ if not inspect.isclass(exc):
+ raise TypeError('{0!r} must be a class.'.format(exc))
+ if not issubclass(exc, BaseException):
+ raise TypeError('{0!r} must inherit from BaseException.'.format(exc))
+
+ self._valid_exception = tuple(*self._valid_exception, exc)
+
+ def clear_exception_types(self):
+ """Removes all exception types that are handled.
+
+ .. note::
+
+ This operation obviously cannot be undone!
+ """
+ self._valid_exception = tuple()
+
+ def remove_exception_type(self, exc):
+ """Removes an exception type from being handled during the reconnect logic.
+
+ Parameters
+ ------------
+ exc: Type[:class:`BaseException`]
+ The exception class to handle.
+
+ Returns
+ ---------
+ :class:`bool`
+ Whether it was successfully removed.
+ """
+ old_length = len(self._valid_exception)
+ self._valid_exception = tuple(x for x in self._valid_exception if x is not exc)
+ return len(self._valid_exception) != old_length
+
+ def get_task(self):
+ """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
+ return self._task
+
+def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None):
+ """A decorator that schedules a task in the background for you with
+ optional reconnect logic.
+
+ Parameters
+ ------------
+ seconds: :class:`float`
+ The number of seconds between every iteration.
+ minutes: :class:`float`
+ The number of minutes between every iteration.
+ hours: :class:`float`
+ The number of hours between every iteration.
+ count: Optional[:class:`int`]
+ The number of loops to do, ``None`` if it should be an
+ infinite loop.
+ reconnect: :class:`bool`
+ Whether to handle errors and restart the task
+ using an exponential back-off algorithm similar to the
+ one used in :meth:`discord.Client.connect`.
+ loop: :class:`asyncio.AbstractEventLoop`
+ The loop to use to register the task, if not given
+ defaults to :func:`asyncio.get_event_loop`.
+
+ Raises
+ --------
+ ValueError
+ An invalid value was given.
+ TypeError
+ The function was not a coroutine.
+
+ Returns
+ ---------
+ :class:`Loop`
+ The loop helper that handles the background task.
+ """
+ def decorator(func):
+ return Loop(func, seconds=seconds, minutes=minutes, hours=hours,
+ count=count, reconnect=reconnect, loop=loop)
+ return decorator
diff --git a/docs/ext/tasks/index.rst b/docs/ext/tasks/index.rst
new file mode 100644
index 00000000..bb242679
--- /dev/null
+++ b/docs/ext/tasks/index.rst
@@ -0,0 +1,82 @@
+``discord.ext.tasks`` -- asyncio.Task helpers
+====================================================
+
+One of the most common operations when making a bot is having a loop run in the background at a specified interval. This pattern is very common but has a lot of things you need to look out for:
+
+- How do I handle :exc:`asyncio.CancelledError`?
+- What do I do if the internet goes out?
+- What is the maximum number of seconds I can sleep anyway?
+
+The goal of this discord.py extension is to abstract all these worries away from you.
+
+Recipes
+---------
+
+A simple background task in a :class:`~discord.ext.commands.Cog`:
+
+.. code-block:: python3
+
+ from discord.ext import tasks, commands
+
+ class MyCog(commands.Cog):
+ def __init__(self):
+ self.index = 0
+ self.printer.run()
+
+ def cog_unload(self):
+ self.printer.cancel()
+
+ @tasks.loop(seconds=5.0)
+ async def printer(self):
+ print(self.index)
+ self.index += 1
+
+Adding an exception to handle during reconnect:
+
+.. code-block:: python3
+
+ import asyncpg
+ from discord.ext import tasks, commands
+
+ class MyCog(commands.Cog):
+ def __init__(self, bot):
+ self.bot = bot
+ self.data = []
+ self.batch_update.add_exception_type(asyncpg.PostgresConnectionError)
+ self.batch_update.run()
+
+ def cog_unload(self):
+ self.batch_update.cancel()
+
+ @tasks.loop(minutes=5.0)
+ async def batch_update(self):
+ async with self.bot.pool.acquire() as con:
+ # batch update here...
+ pass
+
+Looping a certain amount of times before exiting:
+
+.. code-block:: python3
+
+ from discord.ext import tasks
+
+ @tasks.loop(seconds=5.0, count=5)
+ async def slow_count():
+ print(slow_count.current_loop)
+
+ slow_count.run()
+
+Doing something after a task finishes is as simple as using :meth:`asyncio.Task.add_done_callback`:
+
+.. code-block:: python3
+
+ afterwards = lambda f: print('done!')
+ slow_count.get_task().add_done_callback(afterwards)
+
+API Reference
+---------------
+
+.. autoclass:: discord.ext.tasks.Loop()
+ :members:
+
+.. autofunction:: discord.ext.tasks.loop
diff --git a/docs/index.rst b/docs/index.rst
index 4cb19364..77463d9f 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -39,6 +39,7 @@ Extensions
:maxdepth: 3
ext/commands/index.rst
+ ext/tasks/index.rst
Additional Information