aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2020-04-07 21:53:55 -0400
committerRapptz <[email protected]>2020-07-25 09:59:38 -0400
commitb8154e365ff584438a8d42354e56881e550bb72e (patch)
tree799c9bb73c25731a87cb12c04a35b420260f8970
parentFix AttributeError on reconnection (diff)
downloaddiscord.py-b8154e365ff584438a8d42354e56881e550bb72e.tar.xz
discord.py-b8154e365ff584438a8d42354e56881e550bb72e.zip
Rewrite gateway to use aiohttp instead of websockets
-rw-r--r--discord/__main__.py2
-rw-r--r--discord/client.py11
-rw-r--r--discord/errors.py9
-rw-r--r--discord/ext/tasks/__init__.py3
-rw-r--r--discord/gateway.py132
-rw-r--r--discord/http.py11
-rw-r--r--discord/shard.py22
-rw-r--r--requirements.txt1
8 files changed, 98 insertions, 93 deletions
diff --git a/discord/__main__.py b/discord/__main__.py
index 102ca30c..70854748 100644
--- a/discord/__main__.py
+++ b/discord/__main__.py
@@ -31,7 +31,6 @@ from pathlib import Path
import discord
import pkg_resources
import aiohttp
-import websockets
import platform
def show_version():
@@ -46,7 +45,6 @@ def show_version():
entries.append(' - discord.py pkg_resources: v{0}'.format(pkg.version))
entries.append('- aiohttp v{0.__version__}'.format(aiohttp))
- entries.append('- websockets v{0.__version__}'.format(websockets))
uname = platform.uname()
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
print('\n'.join(entries))
diff --git a/discord/client.py b/discord/client.py
index 0fcdcd48..85931569 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -32,7 +32,6 @@ import sys
import traceback
import aiohttp
-import websockets
from .user import User, Profile
from .asset import Asset
@@ -497,9 +496,7 @@ class Client:
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
- asyncio.TimeoutError,
- websockets.InvalidHandshake,
- websockets.WebSocketProtocolError) as exc:
+ asyncio.TimeoutError) as exc:
self.dispatch('disconnect')
if not reconnect:
@@ -632,7 +629,11 @@ class Client:
_cleanup_loop(loop)
if not future.cancelled():
- return future.result()
+ try:
+ return future.result()
+ except KeyboardInterrupt:
+ # I am unsure why this gets raised here but suppress it anyway
+ return None
# properties
diff --git a/discord/errors.py b/discord/errors.py
index 7ab73e9d..f8da42d1 100644
--- a/discord/errors.py
+++ b/discord/errors.py
@@ -159,10 +159,11 @@ class ConnectionClosed(ClientException):
shard_id: Optional[:class:`int`]
The shard ID that got closed if applicable.
"""
- def __init__(self, original, *, shard_id):
+ def __init__(self, socket, *, shard_id):
# This exception is just the same exception except
# reconfigured to subclass ClientException for users
- self.code = original.code
- self.reason = original.reason
+ self.code = socket.close_code
+ # aiohttp doesn't seem to consistently provide close reason
+ self.reason = ''
self.shard_id = shard_id
- super().__init__(str(original))
+ super().__init__('Shard ID %s WebSocket closed with %s' % (self.shard_id, self.code))
diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py
index 3fa1cb01..7921b095 100644
--- a/discord/ext/tasks/__init__.py
+++ b/discord/ext/tasks/__init__.py
@@ -27,7 +27,6 @@ DEALINGS IN THE SOFTWARE.
import asyncio
import datetime
import aiohttp
-import websockets
import discord
import inspect
import logging
@@ -58,8 +57,6 @@ class Loop:
discord.ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError,
- websockets.InvalidHandshake,
- websockets.WebSocketProtocolError,
)
self._before_loop = None
diff --git a/discord/gateway.py b/discord/gateway.py
index e5cbfe53..59dd3c1a 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -36,7 +36,7 @@ import threading
import traceback
import zlib
-import websockets
+import aiohttp
from . import utils
from .activity import BaseActivity
@@ -60,6 +60,10 @@ class ReconnectWebSocket(Exception):
self.resume = resume
self.op = 'RESUME' if resume else 'IDENTIFY'
+class WebSocketClosure(Exception):
+ """An exception to make up for the fact that aiohttp doesn't signal closure."""
+ pass
+
EventListener = namedtuple('EventListener', 'predicate event result future')
class KeepAliveHandler(threading.Thread):
@@ -160,11 +164,17 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency)
-class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
- """Implements a WebSocket for Discord's gateway v6.
+# Monkey patch certain things from the aiohttp websocket code
+# Check this whenever we update dependencies.
+OLD_CLOSE = aiohttp.ClientWebSocketResponse.close
+
+async def _new_ws_close(self, *, code: int = 4000, message: bytes = b'') -> bool:
+ return await OLD_CLOSE(self, code=code, message=message)
- This is created through :func:`create_main_websocket`. Library
- users should never create this manually.
+aiohttp.ClientWebSocketResponse.close = _new_ws_close
+
+class DiscordWebSocket:
+ """Implements a WebSocket for Discord's gateway v6.
Attributes
-----------
@@ -217,9 +227,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.max_size = None
+ def __init__(self, socket, *, loop):
+ self.socket = socket
+ self.loop = loop
+
# an empty dispatcher to prevent crashes
self._dispatch = lambda *args: None
# generic event listeners
@@ -234,14 +245,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
self._zlib = zlib.decompressobj()
self._buffer = bytearray()
+ @property
+ def open(self):
+ return not self.socket.closed
+
@classmethod
- async def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
+ async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
"""
- gateway = await client.http.get_gateway()
- ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None)
+ gateway = gateway or await client.http.get_gateway()
+ socket = await client.http.ws_connect(gateway)
+ ws = cls(socket, loop=client.loop)
# dynamically add attributes needed
ws.token = client.http.token
@@ -267,14 +283,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
return ws
await ws.resume()
- try:
- await ws.ensure_open()
- except websockets.exceptions.ConnectionClosed:
- # ws got closed so let's just do a regular IDENTIFY connect.
- log.warning('RESUME failed (the websocket decided to close) for Shard ID %s. Retrying.', shard_id)
- return await cls.from_client(client, shard_id=shard_id)
- else:
- return ws
+ return ws
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
@@ -472,8 +481,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency
- def _can_handle_close(self, code):
- return code not in (1000, 4004, 4010, 4011)
+ def _can_handle_close(self):
+ return self.socket.close_code not in (1000, 4004, 4010, 4011)
async def poll_event(self):
"""Polls for a DISPATCH event and handles the general gateway loop.
@@ -484,26 +493,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
The websocket connection was terminated for unhandled reasons.
"""
try:
- msg = await self.recv()
- await self.received_message(msg)
- except websockets.exceptions.ConnectionClosed as exc:
- if self._can_handle_close(exc.code):
- log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason)
- raise ReconnectWebSocket(self.shard_id) from exc
- else:
- log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason)
- raise ConnectionClosed(exc, shard_id=self.shard_id) from exc
+ msg = await self.socket.receive()
+ if msg.type is aiohttp.WSMsgType.TEXT:
+ await self.received_message(msg.data)
+ elif msg.type is aiohttp.WSMsgType.BINARY:
+ await self.received_message(msg.data)
+ elif msg.type is aiohttp.WSMsgType.ERROR:
+ log.debug('Received %s', msg)
+ raise msg.data
+ elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
+ log.debug('Received %s', msg)
+ raise WebSocketClosure('Unexpected WebSocket closure.')
+ except WebSocketClosure as e:
+ if self._can_handle_close():
+ log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code)
+ raise ReconnectWebSocket(self.shard_id) from e
+ elif self.socket.close_code is not None:
+ log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code)
+ raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e
async def send(self, data):
self._dispatch('socket_raw_send', data)
- await super().send(data)
+ await self.socket.send_str(data)
async def send_as_json(self, data):
try:
await self.send(utils.to_json(data))
- except websockets.exceptions.ConnectionClosed as exc:
- if not self._can_handle_close(exc.code):
- raise ConnectionClosed(exc, shard_id=self.shard_id) from exc
+ except RuntimeError as exc:
+ if not self._can_handle_close():
+ raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0):
if activity is not None:
@@ -570,19 +588,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload)
- async def close(self, code=4000, reason=''):
- if self._keep_alive:
- self._keep_alive.stop()
-
- await super().close(code, reason)
-
- async def close_connection(self, *args, **kwargs):
+ async def close(self, code=4000):
if self._keep_alive:
self._keep_alive.stop()
- await super().close_connection(*args, **kwargs)
+ await self.socket.close(code=code)
-class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
+class DiscordVoiceWebSocket:
"""Implements the websocket protocol for handling voice connections.
Attributes
@@ -626,14 +638,13 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.max_size = None
+ def __init__(self, socket):
+ self.ws = socket
self._keep_alive = None
async def send_as_json(self, data):
log.debug('Sending voice websocket frame: %s.', data)
- await self.send(utils.to_json(data))
+ await self.ws.send_str(utils.to_json(data))
async def resume(self):
state = self._connection
@@ -664,7 +675,9 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
async def from_client(cls, client, *, resume=False):
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
- ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None)
+ http = client._state.http
+ socket = await http.ws_connect(gateway)
+ ws = cls(socket)
ws.gateway = gateway
ws._connection = client
ws._max_heartbeat_timeout = 60.0
@@ -785,14 +798,19 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
await self.speak(False)
async def poll_event(self):
- try:
- msg = await asyncio.wait_for(self.recv(), timeout=30.0)
- await self.received_message(json.loads(msg))
- except websockets.exceptions.ConnectionClosed as exc:
- raise ConnectionClosed(exc, shard_id=None) from exc
-
- async def close_connection(self, *args, **kwargs):
- if self._keep_alive:
+ # This exception is handled up the chain
+ msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
+ if msg.type is aiohttp.WSMsgType.TEXT:
+ await self.received_message(json.loads(msg.data))
+ elif msg.type is aiohttp.WSMsgType.ERROR:
+ log.debug('Received %s', msg)
+ raise ConnectionClosed(self.ws, shard_id=None) from msg.data
+ elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
+ log.debug('Received %s', msg)
+ raise ConnectionClosed(self.ws, shard_id=None)
+
+ async def close(self, code=1000):
+ if self._keep_alive is not None:
self._keep_alive.stop()
- await super().close_connection(*args, **kwargs)
+ await self.ws.close(code=code)
diff --git a/discord/http.py b/discord/http.py
index 39fcaed0..a9da4267 100644
--- a/discord/http.py
+++ b/discord/http.py
@@ -111,6 +111,17 @@ class HTTPClient:
if self.__session.closed:
self.__session = aiohttp.ClientSession(connector=self.connector)
+ async def ws_connect(self, url):
+ kwargs = {
+ 'proxy_auth': self.proxy_auth,
+ 'proxy': self.proxy,
+ 'max_msg_size': 0,
+ 'timeout': 30.0,
+ 'autoclose': False,
+ }
+
+ return await self.__session.ws_connect(url, **kwargs)
+
async def request(self, route, *, files=None, **kwargs):
bucket = route.bucket
method = route.method
diff --git a/discord/shard.py b/discord/shard.py
index f2feaecb..1e34a56c 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -28,8 +28,6 @@ import asyncio
import itertools
import logging
-import websockets
-
from .state import AutoShardedConnectionState
from .client import Client
from .gateway import *
@@ -191,31 +189,13 @@ class AutoShardedClient(Client):
async def launch_shard(self, gateway, shard_id):
try:
- coro = websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket, compression=None)
+ coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception:
log.info('Failed to connect for shard_id: %s. Retrying...', shard_id)
await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id)
- ws.token = self.http.token
- ws._connection = self._connection
- ws._discord_parsers = self._connection.parsers
- ws._dispatch = self.dispatch
- ws.gateway = gateway
- ws.shard_id = shard_id
- ws.shard_count = self.shard_count
- ws._max_heartbeat_timeout = self._connection.heartbeat_timeout
-
- try:
- # OP HELLO
- await asyncio.wait_for(ws.poll_event(), timeout=180.0)
- await asyncio.wait_for(ws.identify(), timeout=180.0)
- except asyncio.TimeoutError:
- log.info('Timed out when connecting for shard_id: %s. Retrying...', shard_id)
- await asyncio.sleep(5.0)
- return await self.launch_shard(gateway, shard_id)
-
# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
ret.launch()
diff --git a/requirements.txt b/requirements.txt
index 8dfc5301..25c9da58 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,2 +1 @@
aiohttp>=3.6.0,<3.7.0
-websockets>=6.0,!=7.0,!=8.0,!=8.0.1,<9.0