aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/gateway.py29
-rw-r--r--discord/shard.py64
2 files changed, 58 insertions, 35 deletions
diff --git a/discord/gateway.py b/discord/gateway.py
index fcba2dfc..8180f4ec 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -214,35 +214,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
else:
return ws
- @classmethod
- @asyncio.coroutine
- def from_sharded_client(cls, client):
- if client.shard_count is None:
- client.shard_count, gateway = yield from client.http.get_bot_gateway()
- else:
- gateway = yield from client.http.get_gateway()
-
- ret = []
- client.connection.shard_count = client.shard_count
-
- for shard_id in range(client.shard_count):
- ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
- ws.token = client.http.token
- ws._connection = client.connection
- ws._dispatch = client.dispatch
- ws.gateway = gateway
- ws.shard_id = shard_id
- ws.shard_count = client.shard_count
-
- # OP HELLO
- yield from ws.poll_event()
- yield from ws.identify()
- ret.append(ws)
- log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
- yield from asyncio.sleep(5.0, loop=client.loop)
-
- return ret
-
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
diff --git a/discord/shard.py b/discord/shard.py
index 2be0ea12..df0973b3 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -32,6 +32,7 @@ from . import compat
import asyncio
import logging
+import websockets
log = logging.getLogger(__name__)
@@ -93,8 +94,10 @@ class AutoShardedClient(Client):
syncer=self._syncer, http=self.http, loop=self.loop, **kwargs)
# instead of a single websocket, we have multiple
- # the index is the shard_id
- self.shards = []
+ # the key is the shard_id
+ self.shards = {}
+
+ self._still_sharding = True
@asyncio.coroutine
def request_offline_members(self, guild, *, shard_id=None):
@@ -136,6 +139,56 @@ class AutoShardedClient(Client):
yield from ws.send_as_json(payload)
@asyncio.coroutine
+ def pending_reads(self, shard):
+ try:
+ while self._still_sharding:
+ yield from shard.poll()
+ except asyncio.CancelledError:
+ pass
+
+ @asyncio.coroutine
+ def launch_shard(self, gateway, shard_id):
+ try:
+ ws = yield from websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket)
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ log.info('Failed to connect for shard_id: %s. Retrying...' % shard_id)
+ yield from asyncio.sleep(5.0, loop=self.loop)
+ yield from self.launch_shard(gateway, shard_id)
+
+ ws.token = self.http.token
+ ws._connection = self.connection
+ ws._dispatch = self.dispatch
+ ws.gateway = gateway
+ ws.shard_id = shard_id
+ ws.shard_count = self.shard_count
+
+ # OP HELLO
+ yield from ws.poll_event()
+ yield from ws.identify()
+ log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
+
+ # keep reading the shard while others connect
+ self.shards[shard_id] = ret = Shard(ws, self)
+ compat.create_task(self.pending_reads(ret), loop=self.loop)
+ yield from asyncio.sleep(5.0, loop=self.loop)
+
+ @asyncio.coroutine
+ def launch_shards(self):
+ if self.shard_count is None:
+ self.shard_count, gateway = yield from self.http.get_bot_gateway()
+ else:
+ gateway = yield from self.http.get_gateway()
+
+ self.connection.shard_count = self.shard_count
+
+ for shard_id in range(self.shard_count):
+ yield from self.launch_shard(gateway, shard_id)
+
+ self._still_sharding = False
+
+ @asyncio.coroutine
def connect(self):
"""|coro|
@@ -150,11 +203,10 @@ class AutoShardedClient(Client):
ConnectionClosed
The websocket connection has been terminated.
"""
- ret = yield from DiscordWebSocket.from_sharded_client(self)
- self.shards = [Shard(ws, self) for ws in ret]
+ yield from self.launch_shards()
while not self.is_closed:
- pollers = [shard.get_future() for shard in self.shards]
+ pollers = [shard.get_future() for shard in self.shards.values()]
yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
@asyncio.coroutine
@@ -166,7 +218,7 @@ class AutoShardedClient(Client):
if self.is_closed:
return
- for shard in self.shards:
+ for shard in self.shards.values():
yield from shard.ws.close()
yield from self.http.close()