aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-07-02 09:17:32 -0400
committerRapptz <[email protected]>2021-07-02 09:17:32 -0400
commitd7ed88459341527a69e8d8a7a77ad1d9c5e2832c (patch)
tree92b461e4f15d0dbeeaa5a8f9a9a3c821c55bba49 /discord
parent[commands] Add back CommandOnCooldown.type (diff)
downloaddiscord.py-d7ed88459341527a69e8d8a7a77ad1d9c5e2832c.tar.xz
discord.py-d7ed88459341527a69e8d8a7a77ad1d9c5e2832c.zip
Rework view timeouts to work as documented
Diffstat (limited to 'discord')
-rw-r--r--discord/ui/view.py97
1 files changed, 63 insertions, 34 deletions
diff --git a/discord/ui/view.py b/discord/ui/view.py
index e6f8df34..27c76594 100644
--- a/discord/ui/view.py
+++ b/discord/ui/view.py
@@ -162,14 +162,32 @@ class View:
self.__weights = _ViewWeights(self.children)
loop = asyncio.get_running_loop()
- self.id = os.urandom(16).hex()
- self._cancel_callback: Optional[Callable[[View], None]] = None
- self._timeout_handler: Optional[asyncio.TimerHandle] = None
- self._stopped = loop.create_future()
+ self.id: str = os.urandom(16).hex()
+ self.__cancel_callback: Optional[Callable[[View], None]] = None
+ self.__timeout_expiry: Optional[float] = None
+ self.__timeout_task: Optional[asyncio.Task[None]] = None
+ self.__stopped: asyncio.Future[bool] = loop.create_future()
def __repr__(self) -> str:
return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>'
+ async def __timeout_task_impl(self) -> None:
+ while True:
+ # Guard just in case someone changes the value of the timeout at runtime
+ if self.timeout is None:
+ return
+
+ if self.__timeout_expiry is None:
+ return self._dispatch_timeout()
+
+ # Check if we've elapsed our currently set timeout
+ now = time.monotonic()
+ if now >= self.__timeout_expiry:
+ return self._dispatch_timeout()
+
+ # Wait N seconds to see if timeout data has been refreshed
+ await asyncio.sleep(self.__timeout_expiry - now)
+
def to_components(self) -> List[Dict[str, Any]]:
def key(item: Item) -> int:
return item._rendered_row or 0
@@ -328,8 +346,11 @@ class View:
print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
- async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction):
+ async def _scheduled_task(self, item: Item, interaction: Interaction):
try:
+ if self.timeout:
+ self.__timeout_expiry = time.monotonic() + self.timeout
+
allow = await self.interaction_check(interaction)
if not allow:
return
@@ -340,21 +361,28 @@ class View:
except Exception as e:
return await self.on_error(e, item, interaction)
- def _start_listening(self, store: ViewStore) -> None:
- self._cancel_callback = partial(store.remove_view)
+ def _start_listening_from_store(self, store: ViewStore) -> None:
+ self.__cancel_callback = partial(store.remove_view)
if self.timeout:
loop = asyncio.get_running_loop()
- self._timeout_handler = loop.call_later(self.timeout, self.dispatch_timeout)
+ if self.__timeout_task is not None:
+ self.__timeout_task.cancel()
+
+ self.__timeout_expiry = time.monotonic() + self.timeout
+ self.__timeout_task = loop.create_task(self.__timeout_task_impl())
- def dispatch_timeout(self):
- if self._stopped.done():
+ def _dispatch_timeout(self):
+ if self.__stopped.done():
return
- self._stopped.set_result(True)
+ self.__stopped.set_result(True)
asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}')
- def dispatch(self, state: Any, item: Item, interaction: Interaction):
- asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
+ def _dispatch_item(self, item: Item, interaction: Interaction):
+ if self.__stopped.done():
+ return
+
+ asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
def refresh(self, components: List[Component]):
# This is pretty hacky at the moment
@@ -382,23 +410,25 @@ class View:
This operation cannot be undone.
"""
- if not self._stopped.done():
- self._stopped.set_result(False)
+ if not self.__stopped.done():
+ self.__stopped.set_result(False)
- if self._timeout_handler:
- self._timeout_handler.cancel()
+ self.__timeout_expiry = None
+ if self.__timeout_task is not None:
+ self.__timeout_task.cancel()
+ self.__timeout_task = None
- if self._cancel_callback:
- self._cancel_callback(self)
- self._cancel_callback = None
+ if self.__cancel_callback:
+ self.__cancel_callback(self)
+ self.__cancel_callback = None
def is_finished(self) -> bool:
""":class:`bool`: Whether the view has finished interacting."""
- return self._stopped.done()
+ return self.__stopped.done()
def is_dispatching(self) -> bool:
""":class:`bool`: Whether the view has been added for dispatching purposes."""
- return self._cancel_callback is not None
+ return self.__cancel_callback is not None
def is_persistent(self) -> bool:
""":class:`bool`: Whether the view is set up as persistent.
@@ -420,13 +450,13 @@ class View:
If ``True``, then the view timed out. If ``False`` then
the view finished normally.
"""
- return await self._stopped
+ return await self.__stopped
class ViewStore:
def __init__(self, state: ConnectionState):
- # (component_type, custom_id): (View, Item, Expiry)
- self._views: Dict[Tuple[int, str], Tuple[View, Item, Optional[float]]] = {}
+ # (component_type, custom_id): (View, Item)
+ self._views: Dict[Tuple[int, str], Tuple[View, Item]] = {}
# message_id: View
self._synced_message_views: Dict[int, View] = {}
self._state: ConnectionState = state
@@ -436,7 +466,7 @@ class ViewStore:
# fmt: off
views = {
view.id: view
- for (_, (view, _, _)) in self._views.items()
+ for (_, (view, _)) in self._views.items()
if view.is_persistent()
}
# fmt: on
@@ -445,8 +475,8 @@ class ViewStore:
def __verify_integrity(self):
to_remove: List[Tuple[int, str]] = []
now = time.monotonic()
- for (k, (_, _, expiry)) in self._views.items():
- if expiry is not None and now >= expiry:
+ for (k, (view, _)) in self._views.items():
+ if view.is_finished():
to_remove.append(k)
for k in to_remove:
@@ -455,11 +485,10 @@ class ViewStore:
def add_view(self, view: View, message_id: Optional[int] = None):
self.__verify_integrity()
- expiry = view._expires_at
- view._start_listening(self)
+ view._start_listening_from_store(self)
for item in view.children:
if item.is_dispatchable():
- self._views[(item.type.value, item.custom_id)] = (view, item, expiry) # type: ignore
+ self._views[(item.type.value, item.custom_id)] = (view, item) # type: ignore
if message_id is not None:
self._synced_message_views[message_id] = view
@@ -481,10 +510,10 @@ class ViewStore:
if value is None:
return
- view, item, _ = value
- self._views[key] = (view, item, view._expires_at)
+ view, item = value
+ self._views[key] = (view, item)
item.refresh_state(interaction)
- view.dispatch(self._state, item, interaction)
+ view._dispatch_item(item, interaction)
def is_message_tracked(self, message_id: int):
return message_id in self._synced_message_views