From b5bed9ef33ab9eeefc6f9c4f9006f9d2916ed4eb Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 8 Jan 2017 01:31:46 -0500 Subject: Change the way shards are launched in AutoShardedClient. --- discord/shard.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 6 deletions(-) (limited to 'discord/shard.py') diff --git a/discord/shard.py b/discord/shard.py index 2be0ea12..df0973b3 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -32,6 +32,7 @@ from . import compat import asyncio import logging +import websockets log = logging.getLogger(__name__) @@ -93,8 +94,10 @@ class AutoShardedClient(Client): syncer=self._syncer, http=self.http, loop=self.loop, **kwargs) # instead of a single websocket, we have multiple - # the index is the shard_id - self.shards = [] + # the key is the shard_id + self.shards = {} + + self._still_sharding = True @asyncio.coroutine def request_offline_members(self, guild, *, shard_id=None): @@ -135,6 +138,56 @@ class AutoShardedClient(Client): ws = self.shards[shard_id].ws yield from ws.send_as_json(payload) + @asyncio.coroutine + def pending_reads(self, shard): + try: + while self._still_sharding: + yield from shard.poll() + except asyncio.CancelledError: + pass + + @asyncio.coroutine + def launch_shard(self, gateway, shard_id): + try: + ws = yield from websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket) + except Exception as e: + import traceback + traceback.print_exc() + log.info('Failed to connect for shard_id: %s. Retrying...' % shard_id) + yield from asyncio.sleep(5.0, loop=self.loop) + yield from self.launch_shard(gateway, shard_id) + + ws.token = self.http.token + ws._connection = self.connection + ws._dispatch = self.dispatch + ws.gateway = gateway + ws.shard_id = shard_id + ws.shard_count = self.shard_count + + # OP HELLO + yield from ws.poll_event() + yield from ws.identify() + log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) + + # keep reading the shard while others connect + self.shards[shard_id] = ret = Shard(ws, self) + compat.create_task(self.pending_reads(ret), loop=self.loop) + yield from asyncio.sleep(5.0, loop=self.loop) + + @asyncio.coroutine + def launch_shards(self): + if self.shard_count is None: + self.shard_count, gateway = yield from self.http.get_bot_gateway() + else: + gateway = yield from self.http.get_gateway() + + self.connection.shard_count = self.shard_count + + for shard_id in range(self.shard_count): + yield from self.launch_shard(gateway, shard_id) + + self._still_sharding = False + @asyncio.coroutine def connect(self): """|coro| @@ -150,11 +203,10 @@ class AutoShardedClient(Client): ConnectionClosed The websocket connection has been terminated. """ - ret = yield from DiscordWebSocket.from_sharded_client(self) - self.shards = [Shard(ws, self) for ws in ret] + yield from self.launch_shards() while not self.is_closed: - pollers = [shard.get_future() for shard in self.shards] + pollers = [shard.get_future() for shard in self.shards.values()] yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED) @asyncio.coroutine @@ -166,7 +218,7 @@ class AutoShardedClient(Client): if self.is_closed: return - for shard in self.shards: + for shard in self.shards.values(): yield from shard.ws.close() yield from self.http.close() -- cgit v1.2.3