aboutsummaryrefslogtreecommitdiff
path: root/discord/voice_client.py
diff options
context:
space:
mode:
authorJosh <[email protected]>2021-06-28 14:59:14 +1000
committerGitHub <[email protected]>2021-06-28 00:59:14 -0400
commit5acea453cccc1501e587811f3223f4888697f553 (patch)
tree11973c4eac5793d8185b7c6b83239d83333e1199 /discord/voice_client.py
parentTypehint Activity (diff)
downloaddiscord.py-5acea453cccc1501e587811f3223f4888697f553.tar.xz
discord.py-5acea453cccc1501e587811f3223f4888697f553.zip
Type-hint voice_client / player
Diffstat (limited to 'discord/voice_client.py')
-rw-r--r--discord/voice_client.py164
1 files changed, 97 insertions, 67 deletions
diff --git a/discord/voice_client.py b/discord/voice_client.py
index 2ae2a8b1..18cbb732 100644
--- a/discord/voice_client.py
+++ b/discord/voice_client.py
@@ -20,9 +20,9 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
-"""
-"""Some documentation to refer to:
+
+Some documentation to refer to:
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
@@ -37,21 +37,41 @@ DEALINGS IN THE SOFTWARE.
- Finally we can transmit data to endpoint:port.
"""
+from __future__ import annotations
+
import asyncio
import socket
import logging
import struct
import threading
-from typing import Any, Callable
+from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple
from . import opus, utils
from .backoff import ExponentialBackoff
from .gateway import *
from .errors import ClientException, ConnectionClosed
from .player import AudioPlayer, AudioSource
+from .utils import MISSING
+
+if TYPE_CHECKING:
+ from .client import Client
+ from .guild import Guild
+ from .state import ConnectionState
+ from .user import ClientUser
+ from .opus import Encoder
+ from . import abc
+
+ from .types.voice import (
+ GuildVoiceState as GuildVoiceStatePayload,
+ VoiceServerUpdate as VoiceServerUpdatePayload,
+ SupportedModes,
+ )
+
+
+has_nacl: bool
try:
- import nacl.secret
+ import nacl.secret # type: ignore
has_nacl = True
except ImportError:
has_nacl = False
@@ -61,7 +81,10 @@ __all__ = (
'VoiceClient',
)
-log = logging.getLogger(__name__)
+
+
+
+log: logging.Logger = logging.getLogger(__name__)
class VoiceProtocol:
"""A class that represents the Discord voice protocol.
@@ -84,11 +107,11 @@ class VoiceProtocol:
The voice channel that is being connected to.
"""
- def __init__(self, client, channel):
- self.client = client
- self.channel = channel
+ def __init__(self, client: Client, channel: abc.Connectable) -> None:
+ self.client: Client = client
+ self.channel: abc.Connectable = channel
- async def on_voice_state_update(self, data):
+ async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
"""|coro|
An abstract method that is called when the client's voice state
@@ -105,7 +128,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
- async def on_voice_server_update(self, data):
+ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
"""|coro|
An abstract method that is called when initially connecting to voice.
@@ -122,7 +145,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
- async def connect(self, *, timeout: float, reconnect: bool):
+ async def connect(self, *, timeout: float, reconnect: bool) -> None:
"""|coro|
An abstract method called when the client initiates the connection request.
@@ -145,7 +168,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
- async def disconnect(self, *, force: bool):
+ async def disconnect(self, *, force: bool) -> None:
"""|coro|
An abstract method called when the client terminates the connection.
@@ -159,7 +182,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
- def cleanup(self):
+ def cleanup(self) -> None:
"""This method *must* be called to ensure proper clean-up during a disconnect.
It is advisable to call this from within :meth:`disconnect` when you are
@@ -198,48 +221,55 @@ class VoiceClient(VoiceProtocol):
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
"""
- def __init__(self, client, channel):
+ endpoint_ip: str
+ voice_port: int
+ secret_key: List[int]
+ ssrc: int
+
+
+ def __init__(self, client: Client, channel: abc.Connectable):
if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice")
super().__init__(client, channel)
state = client._connection
- self.token = None
- self.socket = None
- self.loop = state.loop
- self._state = state
+ self.token: str = MISSING
+ self.socket = MISSING
+ self.loop: asyncio.AbstractEventLoop = state.loop
+ self._state: ConnectionState = state
# this will be used in the AudioPlayer thread
- self._connected = threading.Event()
-
- self._handshaking = False
- self._potentially_reconnecting = False
- self._voice_state_complete = asyncio.Event()
- self._voice_server_complete = asyncio.Event()
-
- self.mode = None
- self._connections = 0
- self.sequence = 0
- self.timestamp = 0
- self._runner = None
- self._player = None
- self.encoder = None
- self._lite_nonce = 0
- self.ws = None
+ self._connected: threading.Event = threading.Event()
+
+ self._handshaking: bool = False
+ self._potentially_reconnecting: bool = False
+ self._voice_state_complete: asyncio.Event = asyncio.Event()
+ self._voice_server_complete: asyncio.Event = asyncio.Event()
+
+ self.mode: str = MISSING
+ self._connections: int = 0
+ self.sequence: int = 0
+ self.timestamp: int = 0
+ self.timeout: float = 0
+ self._runner: asyncio.Task = MISSING
+ self._player: Optional[AudioPlayer] = None
+ self.encoder: Encoder = MISSING
+ self._lite_nonce: int = 0
+ self.ws: DiscordVoiceWebSocket = MISSING
warn_nacl = not has_nacl
- supported_modes = (
+ supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix',
'xsalsa20_poly1305',
)
@property
- def guild(self):
+ def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild we're connected to, if applicable."""
return getattr(self.channel, 'guild', None)
@property
- def user(self):
+ def user(self) -> ClientUser:
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
return self._state.user
@@ -252,7 +282,7 @@ class VoiceClient(VoiceProtocol):
# connection related
- async def on_voice_state_update(self, data):
+ async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id']
channel_id = data['channel_id']
@@ -265,11 +295,11 @@ class VoiceClient(VoiceProtocol):
await self.disconnect()
else:
guild = self.guild
- self.channel = channel_id and guild and guild.get_channel(int(channel_id))
+ self.channel = channel_id and guild and guild.get_channel(int(channel_id)) # type: ignore
else:
self._voice_state_complete.set()
- async def on_voice_server_update(self, data):
+ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._voice_server_complete.is_set():
log.info('Ignoring extraneous voice server update.')
return
@@ -289,7 +319,7 @@ class VoiceClient(VoiceProtocol):
self.endpoint = self.endpoint[6:]
# This gets set later
- self.endpoint_ip = None
+ self.endpoint_ip = MISSING
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
@@ -301,27 +331,27 @@ class VoiceClient(VoiceProtocol):
self._voice_server_complete.set()
- async def voice_connect(self):
+ async def voice_connect(self) -> None:
await self.channel.guild.change_voice_state(channel=self.channel)
- async def voice_disconnect(self):
+ async def voice_disconnect(self) -> None:
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
- def prepare_handshake(self):
+ def prepare_handshake(self) -> None:
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
self._connections += 1
- def finish_handshake(self):
+ def finish_handshake(self) -> None:
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
- async def connect_websocket(self):
+ async def connect_websocket(self) -> DiscordVoiceWebSocket:
ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while ws.secret_key is None:
@@ -329,7 +359,7 @@ class VoiceClient(VoiceProtocol):
self._connected.set()
return ws
- async def connect(self, *, reconnect: bool, timeout: bool):
+ async def connect(self, *, reconnect: bool, timeout: float) ->None:
log.info('Connecting to voice...')
self.timeout = timeout
@@ -365,10 +395,10 @@ class VoiceClient(VoiceProtocol):
else:
raise
- if self._runner is None:
+ if self._runner is MISSING:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
- async def potential_reconnect(self):
+ async def potential_reconnect(self) -> bool:
# Attempt to stop the player thread from playing early
self._connected.clear()
self.prepare_handshake()
@@ -391,7 +421,7 @@ class VoiceClient(VoiceProtocol):
return True
@property
- def latency(self):
+ def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This could be referred to as the Discord Voice WebSocket latency and is
@@ -403,7 +433,7 @@ class VoiceClient(VoiceProtocol):
return float("inf") if not ws else ws.latency
@property
- def average_latency(self):
+ def average_latency(self) -> float:
""":class:`float`: Average of most recent 20 HEARTBEAT latencies in seconds.
.. versionadded:: 1.4
@@ -411,7 +441,7 @@ class VoiceClient(VoiceProtocol):
ws = self.ws
return float("inf") if not ws else ws.average_latency
- async def poll_voice_ws(self, reconnect):
+ async def poll_voice_ws(self, reconnect: bool) -> None:
backoff = ExponentialBackoff()
while True:
try:
@@ -452,7 +482,7 @@ class VoiceClient(VoiceProtocol):
log.warning('Could not connect to voice... Retrying...')
continue
- async def disconnect(self, *, force: bool = False):
+ async def disconnect(self, *, force: bool = False) -> None:
"""|coro|
Disconnects this voice client from voice.
@@ -473,7 +503,7 @@ class VoiceClient(VoiceProtocol):
if self.socket:
self.socket.close()
- async def move_to(self, channel):
+ async def move_to(self, channel: abc.Snowflake) -> None:
"""|coro|
Moves you to a different voice channel.
@@ -485,7 +515,7 @@ class VoiceClient(VoiceProtocol):
"""
await self.channel.guild.change_voice_state(channel=channel)
- def is_connected(self):
+ def is_connected(self) -> bool:
"""Indicates if the voice client is connected to voice."""
return self._connected.is_set()
@@ -504,20 +534,20 @@ class VoiceClient(VoiceProtocol):
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
return encrypt_packet(header, data)
- def _encrypt_xsalsa20_poly1305(self, header, data):
+ def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes:
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:12] = header
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext
- def _encrypt_xsalsa20_poly1305_suffix(self, header, data):
+ def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data) -> bytes:
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE)
return header + box.encrypt(bytes(data), nonce).ciphertext + nonce
- def _encrypt_xsalsa20_poly1305_lite(self, header, data):
+ def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes:
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
@@ -526,7 +556,7 @@ class VoiceClient(VoiceProtocol):
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
- def play(self, source: AudioSource, *, after: Callable[[Exception], Any]=None):
+ def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any]=None) -> None:
"""Plays an :class:`AudioSource`.
The finalizer, ``after`` is called after the source has been exhausted
@@ -570,32 +600,32 @@ class VoiceClient(VoiceProtocol):
self._player = AudioPlayer(source, self, after=after)
self._player.start()
- def is_playing(self):
+ def is_playing(self) -> bool:
"""Indicates if we're currently playing audio."""
return self._player is not None and self._player.is_playing()
- def is_paused(self):
+ def is_paused(self) -> bool:
"""Indicates if we're playing audio, but if we're paused."""
return self._player is not None and self._player.is_paused()
- def stop(self):
+ def stop(self) -> None:
"""Stops playing audio."""
if self._player:
self._player.stop()
self._player = None
- def pause(self):
+ def pause(self) -> None:
"""Pauses the audio playing."""
if self._player:
self._player.pause()
- def resume(self):
+ def resume(self) -> None:
"""Resumes the audio playing."""
if self._player:
self._player.resume()
@property
- def source(self):
+ def source(self) -> Optional[AudioSource]:
"""Optional[:class:`AudioSource`]: The audio source being played, if playing.
This property can also be used to change the audio source currently being played.
@@ -603,7 +633,7 @@ class VoiceClient(VoiceProtocol):
return self._player.source if self._player else None
@source.setter
- def source(self, value):
+ def source(self, value: AudioSource) -> None:
if not isinstance(value, AudioSource):
raise TypeError(f'expected AudioSource not {value.__class__.__name__}.')
@@ -612,7 +642,7 @@ class VoiceClient(VoiceProtocol):
self._player._set_source(value)
- def send_audio_packet(self, data, *, encode=True):
+ def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
"""Sends an audio packet composed of the data.
You must be connected to play audio.