aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2017-10-14 21:17:27 -0400
committerRapptz <[email protected]>2017-10-14 21:19:46 -0400
commit47a58d354d3c289ce8fcd56f817976a43029887f (patch)
tree7ff778c55d34b1155bad246bfec32870db6b726c
parentShow sha1 for development versions. (diff)
downloaddiscord.py-47a58d354d3c289ce8fcd56f817976a43029887f.tar.xz
discord.py-47a58d354d3c289ce8fcd56f817976a43029887f.zip
Reimplement zlib streaming.
This time with less bugs. It turned out that the crash was due to a synchronisation issue between the pending reads and the actual shard polling mechanism. Essentially the pending reads would be cancelled via a simple bool but there would still be a pass left and thus we would have a single pending read left before or after running the polling mechanism and this would cause a race condition. Now the pending read mechanism is properly waited for before returning control back to the caller.
-rw-r--r--discord/gateway.py15
-rw-r--r--discord/http.py16
-rw-r--r--discord/shard.py42
3 files changed, 55 insertions, 18 deletions
diff --git a/discord/gateway.py b/discord/gateway.py
index 0ab02760..547d3401 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -186,6 +186,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# ws related stuff
self.session_id = None
self.sequence = None
+ self._zlib = zlib.decompressobj()
+ self._buffer = bytearray()
@classmethod
@asyncio.coroutine
@@ -312,8 +314,17 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
self._dispatch('socket_raw_receive', msg)
if isinstance(msg, bytes):
- msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB
- msg = msg.decode('utf-8')
+ self._buffer.extend(msg)
+
+ if len(msg) >= 4:
+ if msg[-4:] == b'\x00\x00\xff\xff':
+ msg = self._zlib.decompress(self._buffer)
+ msg = msg.decode('utf-8')
+ self._buffer = bytearray()
+ else:
+ return
+ else:
+ return
msg = json.loads(msg)
diff --git a/discord/http.py b/discord/http.py
index fa6678ee..8c4ebb16 100644
--- a/discord/http.py
+++ b/discord/http.py
@@ -739,21 +739,29 @@ class HTTPClient:
return self.request(Route('GET', '/oauth2/applications/@me'))
@asyncio.coroutine
- def get_gateway(self):
+ def get_gateway(self, *, encoding='json', v=6, zlib=True):
try:
data = yield from self.request(Route('GET', '/gateway'))
except HTTPException as e:
raise GatewayNotFound() from e
- return data.get('url') + '?encoding=json&v=6'
+ if zlib:
+ value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
+ else:
+ value = '{0}?encoding={1}&v={2}'
+ return value.format(data['url'], encoding, v)
@asyncio.coroutine
- def get_bot_gateway(self):
+ def get_bot_gateway(self, *, encoding='json', v=6, zlib=True):
try:
data = yield from self.request(Route('GET', '/gateway/bot'))
except HTTPException as e:
raise GatewayNotFound() from e
+
+ if zlib:
+ value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
- return data['shards'], data['url'] + '?encoding=json&v=6'
+ value = '{0}?encoding={1}&v={2}'
+ return data['shards'], value.format(data['url'], encoding, v)
def get_user_info(self, user_id):
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))
diff --git a/discord/shard.py b/discord/shard.py
index 89463059..f7f230db 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -28,7 +28,7 @@ from .state import AutoShardedConnectionState
from .client import Client
from .gateway import *
from .errors import ClientException, InvalidArgument
-from . import compat
+from . import compat, utils
from .enums import Status
import asyncio
@@ -45,11 +45,32 @@ class Shard:
self.loop = self._client.loop
self._current = compat.create_future(self.loop)
self._current.set_result(None) # we just need an already done future
+ self._pending = asyncio.Event(loop=self.loop)
+ self._pending_task = None
@property
def id(self):
return self.ws.shard_id
+ def is_pending(self):
+ return not self._pending.is_set()
+
+ def complete_pending_reads(self):
+ self._pending.set()
+
+ def _pending_reads(self):
+ try:
+ while self.is_pending():
+ yield from self.poll()
+ except asyncio.CancelledError:
+ pass
+
+ def launch_pending_reads(self):
+ self._pending_task = compat.create_task(self._pending_reads(), loop=self.loop)
+
+ def wait(self):
+ return self._pending_task
+
@asyncio.coroutine
def poll(self):
try:
@@ -127,7 +148,6 @@ class AutoShardedClient(Client):
return self.shards[i].ws
self._connection._get_websocket = _get_websocket
- self._still_sharding = True
@asyncio.coroutine
def _chunker(self, guild, *, shard_id=None):
@@ -200,14 +220,6 @@ class AutoShardedClient(Client):
yield from self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
@asyncio.coroutine
- def pending_reads(self, shard):
- try:
- while self._still_sharding:
- yield from shard.poll()
- except asyncio.CancelledError:
- pass
-
- @asyncio.coroutine
def launch_shard(self, gateway, shard_id):
try:
ws = yield from asyncio.wait_for(_ensure_coroutine_connect(gateway, self.loop), loop=self.loop, timeout=180.0)
@@ -235,7 +247,7 @@ class AutoShardedClient(Client):
# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
- compat.create_task(self.pending_reads(ret), loop=self.loop)
+ ret.launch_pending_reads()
yield from asyncio.sleep(5.0, loop=self.loop)
@asyncio.coroutine
@@ -252,7 +264,13 @@ class AutoShardedClient(Client):
for shard_id in shard_ids:
yield from self.launch_shard(gateway, shard_id)
- self._still_sharding = False
+ shards_to_wait_for = []
+ for shard in self.shards.values():
+ shard.complete_pending_reads()
+ shards_to_wait_for.append(shard.wait())
+
+ # wait for all pending tasks to finish
+ yield from utils.sane_wait_for(shards_to_wait_for, timeout=300.0, loop=self.loop)
@asyncio.coroutine
def _connect(self):