diff options
| author | Rapptz <[email protected]> | 2017-01-08 01:31:46 -0500 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2017-01-08 01:31:46 -0500 |
| commit | b5bed9ef33ab9eeefc6f9c4f9006f9d2916ed4eb (patch) | |
| tree | ea473bb492b74b7d02a34367d0df717872e81167 /discord/shard.py | |
| parent | Add Guild.chunked property. (diff) | |
| download | discord.py-b5bed9ef33ab9eeefc6f9c4f9006f9d2916ed4eb.tar.xz discord.py-b5bed9ef33ab9eeefc6f9c4f9006f9d2916ed4eb.zip | |
Change the way shards are launched in AutoShardedClient.
Diffstat (limited to 'discord/shard.py')
| -rw-r--r-- | discord/shard.py | 64 |
1 files changed, 58 insertions, 6 deletions
diff --git a/discord/shard.py b/discord/shard.py index 2be0ea12..df0973b3 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -32,6 +32,7 @@ from . import compat import asyncio import logging +import websockets log = logging.getLogger(__name__) @@ -93,8 +94,10 @@ class AutoShardedClient(Client): syncer=self._syncer, http=self.http, loop=self.loop, **kwargs) # instead of a single websocket, we have multiple - # the index is the shard_id - self.shards = [] + # the key is the shard_id + self.shards = {} + + self._still_sharding = True @asyncio.coroutine def request_offline_members(self, guild, *, shard_id=None): @@ -136,6 +139,56 @@ class AutoShardedClient(Client): yield from ws.send_as_json(payload) @asyncio.coroutine + def pending_reads(self, shard): + try: + while self._still_sharding: + yield from shard.poll() + except asyncio.CancelledError: + pass + + @asyncio.coroutine + def launch_shard(self, gateway, shard_id): + try: + ws = yield from websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket) + except Exception as e: + import traceback + traceback.print_exc() + log.info('Failed to connect for shard_id: %s. Retrying...' % shard_id) + yield from asyncio.sleep(5.0, loop=self.loop) + yield from self.launch_shard(gateway, shard_id) + + ws.token = self.http.token + ws._connection = self.connection + ws._dispatch = self.dispatch + ws.gateway = gateway + ws.shard_id = shard_id + ws.shard_count = self.shard_count + + # OP HELLO + yield from ws.poll_event() + yield from ws.identify() + log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) + + # keep reading the shard while others connect + self.shards[shard_id] = ret = Shard(ws, self) + compat.create_task(self.pending_reads(ret), loop=self.loop) + yield from asyncio.sleep(5.0, loop=self.loop) + + @asyncio.coroutine + def launch_shards(self): + if self.shard_count is None: + self.shard_count, gateway = yield from self.http.get_bot_gateway() + else: + gateway = yield from self.http.get_gateway() + + self.connection.shard_count = self.shard_count + + for shard_id in range(self.shard_count): + yield from self.launch_shard(gateway, shard_id) + + self._still_sharding = False + + @asyncio.coroutine def connect(self): """|coro| @@ -150,11 +203,10 @@ class AutoShardedClient(Client): ConnectionClosed The websocket connection has been terminated. """ - ret = yield from DiscordWebSocket.from_sharded_client(self) - self.shards = [Shard(ws, self) for ws in ret] + yield from self.launch_shards() while not self.is_closed: - pollers = [shard.get_future() for shard in self.shards] + pollers = [shard.get_future() for shard in self.shards.values()] yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED) @asyncio.coroutine @@ -166,7 +218,7 @@ class AutoShardedClient(Client): if self.is_closed: return - for shard in self.shards: + for shard in self.shards.values(): yield from shard.ws.close() yield from self.http.close() |