aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Law <[email protected]>2021-05-09 20:27:43 -0700
committerGitHub <[email protected]>2021-05-09 23:27:43 -0400
commit8bc489dba8b8c7ca9141e4e7f00a0e916a7c0269 (patch)
tree98c4483f1f74eb52b9f136eb5e3dd792cd6fb5b4
parentTypehint Widget (diff)
downloaddiscord.py-8bc489dba8b8c7ca9141e4e7f00a0e916a7c0269.tar.xz
discord.py-8bc489dba8b8c7ca9141e4e7f00a0e916a7c0269.zip
[tasks] Add support for explicit time parameter
-rw-r--r--discord/ext/tasks/__init__.py233
-rw-r--r--discord/utils.py14
2 files changed, 215 insertions, 32 deletions
diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py
index d2ae2c75..81e8dc79 100644
--- a/discord/ext/tasks/__init__.py
+++ b/discord/ext/tasks/__init__.py
@@ -31,6 +31,7 @@ import logging
import sys
import traceback
+from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
log = logging.getLogger(__name__)
@@ -39,17 +40,43 @@ __all__ = (
'loop',
)
+class SleepHandle:
+ __slots__ = ('future', 'loop', 'handle')
+
+ def __init__(self, dt, *, loop):
+ self.loop = loop
+ self.future = future = loop.create_future()
+ relative_delta = discord.utils.compute_timedelta(dt)
+ self.handle = loop.call_later(relative_delta, future.set_result, True)
+
+ def recalculate(self, dt):
+ self.handle.cancel()
+ relative_delta = discord.utils.compute_timedelta(dt)
+ self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
+
+ def wait(self):
+ return self.future
+
+ def done(self):
+ return self.future.done()
+
+ def cancel(self):
+ self.handle.cancel()
+ self.future.cancel()
+
+
class Loop:
"""A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`.
"""
- def __init__(self, coro, seconds, hours, minutes, count, reconnect, loop):
+ def __init__(self, coro, seconds, hours, minutes, time, count, reconnect, loop):
self.coro = coro
self.reconnect = reconnect
self.loop = loop
self.count = count
self._current_loop = 0
+ self._handle = None
self._task = None
self._injected = None
self._valid_exception = (
@@ -69,7 +96,7 @@ class Loop:
if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.')
- self.change_interval(seconds=seconds, minutes=minutes, hours=hours)
+ self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
self._last_iteration = None
self._next_iteration = None
@@ -87,14 +114,23 @@ class Loop:
else:
await coro(*args, **kwargs)
+ def _try_sleep_until(self, dt):
+ self._handle = SleepHandle(dt=dt, loop=self.loop)
+ return self._handle.wait()
+
async def _loop(self, *args, **kwargs):
backoff = ExponentialBackoff()
await self._call_loop_function('before_loop')
sleep_until = discord.utils.sleep_until
self._last_iteration_failed = False
- self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
+ if self._time is not None:
+ # the time index should be prepared every time the internal loop is started
+ self._prepare_time_index()
+ self._next_iteration = self._get_next_sleep_time()
+ else:
+ self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
try:
- await asyncio.sleep(0) # allows canceling in before_loop
+ await self._try_sleep_until(self._next_iteration)
while True:
if not self._last_iteration_failed:
self._last_iteration = self._next_iteration
@@ -102,22 +138,26 @@ class Loop:
try:
await self.coro(*args, **kwargs)
self._last_iteration_failed = False
- now = datetime.datetime.now(datetime.timezone.utc)
- if now > self._next_iteration:
- self._next_iteration = now
except self._valid_exception:
self._last_iteration_failed = True
if not self.reconnect:
raise
await asyncio.sleep(backoff.delay())
else:
- await sleep_until(self._next_iteration)
+ await self._try_sleep_until(self._next_iteration)
if self._stop_next_iteration:
return
+
+ now = datetime.datetime.now(datetime.timezone.utc)
+ if now > self._next_iteration:
+ self._prepare_time_index(now)
+ self._next_iteration = now
+
self._current_loop += 1
if self._current_loop == self.count:
break
+
except asyncio.CancelledError:
self._is_being_cancelled = True
raise
@@ -127,6 +167,7 @@ class Loop:
raise exc
finally:
await self._call_loop_function('after_loop')
+ self._handle.cancel()
self._is_being_cancelled = False
self._current_loop = 0
self._stop_next_iteration = False
@@ -136,8 +177,16 @@ class Loop:
if obj is None:
return self
- copy = Loop(self.coro, seconds=self.seconds, hours=self.hours, minutes=self.minutes,
- count=self.count, reconnect=self.reconnect, loop=self.loop)
+ copy = Loop(
+ self.coro,
+ seconds=self._seconds,
+ hours=self._hours,
+ minutes=self._minutes,
+ count=self.count,
+ time=self._time,
+ reconnect=self.reconnect,
+ loop=self.loop,
+ )
copy._injected = obj
copy._before_loop = self._before_loop
copy._after_loop = self._after_loop
@@ -146,6 +195,43 @@ class Loop:
return copy
@property
+ def seconds(self):
+ """Optional[:class:`float`]: Read-only value for the number of seconds
+ between each iteration. ``None`` if an explicit ``time`` value was passed instead.
+
+ .. versionadded:: 2.0
+ """
+ return self._seconds
+
+ @property
+ def minutes(self):
+ """Optional[:class:`float`]: Read-only value for the number of minutes
+ between each iteration. ``None`` if an explicit ``time`` value was passed instead.
+
+ .. versionadded:: 2.0
+ """
+ return self._minutes
+
+ @property
+ def hours(self):
+ """Optional[:class:`float`]: Read-only value for the number of hours
+ between each iteration. ``None`` if an explicit ``time`` value was passed instead.
+
+ .. versionadded:: 2.0
+ """
+ return self._hours
+
+ @property
+ def time(self):
+ """Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at.
+ ``None`` if relative times were passed instead.
+
+ .. versionadded:: 2.0
+ """
+ if self._time is not None:
+ return self._time.copy()
+
+ @property
def current_loop(self):
""":class:`int`: The current iteration of the loop."""
return self._current_loop
@@ -430,16 +516,63 @@ class Loop:
return coro
def _get_next_sleep_time(self):
- return self._last_iteration + datetime.timedelta(seconds=self._sleep)
-
- def change_interval(self, *, seconds=0, minutes=0, hours=0):
+ if self._sleep is not None:
+ return self._last_iteration + datetime.timedelta(seconds=self._sleep)
+
+ if self._time_index >= len(self._time):
+ self._time_index = 0
+ if self._current_loop == 0:
+ # if we're at the last index on the first iteration, we need to sleep until tomorrow
+ return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0])
+
+ next_time = self._time[self._time_index]
+
+ if self._current_loop == 0:
+ self._time_index += 1
+ return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
+
+ next_date = self._last_iteration
+ if self._time_index == 0:
+ # we can assume that the earliest time should be scheduled for "tomorrow"
+ next_date += datetime.timedelta(days=1)
+
+ self._time_index += 1
+ return datetime.datetime.combine(next_date, next_time)
+
+ def _prepare_time_index(self, now=None):
+ # now kwarg should be a datetime.datetime representing the time "now"
+ # to calculate the next time index from
+
+ # pre-condition: self._time is set
+ time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz()
+ for idx, time in enumerate(self._time):
+ if time >= time_now:
+ self._time_index = idx
+ break
+ else:
+ self._time_index = 0
+
+ def _get_time_parameter(self, time, *, inst=isinstance, dt=datetime.time, utc=datetime.timezone.utc):
+ if inst(time, dt):
+ ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
+ return [ret]
+ if not inst(time, Sequence):
+ raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
+ if not time:
+ raise ValueError('time parameter must not be an empty sequence.')
+
+ ret = []
+ for index, t in enumerate(time):
+ if not inst(t, dt):
+ raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
+ ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
+
+ ret = sorted(set(ret)) # de-dupe and sort times
+ return ret
+
+ def change_interval(self, *, seconds=0, minutes=0, hours=0, time=None):
"""Changes the interval for the sleep time.
- .. note::
-
- This only applies on the next loop iteration. If it is desirable for the change of interval
- to be applied right away, cancel the task with :meth:`cancel`.
-
.. versionadded:: 1.2
Parameters
@@ -450,23 +583,54 @@ class Loop:
The number of minutes between every iteration.
hours: :class:`float`
The number of hours between every iteration.
+ time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
+ The exact times to run this loop at. Either a non-empty list or a single
+ value of :class:`datetime.time` should be passed.
+ This cannot be used in conjunction with the relative time parameters.
+
+ .. versionadded:: 2.0
+
+ .. note::
+
+ Duplicate times will be ignored, and only run once.
Raises
-------
ValueError
An invalid value was given.
+ TypeError
+ An invalid value for the ``time`` parameter was passed, or the
+ ``time`` parameter was passed in conjunction with relative time parameters.
"""
- sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
- if sleep < 0:
- raise ValueError('Total number of seconds cannot be less than zero.')
+ if time is None:
+ sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
+ if sleep < 0:
+ raise ValueError('Total number of seconds cannot be less than zero.')
+
+ self._sleep = sleep
+ self._seconds = float(seconds)
+ self._hours = float(hours)
+ self._minutes = float(minutes)
+ self._time = None
+ else:
+ if any((seconds, minutes, hours)):
+ raise TypeError('Cannot mix explicit time with relative time')
+ self._time = self._get_time_parameter(time)
+ self._sleep = self._seconds = self._minutes = self._hours = None
+
+ if self.is_running():
+ if self._time is not None:
+ # prepare the next time index starting from after the last iteration
+ self._prepare_time_index(now=self._last_iteration)
+
+ self._next_iteration = self._get_next_sleep_time()
+ if not self._handle.done():
+ # the loop is sleeping, recalculate based on new interval
+ self._handle.recalculate(self._next_iteration)
- self._sleep = sleep
- self.seconds = seconds
- self.hours = hours
- self.minutes = minutes
-def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None):
+def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True, loop=None):
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -478,6 +642,19 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None
The number of minutes between every iteration.
hours: :class:`float`
The number of hours between every iteration.
+ time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
+ The exact times to run this loop at. Either a non-empty list or a single
+ value of :class:`datetime.time` should be passed. Timezones are supported.
+ If no timezone is given for the times, it is assumed to represent UTC time.
+
+ This cannot be used in conjunction with the relative time parameters.
+
+ .. note::
+
+ Duplicate times will be ignored, and only run once.
+
+ .. versionadded:: 2.0
+
count: Optional[:class:`int`]
The number of loops to do, ``None`` if it should be an
infinite loop.
@@ -494,7 +671,8 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None
ValueError
An invalid value was given.
TypeError
- The function was not a coroutine.
+ The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
+ or ``time`` parameter was passed in conjunction with relative time parameters.
"""
def decorator(func):
kwargs = {
@@ -502,6 +680,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None
'minutes': minutes,
'hours': hours,
'count': count,
+ 'time': time,
'reconnect': reconnect,
'loop': loop
}
diff --git a/discord/utils.py b/discord/utils.py
index 32d8feb4..22a0c40c 100644
--- a/discord/utils.py
+++ b/discord/utils.py
@@ -503,6 +503,13 @@ async def sane_wait_for(futures, *, timeout):
return done
+def compute_timedelta(dt: datetime.datetime):
+ if dt.tzinfo is None:
+ dt = dt.astimezone()
+ now = datetime.datetime.now(datetime.timezone.utc)
+ return max((dt - now).total_seconds(), 0)
+
+
async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]:
"""|coro|
@@ -520,11 +527,8 @@ async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Op
result: Any
If provided is returned to the caller when the coroutine completes.
"""
- if when.tzinfo is None:
- when = when.astimezone()
- now = datetime.datetime.now(datetime.timezone.utc)
- delta = (when - now).total_seconds()
- return await asyncio.sleep(max(delta, 0), result)
+ delta = compute_timedelta(when)
+ return await asyncio.sleep(delta, result)
def utcnow() -> datetime.datetime: