aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-04-25 04:57:29 -0400
committerRapptz <[email protected]>2021-05-27 00:53:13 -0400
commit98570793e4af8dcfca25621e86fecd85297b0d59 (patch)
treeefcbe97935fa61126c642f00fdd3ebbc15b61375 /discord
parentFix bug in Embed.__len__ caused by footer without text (diff)
downloaddiscord.py-98570793e4af8dcfca25621e86fecd85297b0d59.tar.xz
discord.py-98570793e4af8dcfca25621e86fecd85297b0d59.zip
Add initial support for buttons and components
Diffstat (limited to 'discord')
-rw-r--r--discord/__init__.py3
-rw-r--r--discord/abc.py23
-rw-r--r--discord/components.py158
-rw-r--r--discord/enums.py28
-rw-r--r--discord/http.py20
-rw-r--r--discord/interactions.py28
-rw-r--r--discord/message.py33
-rw-r--r--discord/state.py13
-rw-r--r--discord/types/components.py52
-rw-r--r--discord/types/interactions.py15
-rw-r--r--discord/types/message.py2
-rw-r--r--discord/ui/__init__.py14
-rw-r--r--discord/ui/button.py288
-rw-r--r--discord/ui/item.py140
-rw-r--r--discord/ui/view.py270
15 files changed, 1075 insertions, 12 deletions
diff --git a/discord/__init__.py b/discord/__init__.py
index dbbb54c9..3b057e36 100644
--- a/discord/__init__.py
+++ b/discord/__init__.py
@@ -43,7 +43,7 @@ from .template import *
from .widget import *
from .object import *
from .reaction import *
-from . import utils, opus, abc
+from . import utils, opus, abc, ui
from .enums import *
from .embeds import *
from .mentions import *
@@ -56,6 +56,7 @@ from .raw_models import *
from .team import *
from .sticker import *
from .interactions import *
+from .components import *
VersionInfo = namedtuple('VersionInfo', 'major minor micro releaselevel serial')
diff --git a/discord/abc.py b/discord/abc.py
index a16afe8e..961545ed 100644
--- a/discord/abc.py
+++ b/discord/abc.py
@@ -1154,7 +1154,7 @@ class Messageable(Protocol):
async def send(self, content=None, *, tts=False, embed=None, file=None,
files=None, delete_after=None, nonce=None,
allowed_mentions=None, reference=None,
- mention_author=None):
+ mention_author=None, view=None):
"""|coro|
Sends a message to the destination with the content given.
@@ -1212,6 +1212,10 @@ class Messageable(Protocol):
If set, overrides the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions``.
.. versionadded:: 1.6
+ view: :class:`discord.ui.View`
+ A Discord UI View to add to the message.
+
+ .. versionadded:: 2.0
Raises
--------
@@ -1255,6 +1259,14 @@ class Messageable(Protocol):
except AttributeError:
raise InvalidArgument('reference parameter must be Message or MessageReference') from None
+ if view:
+ if not hasattr(view, '__discord_ui_view__'):
+ raise InvalidArgument(f'view parameter must be View not {view.__class__!r}')
+
+ components = view.to_components()
+ else:
+ components = None
+
if file is not None and files is not None:
raise InvalidArgument('cannot pass both file and files parameter to send()')
@@ -1265,7 +1277,7 @@ class Messageable(Protocol):
try:
data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions,
content=content, tts=tts, embed=embed, nonce=nonce,
- message_reference=reference)
+ message_reference=reference, components=components)
finally:
file.close()
@@ -1278,16 +1290,19 @@ class Messageable(Protocol):
try:
data = await state.http.send_files(channel.id, files=files, content=content, tts=tts,
embed=embed, nonce=nonce, allowed_mentions=allowed_mentions,
- message_reference=reference)
+ message_reference=reference, components=components)
finally:
for f in files:
f.close()
else:
data = await state.http.send_message(channel.id, content, tts=tts, embed=embed,
nonce=nonce, allowed_mentions=allowed_mentions,
- message_reference=reference)
+ message_reference=reference, components=components)
ret = state.create_message(channel=channel, data=data)
+ if view:
+ state.store_view(view, ret.id)
+
if delete_after is not None:
await ret.delete(delay=delete_after)
return ret
diff --git a/discord/components.py b/discord/components.py
new file mode 100644
index 00000000..714876a7
--- /dev/null
+++ b/discord/components.py
@@ -0,0 +1,158 @@
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-present Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+from __future__ import annotations
+
+from typing import List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
+from .enums import try_enum, ComponentType, ButtonStyle
+from .partial_emoji import PartialEmoji
+
+if TYPE_CHECKING:
+ from .types.components import (
+ Component as ComponentPayload,
+ ButtonComponent as ButtonComponentPayload,
+ ComponentContainer as ComponentContainerPayload,
+ )
+
+
+__all__ = (
+ 'Component',
+ 'Button',
+)
+
+C = TypeVar('C', bound='Component')
+
+class Component:
+ """Represents a Discord Bot UI Kit Component.
+
+ Currently, the only components supported by Discord are buttons and button groups.
+
+ .. versionadded:: 2.0
+
+ Attributes
+ ------------
+ type: :class:`ComponentType`
+ The type of component.
+ children: List[:class:`Component`]
+ The children components that this holds, if any.
+ """
+
+ __slots__: Tuple[str, ...] = (
+ 'type',
+ 'children',
+ )
+
+ def __init__(self, data: ComponentPayload):
+ self.type: ComponentType = try_enum(ComponentType, data['type'])
+ self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
+
+ def __repr__(self) -> str:
+ attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__slots__)
+ return f'<{self.__class__.__name__} type={self.type!r} {attrs}>'
+
+ def to_dict(self) -> ComponentContainerPayload:
+ return {
+ 'type': int(self.type),
+ 'components': [child.to_dict() for child in self.children],
+ } # type: ignore
+
+
+ @classmethod
+ def _raw_construct(cls: Type[C], **kwargs) -> C:
+ self: C = cls.__new__(cls)
+ slots = cls.__slots__
+ for attr, value in kwargs.items():
+ if attr in slots:
+ setattr(self, attr, value)
+ return self
+
+
+class Button(Component):
+ """Represents a button from the Discord Bot UI Kit.
+
+ This inherits from :class:`Component`.
+
+ .. versionadded:: 2.0
+
+ Attributes
+ -----------
+ style: :class:`ComponentButtonStyle`
+ The style of the button.
+ custom_id: Optional[:class:`str`]
+ The ID of the button that gets received during an interaction.
+ If this button is for a URL, it does not have a custom ID.
+ url: Optional[:class:`str`]
+ The URL this button sends you to.
+ disabled: :class:`bool`
+ Whether the button is disabled or not.
+ label: :class:`str`
+ The label of the button.
+ emoji: Optional[:class:`PartialEmoji`]
+ The emoji of the button, if available.
+ """
+
+ __slots__: Tuple[str, ...] = Component.__slots__ + (
+ 'style',
+ 'custom_id',
+ 'url',
+ 'disabled',
+ 'label',
+ 'emoji',
+ )
+
+ def __init__(self, data: ButtonComponentPayload):
+ self.type: ComponentType = try_enum(ComponentType, data['type'])
+ self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
+ self.custom_id: Optional[str] = data.get('custom_id')
+ self.url: Optional[str] = data.get('url')
+ self.disabled: bool = data.get('disabled', False)
+ self.label: str = data['label']
+ self.emoji: Optional[PartialEmoji]
+ try:
+ self.emoji = PartialEmoji.from_dict(data['emoji'])
+ except KeyError:
+ self.emoji = None
+
+ def to_dict(self) -> ButtonComponentPayload:
+ payload = {
+ 'type': 2,
+ 'style': int(self.style),
+ 'label': self.label,
+ 'disabled': self.disabled,
+ }
+ if self.custom_id:
+ payload['custom_id'] = self.custom_id
+ if self.url:
+ payload['url'] = self.url
+
+ return payload # type: ignore
+
+def _component_factory(data: ComponentPayload) -> Component:
+ component_type = data['type']
+ if component_type == 1:
+ return Component(data)
+ elif component_type == 2:
+ return Button(data) # type: ignore
+ else:
+ return Component(data)
diff --git a/discord/enums.py b/discord/enums.py
index 1fc0b29e..a06e473e 100644
--- a/discord/enums.py
+++ b/discord/enums.py
@@ -48,6 +48,8 @@ __all__ = (
'StickerType',
'InviteTarget',
'VideoQualityMode',
+ 'ComponentType',
+ 'ButtonStyle',
)
def _create_value_cls(name):
@@ -435,6 +437,15 @@ class InviteTarget(Enum):
class InteractionType(Enum):
ping = 1
application_command = 2
+ component = 3
+
+class InteractionResponseType(Enum):
+ pong = 1
+ # ack = 2 (deprecated)
+ # channel_message = 3 (deprecated)
+ channel_message = 4 # (with source)
+ deferred_channel_message = 5 # (with source)
+ ack = 6 # for components?
class VideoQualityMode(Enum):
auto = 1
@@ -443,6 +454,23 @@ class VideoQualityMode(Enum):
def __int__(self):
return self.value
+class ComponentType(Enum):
+ group = 1
+ button = 2
+
+ def __int__(self):
+ return self.value
+
+class ButtonStyle(Enum):
+ blurple = 1
+ grey = 2
+ green = 3
+ red = 4
+ hyperlink = 5
+
+ def __int__(self):
+ return self.value
+
T = TypeVar('T')
def create_unknown_value(cls: Type[T], val: Any) -> T:
diff --git a/discord/http.py b/discord/http.py
index 1b93cc35..67529495 100644
--- a/discord/http.py
+++ b/discord/http.py
@@ -354,6 +354,7 @@ class HTTPClient:
nonce=None,
allowed_mentions=None,
message_reference=None,
+ components=None,
):
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
payload = {}
@@ -376,6 +377,9 @@ class HTTPClient:
if message_reference:
payload['message_reference'] = message_reference
+ if components:
+ payload['components'] = components
+
return self.request(r, json=payload)
def send_typing(self, channel_id):
@@ -393,6 +397,7 @@ class HTTPClient:
nonce=None,
allowed_mentions=None,
message_reference=None,
+ components=None,
):
form = []
@@ -409,6 +414,8 @@ class HTTPClient:
payload['allowed_mentions'] = allowed_mentions
if message_reference:
payload['message_reference'] = message_reference
+ if components:
+ payload['components'] = components
form.append({'name': 'payload_json', 'value': utils.to_json(payload)})
if len(files) == 1:
@@ -445,6 +452,7 @@ class HTTPClient:
nonce=None,
allowed_mentions=None,
message_reference=None,
+ components=None,
):
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
return self.send_multipart_helper(
@@ -456,6 +464,7 @@ class HTTPClient:
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=message_reference,
+ components=components,
)
def delete_message(self, channel_id, message_id, *, reason=None):
@@ -1210,14 +1219,21 @@ class HTTPClient:
return self.request(route, form=form, files=[file])
- def create_interaction_response(self, interaction_id, token):
+ def create_interaction_response(self, interaction_id, token, *, type, data=None):
r = Route(
'POST',
'/interactions/{interaction_id}/{interaction_token}/callback',
interaction_id=interaction_id,
interaction_token=token,
)
- return self.request(r)
+ payload = {
+ 'type': type,
+ }
+
+ if data is not None:
+ payload['data'] = data
+
+ return self.request(r, json=payload)
def get_original_interaction_response(
self,
diff --git a/discord/interactions.py b/discord/interactions.py
index ccdac792..b38d9a49 100644
--- a/discord/interactions.py
+++ b/discord/interactions.py
@@ -30,6 +30,11 @@ from typing import Optional, TYPE_CHECKING
from . import utils
from .enums import try_enum, InteractionType
+from .user import User
+from .member import Member
+from .message import Message
+from .object import Object
+
__all__ = (
'Interaction',
)
@@ -65,6 +70,8 @@ class Interaction:
The application ID that the interaction was for.
user: Optional[Union[:class:`User`, :class:`Member`]]
The user or member that sent the interaction.
+ message: Optional[:class:`Message`]
+ The message that sent this interaction.
token: :class:`str`
The token to continue the interaction. These are valid
for 15 minutes.
@@ -77,6 +84,7 @@ class Interaction:
'channel_id',
'data',
'application_id',
+ 'message',
'user',
'token',
'version',
@@ -97,10 +105,28 @@ class Interaction:
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.application_id = utils._get_as_snowflake(data, 'application_id')
+ channel = self.channel or Object(id=self.channel_id)
+ try:
+ self.message = Message(state=self._state, channel=channel, data=data['message'])
+ except KeyError:
+ self.message = None
+
+ try:
+ self.user = User(state=self._state, data=data['user'])
+ except KeyError:
+ self.user = None
+
+ # TODO: there's a potential data loss here
+ guild = self.guild or Object(id=self.guild_id)
+ try:
+ self.user = Member(state=self._state, guild=guild, data=data['member'])
+ except KeyError:
+ pass
+
@property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""
- return self._state and self._state.get_guild(self.guild_id)
+ return self._state and self._state._get_guild(self.guild_id)
@property
def channel(self) -> Optional[GuildChannel]:
diff --git a/discord/message.py b/discord/message.py
index a1191d68..89e5dabe 100644
--- a/discord/message.py
+++ b/discord/message.py
@@ -37,6 +37,7 @@ from .emoji import Emoji
from .partial_emoji import PartialEmoji
from .enums import MessageType, ChannelType, try_enum
from .errors import InvalidArgument, HTTPException
+from .components import _component_factory
from .embeds import Embed
from .member import Member
from .flags import MessageFlags
@@ -56,6 +57,8 @@ if TYPE_CHECKING:
Reaction as ReactionPayload,
)
+ from .types.components import Component as ComponentPayload
+
from .types.member import Member as MemberPayload
from .types.user import User as UserPayload
from .types.embed import Embed as EmbedPayload
@@ -581,6 +584,10 @@ class Message(Hashable):
A list of stickers given to the message.
.. versionadded:: 1.6
+ components: List[:class:`Component`]
+ A list of components in the message.
+
+ .. versionadded:: 2.0
"""
__slots__ = (
@@ -613,6 +620,7 @@ class Message(Hashable):
'application',
'activity',
'stickers',
+ 'components',
)
if TYPE_CHECKING:
@@ -643,7 +651,8 @@ class Message(Hashable):
self.tts = data['tts']
self.content = data['content']
self.nonce = data.get('nonce')
- self.stickers = [Sticker(data=data, state=state) for data in data.get('stickers', [])]
+ self.stickers = [Sticker(data=d, state=state) for d in data.get('stickers', [])]
+ self.components = [_component_factory(d) for d in data.get('components', [])]
try:
ref = data['message_reference']
@@ -837,6 +846,9 @@ class Message(Hashable):
if role is not None:
self.role_mentions.append(role)
+ def _handle_components(self, components: List[ComponentPayload]):
+ self.components = [_component_factory(d) for d in components]
+
def _rebind_channel_reference(self, new_channel: Union[TextChannel, DMChannel, GroupChannel]) -> None:
self.channel = new_channel
@@ -1134,6 +1146,11 @@ class Message(Hashable):
are used instead.
.. versionadded:: 1.4
+ view: Optional[:class:`~discord.ui.View`]
+ The updated view to update this message with. If ``None`` is passed then
+ the view is removed.
+
+ .. versionadded:: 2.0
Raises
-------
@@ -1191,10 +1208,24 @@ class Message(Hashable):
else:
fields['attachments'] = [a.to_dict() for a in attachments]
+ try:
+ view = fields.pop('view')
+ except KeyError:
+ # To check for the view afterwards
+ view = None
+ else:
+ if view:
+ fields['components'] = view.to_components()
+ else:
+ fields['components'] = []
+
if fields:
data = await self._state.http.edit_message(self.channel.id, self.id, **fields)
self._update(data)
+ if view:
+ self._state.store_view(view, self.id)
+
if delete_after is not None:
await self.delete(delay=delete_after)
diff --git a/discord/state.py b/discord/state.py
index dd09634d..e4ea28ef 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -52,6 +52,7 @@ from .flags import ApplicationFlags, Intents, MemberCacheFlags
from .object import Object
from .invite import Invite
from .interactions import Interaction
+from .ui.view import ViewStore
class ChunkRequest:
def __init__(self, guild_id, loop, resolver, *, cache=True):
@@ -187,6 +188,7 @@ class ConnectionState:
self._users = weakref.WeakValueDictionary()
self._emojis = {}
self._guilds = {}
+ self._view_store = ViewStore(self)
self._voice_clients = {}
# LRU of max size 128
@@ -278,6 +280,9 @@ class ConnectionState:
self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data)
return emoji
+ def store_view(self, view, message_id=None):
+ self._view_store.add_view(view, message_id)
+
@property
def guilds(self):
return list(self._guilds.values())
@@ -509,6 +514,9 @@ class ConnectionState:
else:
self.dispatch('raw_message_edit', raw)
+ if 'components' in data and self._view_store.is_message_tracked(raw.message_id):
+ self._view_store.update_view(raw.message_id, data['components'])
+
def parse_message_reaction_add(self, data):
emoji = data['emoji']
emoji_id = utils._get_as_snowflake(emoji, 'id')
@@ -581,6 +589,11 @@ class ConnectionState:
def parse_interaction_create(self, data):
interaction = Interaction(data=data, state=self)
+ if data['type'] == 3: # interaction component
+ custom_id = interaction.data['custom_id'] # type: ignore
+ component_type = interaction.data['component_type'] # type: ignore
+ self._view_store.dispatch(component_type, custom_id, interaction)
+
self.dispatch('interaction', interaction)
def parse_presence_update(self, data):
diff --git a/discord/types/components.py b/discord/types/components.py
new file mode 100644
index 00000000..d652c711
--- /dev/null
+++ b/discord/types/components.py
@@ -0,0 +1,52 @@
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-present Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+from __future__ import annotations
+
+from typing import Literal, TypedDict, Union
+from .emoji import PartialEmoji
+
+ComponentType = Literal[1, 2]
+ButtonStyle = Literal[1, 2, 3, 4, 5]
+
+
+class ComponentContainer(TypedDict):
+ type: Literal[1]
+ components: Component
+
+
+class _ButtonComponentOptional(TypedDict, total=False):
+ custom_id: str
+ url: str
+ disabled: bool
+ emoji: PartialEmoji
+
+
+class ButtonComponent(_ButtonComponentOptional):
+ type: Literal[2]
+ style: ButtonStyle
+ label: str
+
+
+Component = Union[ComponentContainer, ButtonComponent]
diff --git a/discord/types/interactions.py b/discord/types/interactions.py
index 07cb1932..dabc77c9 100644
--- a/discord/types/interactions.py
+++ b/discord/types/interactions.py
@@ -24,15 +24,18 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
-from typing import Dict, TypedDict, Union, List, Literal
+from typing import TYPE_CHECKING, Dict, TypedDict, Union, List, Literal
from .snowflake import Snowflake
-from .message import AllowedMentions
+from .components import ComponentType
from .channel import PartialChannel
from .embed import Embed
from .member import Member
from .role import Role
from .user import User
+if TYPE_CHECKING:
+ from .message import AllowedMentions, Message
+
class _ApplicationCommandOptional(TypedDict, total=False):
options: List[ApplicationCommandOption]
@@ -114,12 +117,18 @@ class ApplicationCommandInteractionData(_ApplicationCommandInteractionDataOption
name: str
+class ComponentInteractionData(TypedDict):
+ custom_id: str
+ component_type: ComponentType
+
+
class _InteractionOptional(TypedDict, total=False):
- data: ApplicationCommandInteractionData
+ data: Union[ApplicationCommandInteractionData, ComponentInteractionData]
guild_id: Snowflake
channel_id: Snowflake
member: Member
user: User
+ message: Message
class Interaction(_InteractionOptional):
diff --git a/discord/types/message.py b/discord/types/message.py
index 1aa8259b..47c080ff 100644
--- a/discord/types/message.py
+++ b/discord/types/message.py
@@ -31,6 +31,7 @@ from .user import User
from .emoji import PartialEmoji
from .embed import Embed
from .channel import ChannelType
+from .components import Component
from .interactions import MessageInteraction
@@ -119,6 +120,7 @@ class _MessageOptional(TypedDict, total=False):
stickers: List[Sticker]
referenced_message: Optional[Message]
interaction: MessageInteraction
+ components: List[Component]
MessageType = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 19, 20]
diff --git a/discord/ui/__init__.py b/discord/ui/__init__.py
new file mode 100644
index 00000000..9aa9bea5
--- /dev/null
+++ b/discord/ui/__init__.py
@@ -0,0 +1,14 @@
+"""
+discord.ui
+~~~~~~~~~~~
+
+Bot UI Kit helper for the Discord API
+
+:copyright: (c) 2015-present Rapptz
+:license: MIT, see LICENSE for more details.
+
+"""
+
+from .view import *
+from .item import *
+from .button import *
diff --git a/discord/ui/button.py b/discord/ui/button.py
new file mode 100644
index 00000000..afc69f7a
--- /dev/null
+++ b/discord/ui/button.py
@@ -0,0 +1,288 @@
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-present Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+from __future__ import annotations
+
+from typing import Callable, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
+import inspect
+import re
+import os
+
+
+from .item import Item, ItemCallbackType
+from ..enums import ButtonStyle, ComponentType
+from ..partial_emoji import PartialEmoji
+from ..components import Button as ButtonComponent
+
+__all__ = (
+ 'Button',
+ 'button',
+)
+
+if TYPE_CHECKING:
+ from ..components import Component
+
+_custom_emoji = re.compile(r'<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?')
+
+
+def _to_partial_emoji(obj: Union[str, PartialEmoji], *, _custom_emoji=_custom_emoji) -> PartialEmoji:
+ if isinstance(obj, PartialEmoji):
+ return obj
+
+ obj = str(obj)
+ match = _custom_emoji.match(obj)
+ if match is not None:
+ groups = match.groupdict()
+ animated = bool(groups['animated'])
+ emoji_id = int(groups['id'])
+ name = groups['name']
+ return PartialEmoji(name=name, animated=animated, id=emoji_id)
+
+ return PartialEmoji(name=obj, id=None, animated=False)
+
+
+B = TypeVar('B', bound='Button')
+
+
+class Button(Item):
+ """Represents a UI button.
+
+ .. versionadded:: 2.0
+
+ Parameters
+ ------------
+ style: :class:`discord.ButtonStyle`
+ The style of the button.
+ custom_id: Optional[:class:`str`]
+ The ID of the button that gets received during an interaction.
+ If this button is for a URL, it does not have a custom ID.
+ url: Optional[:class:`str`]
+ The URL this button sends you to.
+ disabled: :class:`bool`
+ Whether the button is disabled or not.
+ label: :class:`str`
+ The label of the button.
+ emoji: Optional[:class:`PartialEmoji`]
+ The emoji of the button, if available.
+ """
+
+ __slots__: Tuple[str, ...] = Item.__slots__ + ('_underlying',)
+
+ __item_repr_attributes__: Tuple[str, ...] = (
+ 'style',
+ 'url',
+ 'disabled',
+ 'label',
+ 'emoji',
+ 'group_id',
+ )
+
+ def __init__(
+ self,
+ *,
+ style: ButtonStyle,
+ label: str,
+ disabled: bool = False,
+ custom_id: Optional[str] = None,
+ url: Optional[str] = None,
+ emoji: Optional[Union[str, PartialEmoji]] = None,
+ group: Optional[int] = None,
+ ):
+ super().__init__()
+ if custom_id is not None and url is not None:
+ raise TypeError('cannot mix both url and custom_id with Button')
+
+ if url is None and custom_id is None:
+ custom_id = os.urandom(16).hex()
+
+ self._underlying = ButtonComponent._raw_construct(
+ type=ComponentType.button,
+ custom_id=custom_id,
+ url=url,
+ disabled=disabled,
+ label=label,
+ style=style,
+ emoji=None if emoji is None else _to_partial_emoji(emoji),
+ )
+ self.group_id = group
+
+ @property
+ def style(self) -> ButtonStyle:
+ """:class:`discord.ButtonStyle`: The style of the button."""
+ return self._underlying.style
+
+ @style.setter
+ def style(self, value: ButtonStyle):
+ self._underlying.style = value
+
+ @property
+ def custom_id(self) -> Optional[str]:
+ """Optional[:class:`str`]: The ID of the button that gets received during an interaction.
+
+ If this button is for a URL, it does not have a custom ID.
+ """
+ return self._underlying.custom_id
+
+ @custom_id.setter
+ def custom_id(self, value: Optional[str]):
+ if value is not None and not isinstance(value, str):
+ raise TypeError('custom_id must be None or str')
+
+ self._underlying.custom_id = value
+
+ @property
+ def url(self) -> Optional[str]:
+ """Optional[:class:`str`]: The URL this button sends you to."""
+ return self._underlying.url
+
+ @url.setter
+ def url(self, value: Optional[str]):
+ if value is not None and not isinstance(value, str):
+ raise TypeError('url must be None or str')
+ self._underlying.url = value
+
+ @property
+ def disabled(self) -> bool:
+ """:class:`bool`: Whether the button is disabled or not."""
+ return self._underlying.disabled
+
+ @disabled.setter
+ def disabled(self, value: bool):
+ self._underlying.disabled = bool(value)
+
+ @property
+ def label(self) -> str:
+ """:class:`str`: The label of the button."""
+ return self._underlying.label
+
+ @label.setter
+ def label(self, value: str):
+ self._underlying.label = str(value)
+
+ @property
+ def emoji(self) -> Optional[PartialEmoji]:
+ """Optional[:class:`PartialEmoji`]: The emoji of the button, if available."""
+ return self._underlying.emoji
+
+ @emoji.setter
+ def emoji(self, value: Optional[Union[str, PartialEmoji]]): # type: ignore
+ if value is not None:
+ self._underlying.emoji = _to_partial_emoji(value)
+ else:
+ self._underlying.emoji = None
+
+ def copy(self: B) -> B:
+ button = self.__class__(
+ style=self.style,
+ label=self.label,
+ disabled=self.disabled,
+ custom_id=self.custom_id,
+ url=self.url,
+ emoji=self.emoji,
+ group=self.group_id,
+ )
+ button.callback = self.callback
+ return button
+
+ @classmethod
+ def from_component(cls: Type[B], button: ButtonComponent) -> B:
+ return cls(
+ style=button.style,
+ label=button.label,
+ disabled=button.disabled,
+ custom_id=button.custom_id,
+ url=button.url,
+ emoji=button.emoji,
+ group=None,
+ )
+
+ @property
+ def type(self) -> ComponentType:
+ return self._underlying.type
+
+ def to_component_dict(self):
+ return self._underlying.to_dict()
+
+ def is_dispatchable(self) -> bool:
+ return True
+
+ def refresh_state(self, button: ButtonComponent) -> None:
+ self._underlying = button
+
+
+def button(
+ label: str,
+ *,
+ custom_id: Optional[str] = None,
+ disabled: bool = False,
+ style: ButtonStyle = ButtonStyle.grey,
+ emoji: Optional[Union[str, PartialEmoji]] = None,
+ group: Optional[int] = None,
+) -> Callable[[ItemCallbackType], Button]:
+ """A decorator that attaches a button to a component.
+
+ The function being decorated should have three parameters, ``self`` representing
+ the :class:`discord.ui.View`, the :class:`discord.ui.Button` being pressed and
+ the :class:`discord.Interaction` you receive.
+
+ .. note::
+
+ Buttons with a URL cannot be created with this function.
+ Consider creating a :class:`Button` manually instead.
+ This is because buttons with a URL do not have a callback
+ associated with them since Discord does not do any processing
+ with it.
+
+ Parameters
+ ------------
+ label: :class:`str`
+ The label of the button.
+ custom_id: Optional[:class:`str`]
+ The ID of the button that gets received during an interaction.
+ It is recommended not to set this parameter to prevent conflicts.
+ style: :class:`ButtonStyle`
+ The style of the button. Defaults to :attr:`ButtonStyle.grey`.
+ disabled: :class:`bool`
+ Whether the button is disabled or not. Defaults to ``False``.
+ emoji: Optional[Union[:class:`str`, :class:`PartialEmoji`]]
+ The emoji of the button. This can be in string form or a :class:`PartialEmoji`.
+ group: Optional[:class:`int`]
+ The relative group this button belongs to. A Discord component can only have 5
+ groups. By default, items are arranged automatically into those 5 groups. If you'd
+ like to control the relative positioning of the group then passing an index is advised.
+ For example, group=1 will show up before group=2. Defaults to ``None``, which is automatic
+ ordering.
+ """
+
+ def decorator(func: ItemCallbackType) -> Button:
+ nonlocal custom_id
+ if not inspect.iscoroutinefunction(func):
+ raise TypeError('button function must be a coroutine function')
+
+ custom_id = custom_id or os.urandom(32).hex()
+ button = Button(style=style, custom_id=custom_id, url=None, disabled=disabled, label=label, emoji=emoji, group=group)
+ button.callback = func
+ return button
+
+ return decorator
diff --git a/discord/ui/item.py b/discord/ui/item.py
new file mode 100644
index 00000000..7726407e
--- /dev/null
+++ b/discord/ui/item.py
@@ -0,0 +1,140 @@
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-present Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Callable, Coroutine, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
+import inspect
+
+from ..interactions import Interaction
+
+__all__ = (
+ 'Item',
+)
+
+if TYPE_CHECKING:
+ from ..enums import ComponentType
+ from .view import View
+ from ..components import Component
+
+I = TypeVar('I', bound='Item')
+ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]]
+
+
+class Item:
+ """Represents the base UI item that all UI components inherit from.
+
+ The current UI items supported are:
+
+ - :class:`discord.ui.Button`
+ """
+
+ __slots__: Tuple[str, ...] = (
+ '_callback',
+ '_pass_view_arg',
+ 'group_id',
+ )
+
+ __item_repr_attributes__: Tuple[str, ...] = ('group_id',)
+
+ def __init__(self):
+ self._callback: Optional[ItemCallbackType] = None
+ self._pass_view_arg = True
+ self.group_id: Optional[int] = None
+
+ def to_component_dict(self) -> Dict[str, Any]:
+ raise NotImplementedError
+
+ def copy(self: I) -> I:
+ raise NotImplementedError
+
+ def refresh_state(self, component: Component) -> None:
+ return None
+
+ @classmethod
+ def from_component(cls: Type[I], component: Component) -> I:
+ return cls()
+
+ @property
+ def type(self) -> ComponentType:
+ raise NotImplementedError
+
+ def is_dispatchable(self) -> bool:
+ return False
+
+ def __repr__(self) -> str:
+ attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__item_repr_attributes__)
+ return f'<{self.__class__.__name__} {attrs}>'
+
+ @property
+ def callback(self) -> Optional[ItemCallbackType]:
+ """Returns the underlying callback associated with this interaction."""
+ return self._callback
+
+ @callback.setter
+ def callback(self, value: Optional[ItemCallbackType]):
+ if value is None:
+ self._callback = None
+ return
+
+ # Check if it's a partial function
+ try:
+ partial = value.func
+ except AttributeError:
+ pass
+ else:
+ if not inspect.iscoroutinefunction(value.func):
+ raise TypeError(f'inner partial function must be a coroutine')
+
+ # Check if the partial is bound
+ try:
+ bound_partial = partial.__self__
+ except AttributeError:
+ pass
+ else:
+ self._pass_view_arg = not hasattr(bound_partial, '__discord_ui_view__')
+
+ self._callback = value
+ return
+
+ try:
+ func_self = value.__self__
+ except AttributeError:
+ pass
+ else:
+ if not isinstance(func_self, Item):
+ raise TypeError(f'callback bound method must be from Item not {func_self!r}')
+ else:
+ value = value.__func__
+
+ if not inspect.iscoroutinefunction(value):
+ raise TypeError(f'callback must be a coroutine not {value!r}')
+
+ self._callback = value
+
+ async def _do_call(self, view: View, interaction: Interaction):
+ if self._pass_view_arg:
+ await self._callback(view, self, interaction)
+ else:
+ await self._callback(self, interaction) # type: ignore
diff --git a/discord/ui/view.py b/discord/ui/view.py
new file mode 100644
index 00000000..273a45d0
--- /dev/null
+++ b/discord/ui/view.py
@@ -0,0 +1,270 @@
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-present Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+from __future__ import annotations
+from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple
+from functools import partial
+from itertools import groupby
+
+import asyncio
+import sys
+import time
+import os
+from .item import Item
+from ..enums import ComponentType
+from ..components import (
+ Component,
+ _component_factory,
+ Button as ButtonComponent,
+)
+
+__all__ = (
+ 'View',
+)
+
+
+if TYPE_CHECKING:
+ from ..interactions import Interaction
+ from ..types.components import Component as ComponentPayload
+
+
+def _walk_all_components(components: List[Component]) -> Iterator[Component]:
+ for item in components:
+ if item.type is ComponentType.group:
+ yield from item.children
+ else:
+ yield item
+
+
+def _component_to_item(component: Component) -> Item:
+ if isinstance(component, ButtonComponent):
+ from .button import Button
+
+ return Button.from_component(component)
+ return Item.from_component(component)
+
+
+class View:
+ """Represents a UI view.
+
+ This object must be inherited to create a UI within Discord.
+
+ Parameters
+ -----------
+ timeout: Optional[:class:`float`]
+ Timeout from last interaction with the UI before no longer accepting input.
+ If ``None`` then there is no timeout.
+
+ Attributes
+ ------------
+ timeout: Optional[:class:`float`]
+ Timeout from last interaction with the UI before no longer accepting input.
+ If ``None`` then there is no timeout.
+ children: List[:class:`Item`]
+ The list of children attached to this view.
+ """
+
+ __slots__ = (
+ 'timeout',
+ 'children',
+ 'id',
+ '_cancel_callback',
+ )
+
+ __discord_ui_view__: ClassVar[bool] = True
+
+ if TYPE_CHECKING:
+ __view_children_items__: ClassVar[List[Item]]
+
+ def __init_subclass__(cls) -> None:
+ children: List[Item] = []
+ for base in reversed(cls.__mro__):
+ for member in base.__dict__.values():
+ if isinstance(member, Item):
+ children.append(member)
+
+ if len(children) > 25:
+ raise TypeError('View cannot have more than 25 children')
+
+ cls.__view_children_items__ = children
+
+ def __init__(self, timeout: Optional[float] = 180.0):
+ self.timeout = timeout
+ self.children: List[Item] = [i.copy() for i in self.__view_children_items__]
+ self.id = os.urandom(16).hex()
+ self._cancel_callback: Optional[Callable[[View], None]] = None
+
+ def to_components(self) -> List[Dict[str, Any]]:
+ def key(item: Item) -> int:
+ if item.group_id is None:
+ return sys.maxsize
+ return item.group_id
+
+ children = sorted(self.children, key=key)
+ components: List[Dict[str, Any]] = []
+ for _, group in groupby(children, key=key):
+ group = list(group)
+ if len(group) <= 5:
+ components.append(
+ {
+ 'type': 1,
+ 'components': [item.to_component_dict() for item in group],
+ }
+ )
+ else:
+ components.extend(
+ {
+ 'type': 1,
+ 'components': [item.to_component_dict() for item in group[index : index + 5]],
+ }
+ for index in range(0, len(group), 5)
+ )
+
+ return components
+
+ @property
+ def _expires_at(self) -> Optional[float]:
+ if self.timeout:
+ return time.monotonic() + self.timeout
+ return None
+
+ def add_item(self, item: Item) -> None:
+ """Adds an item to the view.
+
+ Parameters
+ -----------
+ item: :class:`Item`
+ The item to add to the view.
+
+ Raises
+ --------
+ TypeError
+ A :class:`Item` was not passed.
+ ValueError
+ Maximum number of children has been exceeded (25).
+ """
+
+ if len(self.children) > 25:
+ raise ValueError('maximum number of children exceeded')
+
+ if not isinstance(item, Item):
+ raise TypeError(f'expected Item not {item.__class__!r}')
+
+ self.children.append(item)
+
+ async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction):
+ await state.http.create_interaction_response(interaction.id, interaction.token, type=6)
+ await item._do_call(self, interaction)
+
+ def dispatch(self, state: Any, item: Item, interaction: Interaction):
+ asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
+
+ def refresh(self, components: List[Component]):
+ # This is pretty hacky at the moment
+ # fmt: off
+ old_state: Dict[Tuple[int, str], Item] = {
+ (item.type.value, item.custom_id): item # type: ignore
+ for item in self.children
+ if item.is_dispatchable()
+ }
+ # fmt: on
+ children: List[Item] = []
+ for component in _walk_all_components(components):
+ try:
+ older = old_state[(component.type.value, component.custom_id)] # type: ignore
+ except (KeyError, AttributeError):
+ children.append(_component_to_item(component))
+ else:
+ older.refresh_state(component)
+ children.append(older)
+
+ self.children = children
+
+ def stop(self) -> None:
+ """Stops listening to interaction events from this view.
+
+ This operation cannot be undone.
+ """
+ if self._cancel_callback:
+ self._cancel_callback(self)
+
+
+class ViewStore:
+ def __init__(self, state):
+ # (component_type, custom_id): (View, Item, Expiry)
+ self._views: Dict[Tuple[int, str], Tuple[View, Item, Optional[float]]] = {}
+ # message_id: View
+ self._synced_message_views: Dict[int, View] = {}
+ self._state = state
+
+ def __verify_integrity(self):
+ to_remove: List[Tuple[int, str]] = []
+ now = time.monotonic()
+ for (k, (_, _, expiry)) in self._views.items():
+ if expiry is not None and now >= expiry:
+ to_remove.append(k)
+
+ for k in to_remove:
+ del self._views[k]
+
+ def add_view(self, view: View, message_id: Optional[int] = None):
+ self.__verify_integrity()
+
+ expiry = view._expires_at
+ view._cancel_callback = partial(self.remove_view)
+ for item in view.children:
+ if item.is_dispatchable():
+ self._views[(item.type.value, item.custom_id)] = (view, item, expiry) # type: ignore
+
+ if message_id is not None:
+ self._synced_message_views[message_id] = view
+
+ def remove_view(self, view: View):
+ for item in view.children:
+ if item.is_dispatchable():
+ self._views.pop((item.type.value, item.custom_id)) # type: ignore
+
+ for key, value in self._synced_message_views.items():
+ if value.id == view.id:
+ del self._synced_message_views[key]
+ break
+
+ def dispatch(self, component_type: int, custom_id: str, interaction: Interaction):
+ self.__verify_integrity()
+ key = (component_type, custom_id)
+ value = self._views.get(key)
+ if value is None:
+ return
+
+ view, item, _ = value
+ self._views[key] = (view, item, view._expires_at)
+ view.dispatch(self._state, item, interaction)
+
+ def is_message_tracked(self, message_id: int):
+ return message_id in self._synced_message_views
+
+ def update_view(self, message_id: int, components: List[ComponentPayload]):
+ # pre-req: is_message_tracked == true
+ view = self._synced_message_views[message_id]
+ view.refresh([_component_factory(d) for d in components])