diff options
| author | Rapptz <[email protected]> | 2017-01-07 21:55:47 -0500 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2017-01-07 23:19:39 -0500 |
| commit | 20041ea756305f20c86a621232639932c50f107c (patch) | |
| tree | fc9be7da66b1dffd274d96f85dd1cb7c605e56c2 /discord/gateway.py | |
| parent | Fix variable shadowing in READY parsing. (diff) | |
| download | discord.py-20041ea756305f20c86a621232639932c50f107c.tar.xz discord.py-20041ea756305f20c86a621232639932c50f107c.zip | |
Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process
without the headaches of IPC/RPC or much difficulty.
Diffstat (limited to 'discord/gateway.py')
| -rw-r--r-- | discord/gateway.py | 80 |
1 files changed, 57 insertions, 23 deletions
diff --git a/discord/gateway.py b/discord/gateway.py index 2154cc98..fcba2dfc 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -47,11 +47,13 @@ __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket', class ReconnectWebSocket(Exception): """Signals to handle the RECONNECT opcode.""" - pass + def __init__(self, shard_id): + self.shard_id = shard_id class ResumeWebSocket(Exception): """Signals to initialise via RESUME opcode instead of IDENTIFY.""" - pass + def __init__(self, shard_id): + self.shard_id = shard_id EventListener = namedtuple('EventListener', 'predicate event result future') @@ -81,7 +83,7 @@ class KeepAliveHandler(threading.Thread): def get_payload(self): return { 'op': self.ws.HEARTBEAT, - 'd': self.ws._connection.sequence + 'd': self.ws.sequence } def stop(self): @@ -165,9 +167,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # the keep alive self._keep_alive = None + # ws related stuff + self.session_id = None + self.sequence = None + @classmethod @asyncio.coroutine - def from_client(cls, client, *, resume=False): + def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False): """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -180,8 +186,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): ws._connection = client.connection ws._dispatch = client.dispatch ws.gateway = gateway - ws.shard_id = client.shard_id - ws.shard_count = client.shard_count + ws.shard_id = shard_id + ws.shard_count = client.connection.shard_count + ws.session_id = session + ws.sequence = sequence client.connection._update_references(ws) @@ -206,6 +214,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): else: return ws + @classmethod + @asyncio.coroutine + def from_sharded_client(cls, client): + if client.shard_count is None: + client.shard_count, gateway = yield from client.http.get_bot_gateway() + else: + gateway = yield from client.http.get_gateway() + + ret = [] + client.connection.shard_count = client.shard_count + + for shard_id in range(client.shard_count): + ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) + ws.token = client.http.token + ws._connection = client.connection + ws._dispatch = client.dispatch + ws.gateway = gateway + ws.shard_id = shard_id + ws.shard_count = client.shard_count + + # OP HELLO + yield from ws.poll_event() + yield from ws.identify() + ret.append(ws) + log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) + yield from asyncio.sleep(5.0, loop=client.loop) + + return ret + def wait_for(self, event, predicate, result=None): """Waits for a DISPATCH'd event that meets the predicate. @@ -262,12 +299,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): @asyncio.coroutine def resume(self): """Sends the RESUME packet.""" - state = self._connection payload = { 'op': self.RESUME, 'd': { - 'seq': state.sequence, - 'session_id': state.session_id, + 'seq': self.sequence, + 'session_id': self.session_id, 'token': self.token } } @@ -283,16 +319,15 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): msg = msg.decode('utf-8') msg = json.loads(msg) - state = self._connection - log.debug('WebSocket Event: {}'.format(msg)) + log.debug('For Shard ID {}: WebSocket Event: {}'.format(self.shard_id, msg)) self._dispatch('socket_response', msg) op = msg.get('op') data = msg.get('d') seq = msg.get('s') if seq is not None: - state.sequence = seq + self.sequence = seq if op == self.RECONNECT: # "reconnect" can only be handled by the Client @@ -300,7 +335,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # internal exception signalling to reconnect. log.info('Received RECONNECT opcode.') yield from self.close() - raise ReconnectWebSocket() + raise ReconnectWebSocket(self.shard_id) if op == self.HEARTBEAT_ACK: return # disable noisy logging for now @@ -317,11 +352,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): return if op == self.INVALIDATE_SESSION: - state.sequence = None - state.session_id = None + self.sequence = None + self.session_id = None if data == True: yield from self.close() - raise ResumeWebSocket() + raise ResumeWebSocket(self.shard_id) yield from self.identify() return @@ -334,9 +369,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): is_ready = event == 'READY' if is_ready: - state.clear() - state.sequence = msg['s'] - state.session_id = data['session_id'] + self.sequence = msg['s'] + self.session_id = data['session_id'] parser = 'parse_' + event.lower() @@ -389,9 +423,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): except websockets.exceptions.ConnectionClosed as e: if self._can_handle_close(e.code): log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e)) - raise ResumeWebSocket() from e + raise ResumeWebSocket(self.shard_id) from e else: - raise ConnectionClosed(e) from e + raise ConnectionClosed(e, shard_id=self.shard_id) from e @asyncio.coroutine def send(self, data): @@ -404,7 +438,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): yield from super().send(utils.to_json(data)) except websockets.exceptions.ConnectionClosed as e: if not self._can_handle_close(e.code): - raise ConnectionClosed(e) from e + raise ConnectionClosed(e, shard_id=self.shard_id) from e @asyncio.coroutine def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None): @@ -615,7 +649,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop) yield from self.received_message(json.loads(msg)) except websockets.exceptions.ConnectionClosed as e: - raise ConnectionClosed(e) from e + raise ConnectionClosed(e, shard_id=None) from e @asyncio.coroutine def close_connection(self, force=False): |