diff options
| author | Rapptz <[email protected]> | 2016-04-27 17:37:25 -0400 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2016-04-27 18:36:24 -0400 |
| commit | c1b5a528230765ce67fe885e0fd50058e0a820c9 (patch) | |
| tree | 2c4c519d323fcf7c7f3a3e1ddab96a50dfbed881 /discord/gateway.py | |
| parent | Begin working on gateway v4 support. (diff) | |
| download | discord.py-c1b5a528230765ce67fe885e0fd50058e0a820c9.tar.xz discord.py-c1b5a528230765ce67fe885e0fd50058e0a820c9.zip | |
Refactor voice websocket into gateway.py
Diffstat (limited to 'discord/gateway.py')
| -rw-r--r-- | discord/gateway.py | 190 |
1 files changed, 184 insertions, 6 deletions
diff --git a/discord/gateway.py b/discord/gateway.py index 2b4fc4dc..ccfc98df 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -36,11 +36,13 @@ import logging import zlib, time, json from collections import namedtuple import threading +import struct log = logging.getLogger(__name__) __all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket', - 'KeepAliveHandler' ] + 'KeepAliveHandler', 'VoiceKeepAliveHandler', + 'DiscordVoiceWebSocket' ] class ReconnectWebSocket(Exception): """Signals to handle the RECONNECT opcode.""" @@ -56,13 +58,13 @@ class KeepAliveHandler(threading.Thread): self.ws = ws self.interval = interval self.daemon = True + self.msg = 'Keeping websocket alive with sequence {0[d]}' self._stop = threading.Event() def run(self): while not self._stop.wait(self.interval): data = self.get_payload() - msg = 'Keeping websocket alive with sequence {0[d]}'.format(data) - log.debug(msg) + log.debug(self.msg.format(data)) coro = self.ws.send_as_json(data) f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: @@ -80,6 +82,17 @@ class KeepAliveHandler(threading.Thread): def stop(self): self._stop.set() +class VoiceKeepAliveHandler(KeepAliveHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.msg = 'Keeping voice websocket alive with timestamp {0[d]}' + + def get_payload(self): + return { + 'op': self.ws.HEARTBEAT, + 'd': int(time.time() * 1000) + } + @asyncio.coroutine def get_gateway(token, *, loop=None): @@ -212,7 +225,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): connection=client.connection, loop=client.loop) - def wait_for(self, event, predicate, result): + def wait_for(self, event, predicate, result=None): """Waits for a DISPATCH'd event that meets the predicate. Parameters @@ -224,7 +237,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): properties. The data parameter is the 'd' key in the JSON message. result A function that takes the same data parameter and executes to send - the result to the future. + the result to the future. If None, returns the data. Returns -------- @@ -281,6 +294,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # "reconnect" can only be handled by the Client # so we terminate our connection and raise an # internal exception signalling to reconnect. + log.info('Receivede RECONNECT opcode.') yield from self.close() raise ReconnectWebSocket() @@ -332,7 +346,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): removed.append(index) else: if valid: - future.set_result(entry.result) + ret = data if entry.result is None else entry.result(data) + future.set_result(ret) removed.append(index) for index in reversed(removed): @@ -352,6 +367,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): yield from self.received_message(msg) except websockets.exceptions.ConnectionClosed as e: if e.code in (4008, 4009) or e.code in range(1001, 1015): + log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e)) raise ReconnectWebSocket() from e else: raise ConnectionClosed(e) from e @@ -395,8 +411,170 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): me.status = status @asyncio.coroutine + def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): + payload = { + 'op': self.VOICE_STATE, + 'd': { + 'guild_id': guild_id, + 'channel_id': channel_id, + 'self_mute': self_mute, + 'self_deaf': self_deaf + } + } + + yield from self.send_as_json(payload) + + @asyncio.coroutine + def close(self, code=1000, reason=''): + if self._keep_alive: + self._keep_alive.stop() + + yield from super().close(code, reason) + +class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): + """Implements the websocket protocol for handling voice connections. + + Attributes + ----------- + IDENTIFY + Send only. Starts a new voice session. + SELECT_PROTOCOL + Send only. Tells discord what encryption mode and how to connect for voice. + READY + Receive only. Tells the websocket that the initial connection has completed. + HEARTBEAT + Send only. Keeps your websocket connection alive. + SESSION_DESCRIPTION + Receive only. Gives you the secret key required for voice. + SPEAKING + Send only. Notifies the client if you are currently speaking. + """ + + IDENTIFY = 0 + SELECT_PROTOCOL = 1 + READY = 2 + HEARTBEAT = 3 + SESSION_DESCRIPTION = 4 + SPEAKING = 5 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_size = None + self._keep_alive = None + + @asyncio.coroutine + def send_as_json(self, data): + yield from self.send(utils.to_json(data)) + + @classmethod + @asyncio.coroutine + def from_client(cls, client): + """Creates a voice websocket for the :class:`VoiceClient`.""" + gateway = 'wss://' + client.endpoint + ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) + ws.gateway = gateway + ws._connection = client + + identify = { + 'op': cls.IDENTIFY, + 'd': { + 'server_id': client.guild_id, + 'user_id': client.user.id, + 'session_id': client.session_id, + 'token': client.token + } + } + + yield from ws.send_as_json(identify) + return ws + + @asyncio.coroutine + def select_protocol(self, ip, port): + payload = { + 'op': self.SELECT_PROTOCOL, + 'd': { + 'protocol': 'udp', + 'data': { + 'address': ip, + 'port': port, + 'mode': 'xsalsa20_poly1305' + } + } + } + + yield from self.send_as_json(payload) + log.debug('Selected protocol as {}'.format(payload)) + + @asyncio.coroutine + def speak(self, is_speaking=True): + payload = { + 'op': self.SPEAKING, + 'd': { + 'speaking': is_speaking, + 'delay': 0 + } + } + + yield from self.send_as_json(payload) + log.debug('Voice speaking now set to {}'.format(is_speaking)) + + @asyncio.coroutine + def received_message(self, msg): + log.debug('Voice websocket frame received: {}'.format(msg)) + op = msg.get('op') + data = msg.get('d') + + if op == self.READY: + interval = (data['heartbeat_interval'] / 100.0) - 5 + self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=interval) + self._keep_alive.start() + yield from self.initial_connection(data) + elif op == self.SESSION_DESCRIPTION: + yield from self.load_secret_key(data) + + @asyncio.coroutine + def initial_connection(self, data): + state = self._connection + state.ssrc = data.get('ssrc') + state.voice_port = data.get('port') + packet = bytearray(70) + struct.pack_into('>I', packet, 0, state.ssrc) + state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) + recv = yield from self.loop.sock_recv(state.socket, 70) + log.debug('received packet in initial_connection: {}'.format(recv)) + + # the ip is ascii starting at the 4th byte and ending at the first null + ip_start = 4 + ip_end = recv.index(0, ip_start) + state.ip = recv[ip_start:ip_end].decode('ascii') + + # the port is a little endian unsigned short in the last two bytes + # yes, this is different endianness from everything else + state.port = struct.unpack_from('<H', recv, len(recv) - 2)[0] + + log.debug('detected ip: {0.ip} port: {0.port}'.format(state)) + yield from self.select_protocol(state.ip, state.port) + log.info('selected the voice protocol for use') + + @asyncio.coroutine + def load_secret_key(self, data): + log.info('received secret key for voice connection') + self._connection.secret_key = data.get('secret_key') + yield from self.speak() + + @asyncio.coroutine + def poll_event(self): + try: + msg = yield from self.recv() + yield from self.received_message(json.loads(msg)) + except websockets.exceptions.ConnectionClosed as e: + raise ConnectionClosed(e) from e + + @asyncio.coroutine def close(self, code=1000, reason=''): if self._keep_alive: self._keep_alive.stop() yield from super().close(code, reason) + + |