aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-04-26 06:02:43 -0400
committerRapptz <[email protected]>2021-05-27 00:53:14 -0400
commit4c0ebc922155c2d9ca3129e0dbfdcea10f3ad777 (patch)
treed72b6614c242ce34863a643fd0026aa1319ac2cd /discord
parentFix emoji not showing up in button component (diff)
downloaddiscord.py-4c0ebc922155c2d9ca3129e0dbfdcea10f3ad777.tar.xz
discord.py-4c0ebc922155c2d9ca3129e0dbfdcea10f3ad777.zip
Change the way callbacks are defined to allow deriving
This should hopefully make these work more consistently as other functions do.
Diffstat (limited to 'discord')
-rw-r--r--discord/ui/button.py33
-rw-r--r--discord/ui/item.py82
-rw-r--r--discord/ui/view.py19
3 files changed, 45 insertions, 89 deletions
diff --git a/discord/ui/button.py b/discord/ui/button.py
index afc69f7a..8ff4ce74 100644
--- a/discord/ui/button.py
+++ b/discord/ui/button.py
@@ -87,8 +87,6 @@ class Button(Item):
The emoji of the button, if available.
"""
- __slots__: Tuple[str, ...] = Item.__slots__ + ('_underlying',)
-
__item_repr_attributes__: Tuple[str, ...] = (
'style',
'url',
@@ -192,19 +190,6 @@ class Button(Item):
else:
self._underlying.emoji = None
- def copy(self: B) -> B:
- button = self.__class__(
- style=self.style,
- label=self.label,
- disabled=self.disabled,
- custom_id=self.custom_id,
- url=self.url,
- emoji=self.emoji,
- group=self.group_id,
- )
- button.callback = self.callback
- return button
-
@classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B:
return cls(
@@ -239,7 +224,7 @@ def button(
style: ButtonStyle = ButtonStyle.grey,
emoji: Optional[Union[str, PartialEmoji]] = None,
group: Optional[int] = None,
-) -> Callable[[ItemCallbackType], Button]:
+) -> Callable[[ItemCallbackType], ItemCallbackType]:
"""A decorator that attaches a button to a component.
The function being decorated should have three parameters, ``self`` representing
@@ -275,14 +260,22 @@ def button(
ordering.
"""
- def decorator(func: ItemCallbackType) -> Button:
+ def decorator(func: ItemCallbackType) -> ItemCallbackType:
nonlocal custom_id
if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function')
custom_id = custom_id or os.urandom(32).hex()
- button = Button(style=style, custom_id=custom_id, url=None, disabled=disabled, label=label, emoji=emoji, group=group)
- button.callback = func
- return button
+ func.__discord_ui_model_type__ = Button
+ func.__discord_ui_model_kwargs__ = {
+ 'style': style,
+ 'custom_id': custom_id,
+ 'url': None,
+ 'disabled': disabled,
+ 'label': label,
+ 'emoji': emoji,
+ 'group': group,
+ }
+ return func
return decorator
diff --git a/discord/ui/item.py b/discord/ui/item.py
index 7726407e..dc6c91a0 100644
--- a/discord/ui/item.py
+++ b/discord/ui/item.py
@@ -24,8 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
-from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
-import inspect
+from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from ..interactions import Interaction
@@ -50,25 +49,15 @@ class Item:
- :class:`discord.ui.Button`
"""
- __slots__: Tuple[str, ...] = (
- '_callback',
- '_pass_view_arg',
- 'group_id',
- )
-
__item_repr_attributes__: Tuple[str, ...] = ('group_id',)
def __init__(self):
- self._callback: Optional[ItemCallbackType] = None
- self._pass_view_arg = True
+ self._view: Optional[View] = None
self.group_id: Optional[int] = None
def to_component_dict(self) -> Dict[str, Any]:
raise NotImplementedError
- def copy(self: I) -> I:
- raise NotImplementedError
-
def refresh_state(self, component: Component) -> None:
return None
@@ -88,53 +77,20 @@ class Item:
return f'<{self.__class__.__name__} {attrs}>'
@property
- def callback(self) -> Optional[ItemCallbackType]:
- """Returns the underlying callback associated with this interaction."""
- return self._callback
-
- @callback.setter
- def callback(self, value: Optional[ItemCallbackType]):
- if value is None:
- self._callback = None
- return
-
- # Check if it's a partial function
- try:
- partial = value.func
- except AttributeError:
- pass
- else:
- if not inspect.iscoroutinefunction(value.func):
- raise TypeError(f'inner partial function must be a coroutine')
-
- # Check if the partial is bound
- try:
- bound_partial = partial.__self__
- except AttributeError:
- pass
- else:
- self._pass_view_arg = not hasattr(bound_partial, '__discord_ui_view__')
-
- self._callback = value
- return
-
- try:
- func_self = value.__self__
- except AttributeError:
- pass
- else:
- if not isinstance(func_self, Item):
- raise TypeError(f'callback bound method must be from Item not {func_self!r}')
- else:
- value = value.__func__
-
- if not inspect.iscoroutinefunction(value):
- raise TypeError(f'callback must be a coroutine not {value!r}')
-
- self._callback = value
-
- async def _do_call(self, view: View, interaction: Interaction):
- if self._pass_view_arg:
- await self._callback(view, self, interaction)
- else:
- await self._callback(self, interaction) # type: ignore
+ def view(self) -> Optional[View]:
+ """Optional[:class:`View`]: The underlying view for this item."""
+ return self._view
+
+ async def callback(self, interaction: Interaction):
+ """|coro|
+
+ The callback associated with this UI item.
+
+ This can be overriden by subclasses.
+
+ Parameters
+ -----------
+ interaction: :class:`Interaction`
+ The interaction that triggered this UI item.
+ """
+ pass
diff --git a/discord/ui/view.py b/discord/ui/view.py
index 273a45d0..712f787a 100644
--- a/discord/ui/view.py
+++ b/discord/ui/view.py
@@ -31,7 +31,7 @@ import asyncio
import sys
import time
import os
-from .item import Item
+from .item import Item, ItemCallbackType
from ..enums import ComponentType
from ..components import (
Component,
@@ -95,13 +95,13 @@ class View:
__discord_ui_view__: ClassVar[bool] = True
if TYPE_CHECKING:
- __view_children_items__: ClassVar[List[Item]]
+ __view_children_items__: ClassVar[List[ItemCallbackType]]
def __init_subclass__(cls) -> None:
- children: List[Item] = []
+ children: List[ItemCallbackType] = []
for base in reversed(cls.__mro__):
for member in base.__dict__.values():
- if isinstance(member, Item):
+ if hasattr(member, '__discord_ui_model_type__'):
children.append(member)
if len(children) > 25:
@@ -111,7 +111,13 @@ class View:
def __init__(self, timeout: Optional[float] = 180.0):
self.timeout = timeout
- self.children: List[Item] = [i.copy() for i in self.__view_children_items__]
+ self.children: List[Item] = []
+ for func in self.__view_children_items__:
+ item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
+ item.callback = partial(func, self, item)
+ item._view = self
+ self.children.append(item)
+
self.id = os.urandom(16).hex()
self._cancel_callback: Optional[Callable[[View], None]] = None
@@ -171,11 +177,12 @@ class View:
if not isinstance(item, Item):
raise TypeError(f'expected Item not {item.__class__!r}')
+ item._view = self
self.children.append(item)
async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction):
await state.http.create_interaction_response(interaction.id, interaction.token, type=6)
- await item._do_call(self, interaction)
+ await item.callback(interaction)
def dispatch(self, state: Any, item: Item, interaction: Interaction):
asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}')