aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/__init__.py2
-rw-r--r--discord/abc.py19
-rw-r--r--discord/client.py6
-rw-r--r--discord/ext/commands/context.py2
-rw-r--r--discord/guild.py2
-rw-r--r--discord/shard.py1
-rw-r--r--discord/state.py14
-rw-r--r--discord/voice_client.py287
-rw-r--r--docs/api.rst3
9 files changed, 230 insertions, 106 deletions
diff --git a/discord/__init__.py b/discord/__init__.py
index 78b25e31..c6b21593 100644
--- a/discord/__init__.py
+++ b/discord/__init__.py
@@ -54,7 +54,7 @@ from .mentions import AllowedMentions
from .shard import AutoShardedClient, ShardInfo
from .player import *
from .webhook import *
-from .voice_client import VoiceClient
+from .voice_client import VoiceClient, VoiceProtocol
from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff
from .raw_models import *
from .team import *
diff --git a/discord/abc.py b/discord/abc.py
index 4024334d..75624e18 100644
--- a/discord/abc.py
+++ b/discord/abc.py
@@ -36,7 +36,7 @@ from .permissions import PermissionOverwrite, Permissions
from .role import Role
from .invite import Invite
from .file import File
-from .voice_client import VoiceClient
+from .voice_client import VoiceClient, VoiceProtocol
from . import utils
class _Undefined:
@@ -1053,7 +1053,6 @@ class Messageable(metaclass=abc.ABCMeta):
"""
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
-
class Connectable(metaclass=abc.ABCMeta):
"""An ABC that details the common operations on a channel that can
connect to a voice server.
@@ -1072,7 +1071,7 @@ class Connectable(metaclass=abc.ABCMeta):
def _get_voice_state_pair(self):
raise NotImplementedError
- async def connect(self, *, timeout=60.0, reconnect=True):
+ async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient):
"""|coro|
Connects to voice and creates a :class:`VoiceClient` to establish
@@ -1086,6 +1085,9 @@ class Connectable(metaclass=abc.ABCMeta):
Whether the bot should automatically attempt
a reconnect if a part of the handshake fails
or the gateway goes down.
+ cls: Type[:class:`VoiceProtocol`]
+ A type that subclasses :class:`~discord.VoiceProtocol` to connect with.
+ Defaults to :class:`~discord.VoiceClient`.
Raises
-------
@@ -1098,20 +1100,25 @@ class Connectable(metaclass=abc.ABCMeta):
Returns
--------
- :class:`~discord.VoiceClient`
+ :class:`~discord.VoiceProtocol`
A voice client that is fully connected to the voice server.
"""
+
+ if not issubclass(cls, VoiceProtocol):
+ raise TypeError('Type must meet VoiceProtocol abstract base class.')
+
key_id, _ = self._get_voice_client_key()
state = self._state
if state._get_voice_client(key_id):
raise ClientException('Already connected to a voice channel.')
- voice = VoiceClient(state=state, timeout=timeout, channel=self)
+ client = state._get_client()
+ voice = cls(client, self)
state._add_voice_client(key_id, voice)
try:
- await voice.connect(reconnect=reconnect)
+ await voice.connect(timeout=timeout, reconnect=reconnect)
except asyncio.TimeoutError:
try:
await voice.disconnect(force=True)
diff --git a/discord/client.py b/discord/client.py
index 9bc5dd12..407fd47f 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -238,6 +238,7 @@ class Client:
self._closed = False
self._ready = asyncio.Event()
self._connection._get_websocket = self._get_websocket
+ self._connection._get_client = lambda: self
if VoiceClient.warn_nacl:
VoiceClient.warn_nacl = False
@@ -299,7 +300,10 @@ class Client:
@property
def voice_clients(self):
- """List[:class:`.VoiceClient`]: Represents a list of voice connections."""
+ """List[:class:`.VoiceProtocol`]: Represents a list of voice connections.
+
+ These are usually :class:`.VoiceClient` instances.
+ """
return self._connection.voice_clients
def is_ready(self):
diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py
index 8b8cf4bc..3cf851c6 100644
--- a/discord/ext/commands/context.py
+++ b/discord/ext/commands/context.py
@@ -238,7 +238,7 @@ class Context(discord.abc.Messageable):
@property
def voice_client(self):
- r"""Optional[:class:`.VoiceClient`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
+ r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild
return g.voice_client if g else None
diff --git a/discord/guild.py b/discord/guild.py
index 4c6013a3..0bf94a28 100644
--- a/discord/guild.py
+++ b/discord/guild.py
@@ -377,7 +377,7 @@ class Guild(Hashable):
@property
def voice_client(self):
- """Optional[:class:`VoiceClient`]: Returns the :class:`VoiceClient` associated with this guild, if any."""
+ """Optional[:class:`VoiceProtocol`]: Returns the :class:`VoiceProtocol` associated with this guild, if any."""
return self._state._get_voice_client(self.id)
@property
diff --git a/discord/shard.py b/discord/shard.py
index f6320678..ef29d590 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -292,6 +292,7 @@ class AutoShardedClient(Client):
# the key is the shard_id
self.__shards = {}
self._connection._get_websocket = self._get_websocket
+ self._connection._get_client = lambda: self
self.__queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None):
diff --git a/discord/state.py b/discord/state.py
index f0e93d35..fc297d03 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -63,6 +63,12 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
log = logging.getLogger(__name__)
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
+async def logging_coroutine(coroutine, *, info):
+ try:
+ await coroutine
+ except Exception:
+ log.exception('Exception occurred during %s', info)
+
class ConnectionState:
def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
self.loop = loop
@@ -939,9 +945,8 @@ class ConnectionState:
if int(data['user_id']) == self.user.id:
voice = self._get_voice_client(guild.id)
if voice is not None:
- ch = guild.get_channel(channel_id)
- if ch is not None:
- voice.channel = ch
+ coro = voice.on_voice_state_update(data)
+ asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice state update handler'))
member, before, after = guild._update_voice_state(data, channel_id)
if member is not None:
@@ -962,7 +967,8 @@ class ConnectionState:
vc = self._get_voice_client(key_id)
if vc is not None:
- asyncio.ensure_future(vc._create_socket(key_id, data))
+ coro = vc.on_voice_server_update(data)
+ asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
def parse_typing_start(self, data):
channel, guild = self._get_guild_channel(data)
diff --git a/discord/voice_client.py b/discord/voice_client.py
index ab9a6406..a1a7109a 100644
--- a/discord/voice_client.py
+++ b/discord/voice_client.py
@@ -45,7 +45,7 @@ import logging
import struct
import threading
-from . import opus
+from . import opus, utils
from .backoff import ExponentialBackoff
from .gateway import *
from .errors import ClientException, ConnectionClosed
@@ -59,7 +59,110 @@ except ImportError:
log = logging.getLogger(__name__)
-class VoiceClient:
+class VoiceProtocol:
+ """A class that represents the Discord voice protocol.
+
+ This is an abstract class. The library provides a concrete implementation
+ under :class:`VoiceClient`.
+
+ This class allows you to implement a protocol to allow for an external
+ method of sending voice, such as Lavalink_ or a native library implementation.
+
+ These classes are passed to :meth:`abc.Connectable.connect`.
+
+ .. _Lavalink: https://github.com/Frederikam/Lavalink
+
+ Parameters
+ ------------
+ client: :class:`Client`
+ The client (or its subclasses) that started the connection request.
+ channel: :class:`abc.Connectable`
+ The voice channel that is being connected to.
+ """
+
+ def __init__(self, client, channel):
+ self.client = client
+ self.channel = channel
+
+ async def on_voice_state_update(self, data):
+ """|coro|
+
+ An abstract method that is called when the client's voice state
+ has changed. This corresponds to ``VOICE_STATE_UPDATE``.
+
+ Parameters
+ ------------
+ data: :class:`dict`
+ The raw `voice state payload`_.
+
+ .. _voice state payload: https://discord.com/developers/docs/resources/voice#voice-state-object
+ """
+ raise NotImplementedError
+
+ async def on_voice_server_update(self, data):
+ """|coro|
+
+ An abstract method that is called when initially connecting to voice.
+ This corresponds to ``VOICE_SERVER_UPDATE``.
+
+ Parameters
+ ------------
+ data: :class:`dict`
+ The raw `voice server update payload`__.
+
+ .. _VSU: https://discord.com/developers/docs/topics/gateway#voice-server-update-voice-server-update-event-fields
+
+ __ VSU_
+ """
+ raise NotImplementedError
+
+ async def connect(self, *, timeout, reconnect):
+ """|coro|
+
+ An abstract method called when the client initiates the connection request.
+
+ When a connection is requested initially, the library calls the following functions
+ in order:
+
+ - ``__init__``
+
+ Parameters
+ ------------
+ timeout: :class:`float`
+ The timeout for the connection.
+ reconnect: :class:`bool`
+ Whether reconnection is expected.
+ """
+ raise NotImplementedError
+
+ async def disconnect(self, *, force):
+ """|coro|
+
+ An abstract method called when the client terminates the connection.
+
+ See :meth:`cleanup`.
+
+ Parameters
+ ------------
+ force: :class:`bool`
+ Whether the disconnection was forced.
+ """
+ raise NotImplementedError
+
+ def cleanup(self):
+ """This method *must* be called to ensure proper clean-up during a disconnect.
+
+ It is advisable to call this from within :meth:`disconnect` when you are
+ completely done with the voice protocol instance.
+
+ This method removes it from the internal state cache that keeps track of
+ currently alive voice clients. Failure to clean-up will cause subsequent
+ connections to report that it's still connected.
+ """
+ key_id, _ = self.channel._get_voice_client_key()
+ self.client._connection._remove_voice_client(key_id)
+
+class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection.
You do not create these, you typically get them from
@@ -85,14 +188,13 @@ class VoiceClient:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
"""
- def __init__(self, state, timeout, channel):
+ def __init__(self, client, channel):
if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice")
- self.channel = channel
- self.main_ws = None
- self.timeout = timeout
- self.ws = None
+ super().__init__(client, channel)
+ state = client._connection
+ self.token = None
self.socket = None
self.loop = state.loop
self._state = state
@@ -100,8 +202,8 @@ class VoiceClient:
self._connected = threading.Event()
self._handshaking = False
- self._handshake_check = asyncio.Lock()
- self._handshake_complete = asyncio.Event()
+ self._voice_state_complete = asyncio.Event()
+ self._voice_server_complete = asyncio.Event()
self.mode = None
self._connections = 0
@@ -138,48 +240,24 @@ class VoiceClient:
# connection related
- async def start_handshake(self):
- log.info('Starting voice handshake...')
-
- guild_id, channel_id = self.channel._get_voice_state_pair()
- state = self._state
- self.main_ws = ws = state._get_websocket(guild_id)
- self._connections += 1
-
- # request joining
- await ws.voice_state(guild_id, channel_id)
-
- try:
- await asyncio.wait_for(self._handshake_complete.wait(), timeout=self.timeout)
- except asyncio.TimeoutError:
- await self.terminate_handshake(remove=True)
- raise
-
- log.info('Voice handshake complete. Endpoint found %s (IP: %s)', self.endpoint, self.endpoint_ip)
+ async def on_voice_state_update(self, data):
+ self.session_id = data['session_id']
+ channel_id = data['channel_id']
- async def terminate_handshake(self, *, remove=False):
- guild_id, channel_id = self.channel._get_voice_state_pair()
- self._handshake_complete.clear()
- await self.main_ws.voice_state(guild_id, None, self_mute=True)
- self._handshaking = False
+ if not self._handshaking:
+ # If we're done handshaking then we just need to update ourselves
+ guild = self.guild
+ self.channel = channel_id and guild and guild.get_channel(int(channel_id))
+ else:
+ self._voice_state_complete.set()
- log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', channel_id, guild_id)
- if remove:
- log.info('The voice client has been removed for Channel ID %s (Guild ID %s)', channel_id, guild_id)
- key_id, _ = self.channel._get_voice_client_key()
- self._state._remove_voice_client(key_id)
-
- async def _create_socket(self, server_id, data):
- async with self._handshake_check:
- if self._handshaking:
- log.info("Ignoring voice server update while handshake is in progress")
- return
- self._handshaking = True
+ async def on_voice_server_update(self, data):
+ if self._voice_server_complete.is_set():
+ log.info('Ignoring extraneous voice server update.')
+ return
- self._connected.clear()
- self.session_id = self.main_ws.session_id
- self.server_id = server_id
self.token = data.get('token')
+ self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint')
if endpoint is None or self.token is None:
@@ -195,23 +273,77 @@ class VoiceClient:
# This gets set later
self.endpoint_ip = None
- if self.socket:
- try:
- self.socket.close()
- except Exception:
- pass
-
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
- if self._handshake_complete.is_set():
- # terminate the websocket and handle the reconnect loop if necessary.
- self._handshake_complete.clear()
- self._handshaking = False
+ if not self._handshaking:
+ # If we're not handshaking then we need to terminate our previous connection in the websocket
await self.ws.close(4000)
return
- self._handshake_complete.set()
+ self._voice_server_complete.set()
+
+ async def voice_connect(self):
+ self._connections += 1
+ await self.channel.guild.change_voice_state(channel=self.channel)
+
+ async def voice_disconnect(self):
+ log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
+ await self.channel.guild.change_voice_state(channel=None)
+
+ async def connect(self, *, reconnect, timeout):
+ log.info('Connecting to voice...')
+ self.timeout = timeout
+ try:
+ del self.secret_key
+ except AttributeError:
+ pass
+
+
+ for i in range(5):
+ self._voice_state_complete.clear()
+ self._voice_server_complete.clear()
+ self._handshaking = True
+
+ # This has to be created before we start the flow.
+ futures = [
+ self._voice_state_complete.wait(),
+ self._voice_server_complete.wait(),
+ ]
+
+ # Start the connection flow
+ log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
+ await self.voice_connect()
+
+ try:
+ await utils.sane_wait_for(futures, timeout=timeout)
+ except asyncio.TimeoutError:
+ await self.disconnect(force=True)
+ raise
+
+ log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
+ self._handshaking = False
+ self._voice_server_complete.clear()
+ self._voice_state_complete.clear()
+
+ try:
+ self.ws = await DiscordVoiceWebSocket.from_client(self)
+ self._connected.clear()
+ while not hasattr(self, 'secret_key'):
+ await self.ws.poll_event()
+ self._connected.set()
+ break
+ except (ConnectionClosed, asyncio.TimeoutError):
+ if reconnect:
+ log.exception('Failed to connect to voice... Retrying...')
+ await asyncio.sleep(1 + i * 2.0)
+ await self.voice_disconnect()
+ continue
+ else:
+ raise
+
+ if self._runner is None:
+ self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
@property
def latency(self):
@@ -234,35 +366,6 @@ class VoiceClient:
ws = self.ws
return float("inf") if not ws else ws.average_latency
- async def connect(self, *, reconnect=True, _tries=0, do_handshake=True):
- log.info('Connecting to voice...')
- try:
- del self.secret_key
- except AttributeError:
- pass
-
- if do_handshake:
- await self.start_handshake()
-
- try:
- self.ws = await DiscordVoiceWebSocket.from_client(self)
- self._handshaking = False
- self._connected.clear()
- while not hasattr(self, 'secret_key'):
- await self.ws.poll_event()
- self._connected.set()
- except (ConnectionClosed, asyncio.TimeoutError):
- if reconnect and _tries < 5:
- log.exception('Failed to connect to voice... Retrying...')
- await asyncio.sleep(1 + _tries * 2.0)
- await self.terminate_handshake()
- await self.connect(reconnect=reconnect, _tries=_tries + 1)
- else:
- raise
-
- if self._runner is None:
- self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
-
async def poll_voice_ws(self, reconnect):
backoff = ExponentialBackoff()
while True:
@@ -287,9 +390,9 @@ class VoiceClient:
log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
self._connected.clear()
await asyncio.sleep(retry)
- await self.terminate_handshake()
+ await self.voice_disconnect()
try:
- await self.connect(reconnect=True)
+ await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
log.warning('Could not connect to voice... Retrying...')
@@ -310,8 +413,9 @@ class VoiceClient:
if self.ws:
await self.ws.close()
- await self.terminate_handshake(remove=True)
+ await self.voice_disconnect()
finally:
+ self.cleanup()
if self.socket:
self.socket.close()
@@ -325,8 +429,7 @@ class VoiceClient:
channel: :class:`abc.Snowflake`
The channel to move to. Must be a voice channel.
"""
- guild_id, _ = self.channel._get_voice_state_pair()
- await self.main_ws.voice_state(guild_id, channel.id)
+ await self.channel.guild.change_voice_state(channel=channel)
def is_connected(self):
"""Indicates if the voice client is connected to voice."""
diff --git a/docs/api.rst b/docs/api.rst
index 6b843bd1..d4af5ff1 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -54,6 +54,9 @@ Voice
.. autoclass:: VoiceClient()
:members:
+.. autoclass:: VoiceProtocol
+ :members:
+
.. autoclass:: AudioSource
:members: