aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2017-01-25 22:26:49 -0500
committerRapptz <[email protected]>2017-01-25 22:26:49 -0500
commite5cb7d295c9c8ea5ca52308b4286452a64729b83 (patch)
treef582bc40d0a12516f553f5ed8fc033218e54d225
parentAdd compatibility shim for asyncio.Future creation. (diff)
downloaddiscord.py-e5cb7d295c9c8ea5ca52308b4286452a64729b83.tar.xz
discord.py-e5cb7d295c9c8ea5ca52308b4286452a64729b83.zip
Replace wait_for_* with a generic Client.wait_for
-rw-r--r--discord/client.py372
1 files changed, 84 insertions, 288 deletions
diff --git a/discord/client.py b/discord/client.py
index ec05cb20..76509016 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -41,7 +41,7 @@ import aiohttp
import websockets
import logging, traceback
-import sys, re, io, enum
+import sys, re, io
import itertools
import datetime
from collections import namedtuple
@@ -51,7 +51,6 @@ PY35 = sys.version_info >= (3, 5)
log = logging.getLogger(__name__)
AppInfo = namedtuple('AppInfo', 'id name description icon owner')
-WaitedReaction = namedtuple('WaitedReaction', 'reaction user')
def app_info_icon_url(self):
"""Retrieves the application's icon_url if it exists. Empty string otherwise."""
@@ -62,10 +61,6 @@ def app_info_icon_url(self):
AppInfo.icon_url = property(app_info_icon_url)
-class WaitForType(enum.Enum):
- message = 0
- reaction = 1
-
class Client:
"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
@@ -113,7 +108,7 @@ class Client:
self.ws = None
self.email = None
self.loop = asyncio.get_event_loop() if loop is None else loop
- self._listeners = []
+ self._listeners = {}
self.shard_id = options.get('shard_id')
self.shard_count = options.get('shard_count')
@@ -125,8 +120,6 @@ class Client:
self.connection.shard_count = self.shard_count
self._closed = asyncio.Event(loop=self.loop)
- self._is_logged_in = asyncio.Event(loop=self.loop)
- self._is_ready = asyncio.Event(loop=self.loop)
# if VoiceClient.warn_nacl:
# VoiceClient.warn_nacl = False
@@ -156,57 +149,6 @@ class Client:
yield from self.ws.send_as_json(payload)
- def handle_reaction_add(self, reaction, user):
- removed = []
- for i, (condition, future, event_type) in enumerate(self._listeners):
- if event_type is not WaitForType.reaction:
- continue
-
- if future.cancelled():
- removed.append(i)
- continue
-
- try:
- result = condition(reaction, user)
- except Exception as e:
- future.set_exception(e)
- removed.append(i)
- else:
- if result:
- future.set_result(WaitedReaction(reaction, user))
- removed.append(i)
-
-
- for idx in reversed(removed):
- del self._listeners[idx]
-
- def handle_message(self, message):
- removed = []
- for i, (condition, future, event_type) in enumerate(self._listeners):
- if event_type is not WaitForType.message:
- continue
-
- if future.cancelled():
- removed.append(i)
- continue
-
- try:
- result = condition(message)
- except Exception as e:
- future.set_exception(e)
- removed.append(i)
- else:
- if result:
- future.set_result(message)
- removed.append(i)
-
-
- for idx in reversed(removed):
- del self._listeners[idx]
-
- def handle_ready(self):
- self._is_ready.set()
-
def _resolve_invite(self, invite):
if isinstance(invite, Invite) or isinstance(invite, Object):
return invite.id
@@ -264,6 +206,35 @@ class Client:
method = 'on_' + event
handler = 'handle_' + event
+ listeners = self._listeners.get(event)
+ if listeners:
+ removed = []
+ for i, (future, condition) in enumerate(listeners):
+ if future.cancelled():
+ removed.append(i)
+ continue
+
+ try:
+ result = condition(*args)
+ except Exception as e:
+ future.set_exception(e)
+ removed.append(i)
+ else:
+ if result:
+ if len(args) == 0:
+ future.set_result(None)
+ elif len(args) == 1:
+ future.set_result(args[0])
+ else:
+ future.set_result(args)
+ removed.append(i)
+
+ if len(removed) == len(listeners):
+ self._listeners.pop(event)
+ else:
+ for idx in reversed(removed):
+ del listeners[idx]
+
try:
actual_handler = getattr(self, handler)
except AttributeError:
@@ -353,7 +324,6 @@ class Client:
data = yield from self.http.static_login(token, bot=bot)
self.email = data.get('email', None)
self.connection.is_bot = bot
- self._is_logged_in.set()
@asyncio.coroutine
def logout(self):
@@ -362,7 +332,6 @@ class Client:
Logs out of Discord and closes all connections.
"""
yield from self.close()
- self._is_logged_in.clear()
@asyncio.coroutine
def connect(self):
@@ -420,7 +389,6 @@ class Client:
yield from self.http.close()
self._closed.set()
- self._is_ready.clear()
@asyncio.coroutine
def start(self, *args, **kwargs):
@@ -474,12 +442,7 @@ class Client:
finally:
self.loop.close()
- # properties
-
- @property
- def is_logged_in(self):
- """bool: Indicates if the client has logged in successfully."""
- return self._is_logged_in.is_set()
+ # properties
@property
def is_closed(self):
@@ -550,250 +513,83 @@ class Client:
# listeners/waiters
- @asyncio.coroutine
- def wait_until_ready(self):
- """|coro|
-
- This coroutine waits until the client is all ready. This could be considered
- another way of asking for :func:`discord.on_ready` except meant for your own
- background tasks.
- """
- yield from self._is_ready.wait()
-
- @asyncio.coroutine
- def wait_until_login(self):
+ def wait_for(self, event, *, check=None, timeout=None):
"""|coro|
- This coroutine waits until the client is logged on successfully. This
- is different from waiting until the client's state is all ready. For
- that check :func:`discord.on_ready` and :meth:`wait_until_ready`.
- """
- yield from self._is_logged_in.wait()
-
- @asyncio.coroutine
- def wait_for_message(self, timeout=None, *, author=None, channel=None, content=None, check=None):
- """|coro|
-
- Waits for a message reply from Discord. This could be seen as another
- :func:`discord.on_message` event outside of the actual event. This could
- also be used for follow-ups and easier user interactions.
-
- The keyword arguments passed into this function are combined using the logical and
- operator. The ``check`` keyword argument can be used to pass in more complicated
- checks and must be a regular function (not a coroutine).
+ Waits for a WebSocket event to be dispatched.
- The ``timeout`` parameter is passed into `asyncio.wait_for`_. By default, it
- does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine
- catches the exception and returns ``None`` instead of a :class:`Message`.
+ This could be used to wait for a user to reply to a message,
+ or to react to a message, or to edit a message in a self-contained
+ way.
- If the ``check`` predicate throws an exception, then the exception is propagated.
+ The ``timeout`` parameter is passed onto `asyncio.wait_for`_. By default,
+ it does not timeout. Note that this does propagate the
+ ``asyncio.TimeoutError`` for you in case of timeout and is provided for
+ ease of use.
- This function returns the **first message that meets the requirements**.
+ In case the event returns multiple arguments, a tuple containing those
+ arguments is returned instead. Please check the
+ :ref:`documentation <discord-api-events>` for a list of events and their
+ parameters.
- .. _asyncio.wait_for: https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for
+ This function returns the **first event that meets the requirements**.
Examples
- ----------
-
- Basic example:
+ ---------
- .. code-block:: python
- :emphasize-lines: 5
+ Waiting for a user reply: ::
@client.event
async def on_message(message):
if message.content.startswith('$greet'):
- await message.channel.send('Say hello')
- msg = await client.wait_for_message(author=message.author, content='hello')
- await message.channel.send('Hello.')
-
- Asking for a follow-up question:
-
- .. code-block:: python
- :emphasize-lines: 6
-
- @client.event
- async def on_message(message):
- if message.content.startswith('$start'):
- await message.channel.send('Type $stop 4 times.')
- for i in range(4):
- msg = await client.wait_for_message(author=message.author, content='$stop')
- fmt = '{} left to go...'
- await message.channel.send(fmt.format(3 - i))
-
- await message.channel.send('Good job!')
-
- Advanced filters using ``check``:
-
- .. code-block:: python
- :emphasize-lines: 9
-
- @client.event
- async def on_message(message):
- if message.content.startswith('$cool'):
- await message.channel.send('Who is cool? Type $name namehere')
-
- def check(msg):
- return msg.content.startswith('$name')
+ await message.channel.send('Say hello!')
- message = await client.wait_for_message(author=message.author, check=check)
- name = message.content[len('$name'):].strip()
- await message.channel.send('{} is cool indeed'.format(name))
+ def check(m):
+ return m.content == 'hello' and m.channel == message.channel
+ msg = await client.wait_for('message', check=check)
+ await message.channel.send('Hello {.author}!'.format(msg))
Parameters
- -----------
- timeout : float
- The number of seconds to wait before returning ``None``.
- author : :class:`Member` or :class:`User`
- The author the message must be from.
- channel : :class:`Channel` or :class:`PrivateChannel` or :class:`Object`
- The channel the message must be from.
- content : str
- The exact content the message must have.
- check : function
- A predicate for other complicated checks. The predicate must take
- a :class:`Message` as its only parameter.
-
- Returns
- --------
- :class:`Message`
- The message that you requested for.
- """
-
- def predicate(message):
- result = True
- if author is not None:
- result = result and message.author == author
-
- if content is not None:
- result = result and message.content == content
-
- if channel is not None:
- result = result and message.channel.id == channel.id
-
- if callable(check):
- # the exception thrown by check is propagated through the future.
- result = result and check(message)
-
- return result
-
- future = compat.create_future(self.loop)
- self._listeners.append((predicate, future, WaitForType.message))
- try:
- message = yield from asyncio.wait_for(future, timeout, loop=self.loop)
- except asyncio.TimeoutError:
- message = None
- return message
-
-
- @asyncio.coroutine
- def wait_for_reaction(self, emoji=None, *, user=None, timeout=None, message=None, check=None):
- """|coro|
-
- Waits for a message reaction from Discord. This is similar to :meth:`wait_for_message`
- and could be seen as another :func:`on_reaction_add` event outside of the actual event.
- This could be used for follow up situations.
-
- Similar to :meth:`wait_for_message`, the keyword arguments are combined using logical
- AND operator. The ``check`` keyword argument can be used to pass in more complicated
- checks and must a regular function taking in two arguments, ``(reaction, user)``. It
- must not be a coroutine.
-
- The ``timeout`` parameter is passed into asyncio.wait_for. By default, it
- does not timeout. Instead of throwing ``asyncio.TimeoutError`` the coroutine
- catches the exception and returns ``None`` instead of a the ``(reaction, user)``
- tuple.
-
- If the ``check`` predicate throws an exception, then the exception is propagated.
-
- The ``emoji`` parameter can be either a :class:`Emoji`, a ``str`` representing
- an emoji, or a sequence of either type. If the ``emoji`` parameter is a sequence
- then the first reaction emoji that is in the list is returned. If ``None`` is
- passed then the first reaction emoji used is returned.
-
- This function returns the **first reaction that meets the requirements**.
-
- Examples
- ---------
-
- Basic Example:
-
- .. code-block:: python
-
- @client.event
- async def on_message(message):
- if message.content.startswith('$react'):
- msg = await message.channel.send('React with thumbs up or thumbs down.')
- res = await client.wait_for_reaction(['\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}'], message=msg)
- await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res))
-
- Checking for reaction emoji regardless of skin tone:
-
- .. code-block:: python
-
- @client.event
- async def on_message(message):
- if message.content.startswith('$react'):
- msg = await message.channel.send('React with thumbs up or thumbs down.')
-
- def check(reaction, user):
- e = str(reaction.emoji)
- return e.startswith(('\N{THUMBS UP SIGN}', '\N{THUMBS DOWN SIGN}'))
-
- res = await client.wait_for_reaction(message=msg, check=check)
- await message.channel.send('{0.user} reacted with {0.reaction.emoji}!'.format(res))
+ ------------
+ event: str
+ The event name, similar to the :ref:`event reference <discord-api-events>`,
+ but without the ``on_`` prefix, to wait for.
+ check: Optional[predicate]
+ A predicate to check what to wait for. The arguments must meet the
+ parameters of the event being waited for.
+ timeout: Optional[float]
+ The number of seconds to wait before timing out and raising
+ ``asyncio.TimeoutError``\.
- Parameters
- -----------
- timeout: float
- The number of seconds to wait before returning ``None``.
- user: :class:`Member` or :class:`User`
- The user the reaction must be from.
- emoji: str or :class:`Emoji` or sequence
- The emoji that we are waiting to react with.
- message: :class:`Message`
- The message that we want the reaction to be from.
- check: function
- A predicate for other complicated checks. The predicate must take
- ``(reaction, user)`` as its two parameters, which ``reaction`` being a
- :class:`Reaction` and ``user`` being either a :class:`User` or a
- :class:`Member`.
+ Raises
+ -------
+ asyncio.TimeoutError
+ If a timeout is provided and it was reached.
Returns
--------
- namedtuple
- A namedtuple with attributes ``reaction`` and ``user`` similar to :func:`on_reaction_add`.
+ Any
+ Returns no arguments, a single argument, or a tuple of multiple
+ arguments that mirrors the parameters passed in the
+ :ref:`event reference <discord-api-events>`.
"""
- if emoji is None:
- emoji_check = lambda r: True
- elif isinstance(emoji, (str, Emoji)):
- emoji_check = lambda r: r.emoji == emoji
- else:
- emoji_check = lambda r: r.emoji in emoji
-
- def predicate(reaction, reaction_user):
- result = emoji_check(reaction)
-
- if message is not None:
- result = result and message.id == reaction.message.id
-
- if user is not None:
- result = result and user.id == reaction_user.id
-
- if callable(check):
- # the exception thrown by check is propagated through the future.
- result = result and check(reaction, reaction_user)
-
- return result
-
future = compat.create_future(self.loop)
- self._listeners.append((predicate, future, WaitForType.reaction))
+ if check is None:
+ def _check(*args):
+ return True
+ check = _check
+
+ ev = event.lower()
try:
- return (yield from asyncio.wait_for(future, timeout, loop=self.loop))
- except asyncio.TimeoutError:
- return None
+ listeners = self._listeners[ev]
+ except KeyError:
+ listeners = []
+ self._listeners[ev] = listeners
+
+ listeners.append((future, check))
+ return asyncio.wait_for(future, timeout, loop=self.loop)
# event registration