diff options
| author | Rapptz <[email protected]> | 2017-10-14 21:17:27 -0400 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2017-10-14 21:19:46 -0400 |
| commit | 47a58d354d3c289ce8fcd56f817976a43029887f (patch) | |
| tree | 7ff778c55d34b1155bad246bfec32870db6b726c | |
| parent | Show sha1 for development versions. (diff) | |
| download | discord.py-47a58d354d3c289ce8fcd56f817976a43029887f.tar.xz discord.py-47a58d354d3c289ce8fcd56f817976a43029887f.zip | |
Reimplement zlib streaming.
This time with less bugs. It turned out that the crash was due to a
synchronisation issue between the pending reads and the actual shard
polling mechanism.
Essentially the pending reads would be cancelled via a simple bool but
there would still be a pass left and thus we would have a single
pending read left before or after running the polling mechanism and
this would cause a race condition.
Now the pending read mechanism is properly waited for before returning
control back to the caller.
| -rw-r--r-- | discord/gateway.py | 15 | ||||
| -rw-r--r-- | discord/http.py | 16 | ||||
| -rw-r--r-- | discord/shard.py | 42 |
3 files changed, 55 insertions, 18 deletions
diff --git a/discord/gateway.py b/discord/gateway.py index 0ab02760..547d3401 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -186,6 +186,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # ws related stuff self.session_id = None self.sequence = None + self._zlib = zlib.decompressobj() + self._buffer = bytearray() @classmethod @asyncio.coroutine @@ -312,8 +314,17 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): self._dispatch('socket_raw_receive', msg) if isinstance(msg, bytes): - msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB - msg = msg.decode('utf-8') + self._buffer.extend(msg) + + if len(msg) >= 4: + if msg[-4:] == b'\x00\x00\xff\xff': + msg = self._zlib.decompress(self._buffer) + msg = msg.decode('utf-8') + self._buffer = bytearray() + else: + return + else: + return msg = json.loads(msg) diff --git a/discord/http.py b/discord/http.py index fa6678ee..8c4ebb16 100644 --- a/discord/http.py +++ b/discord/http.py @@ -739,21 +739,29 @@ class HTTPClient: return self.request(Route('GET', '/oauth2/applications/@me')) @asyncio.coroutine - def get_gateway(self): + def get_gateway(self, *, encoding='json', v=6, zlib=True): try: data = yield from self.request(Route('GET', '/gateway')) except HTTPException as e: raise GatewayNotFound() from e - return data.get('url') + '?encoding=json&v=6' + if zlib: + value = '{0}?encoding={1}&v={2}&compress=zlib-stream' + else: + value = '{0}?encoding={1}&v={2}' + return value.format(data['url'], encoding, v) @asyncio.coroutine - def get_bot_gateway(self): + def get_bot_gateway(self, *, encoding='json', v=6, zlib=True): try: data = yield from self.request(Route('GET', '/gateway/bot')) except HTTPException as e: raise GatewayNotFound() from e + + if zlib: + value = '{0}?encoding={1}&v={2}&compress=zlib-stream' else: - return data['shards'], data['url'] + '?encoding=json&v=6' + value = '{0}?encoding={1}&v={2}' + return data['shards'], value.format(data['url'], encoding, v) def get_user_info(self, user_id): return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) diff --git a/discord/shard.py b/discord/shard.py index 89463059..f7f230db 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -28,7 +28,7 @@ from .state import AutoShardedConnectionState from .client import Client from .gateway import * from .errors import ClientException, InvalidArgument -from . import compat +from . import compat, utils from .enums import Status import asyncio @@ -45,11 +45,32 @@ class Shard: self.loop = self._client.loop self._current = compat.create_future(self.loop) self._current.set_result(None) # we just need an already done future + self._pending = asyncio.Event(loop=self.loop) + self._pending_task = None @property def id(self): return self.ws.shard_id + def is_pending(self): + return not self._pending.is_set() + + def complete_pending_reads(self): + self._pending.set() + + def _pending_reads(self): + try: + while self.is_pending(): + yield from self.poll() + except asyncio.CancelledError: + pass + + def launch_pending_reads(self): + self._pending_task = compat.create_task(self._pending_reads(), loop=self.loop) + + def wait(self): + return self._pending_task + @asyncio.coroutine def poll(self): try: @@ -127,7 +148,6 @@ class AutoShardedClient(Client): return self.shards[i].ws self._connection._get_websocket = _get_websocket - self._still_sharding = True @asyncio.coroutine def _chunker(self, guild, *, shard_id=None): @@ -200,14 +220,6 @@ class AutoShardedClient(Client): yield from self._connection.request_offline_members(sub_guilds, shard_id=shard_id) @asyncio.coroutine - def pending_reads(self, shard): - try: - while self._still_sharding: - yield from shard.poll() - except asyncio.CancelledError: - pass - - @asyncio.coroutine def launch_shard(self, gateway, shard_id): try: ws = yield from asyncio.wait_for(_ensure_coroutine_connect(gateway, self.loop), loop=self.loop, timeout=180.0) @@ -235,7 +247,7 @@ class AutoShardedClient(Client): # keep reading the shard while others connect self.shards[shard_id] = ret = Shard(ws, self) - compat.create_task(self.pending_reads(ret), loop=self.loop) + ret.launch_pending_reads() yield from asyncio.sleep(5.0, loop=self.loop) @asyncio.coroutine @@ -252,7 +264,13 @@ class AutoShardedClient(Client): for shard_id in shard_ids: yield from self.launch_shard(gateway, shard_id) - self._still_sharding = False + shards_to_wait_for = [] + for shard in self.shards.values(): + shard.complete_pending_reads() + shards_to_wait_for.append(shard.wait()) + + # wait for all pending tasks to finish + yield from utils.sane_wait_for(shards_to_wait_for, timeout=300.0, loop=self.loop) @asyncio.coroutine def _connect(self): |