aboutsummaryrefslogtreecommitdiff
path: root/discord/gateway.py
diff options
context:
space:
mode:
authorRapptz <[email protected]>2017-01-07 21:55:47 -0500
committerRapptz <[email protected]>2017-01-07 23:19:39 -0500
commit20041ea756305f20c86a621232639932c50f107c (patch)
treefc9be7da66b1dffd274d96f85dd1cb7c605e56c2 /discord/gateway.py
parentFix variable shadowing in READY parsing. (diff)
downloaddiscord.py-20041ea756305f20c86a621232639932c50f107c.tar.xz
discord.py-20041ea756305f20c86a621232639932c50f107c.zip
Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process without the headaches of IPC/RPC or much difficulty.
Diffstat (limited to 'discord/gateway.py')
-rw-r--r--discord/gateway.py80
1 files changed, 57 insertions, 23 deletions
diff --git a/discord/gateway.py b/discord/gateway.py
index 2154cc98..fcba2dfc 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -47,11 +47,13 @@ __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket',
class ReconnectWebSocket(Exception):
"""Signals to handle the RECONNECT opcode."""
- pass
+ def __init__(self, shard_id):
+ self.shard_id = shard_id
class ResumeWebSocket(Exception):
"""Signals to initialise via RESUME opcode instead of IDENTIFY."""
- pass
+ def __init__(self, shard_id):
+ self.shard_id = shard_id
EventListener = namedtuple('EventListener', 'predicate event result future')
@@ -81,7 +83,7 @@ class KeepAliveHandler(threading.Thread):
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
- 'd': self.ws._connection.sequence
+ 'd': self.ws.sequence
}
def stop(self):
@@ -165,9 +167,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# the keep alive
self._keep_alive = None
+ # ws related stuff
+ self.session_id = None
+ self.sequence = None
+
@classmethod
@asyncio.coroutine
- def from_client(cls, client, *, resume=False):
+ def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@@ -180,8 +186,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
ws._connection = client.connection
ws._dispatch = client.dispatch
ws.gateway = gateway
- ws.shard_id = client.shard_id
- ws.shard_count = client.shard_count
+ ws.shard_id = shard_id
+ ws.shard_count = client.connection.shard_count
+ ws.session_id = session
+ ws.sequence = sequence
client.connection._update_references(ws)
@@ -206,6 +214,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
else:
return ws
+ @classmethod
+ @asyncio.coroutine
+ def from_sharded_client(cls, client):
+ if client.shard_count is None:
+ client.shard_count, gateway = yield from client.http.get_bot_gateway()
+ else:
+ gateway = yield from client.http.get_gateway()
+
+ ret = []
+ client.connection.shard_count = client.shard_count
+
+ for shard_id in range(client.shard_count):
+ ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
+ ws.token = client.http.token
+ ws._connection = client.connection
+ ws._dispatch = client.dispatch
+ ws.gateway = gateway
+ ws.shard_id = shard_id
+ ws.shard_count = client.shard_count
+
+ # OP HELLO
+ yield from ws.poll_event()
+ yield from ws.identify()
+ ret.append(ws)
+ log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
+ yield from asyncio.sleep(5.0, loop=client.loop)
+
+ return ret
+
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
@@ -262,12 +299,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
@asyncio.coroutine
def resume(self):
"""Sends the RESUME packet."""
- state = self._connection
payload = {
'op': self.RESUME,
'd': {
- 'seq': state.sequence,
- 'session_id': state.session_id,
+ 'seq': self.sequence,
+ 'session_id': self.session_id,
'token': self.token
}
}
@@ -283,16 +319,15 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
msg = msg.decode('utf-8')
msg = json.loads(msg)
- state = self._connection
- log.debug('WebSocket Event: {}'.format(msg))
+ log.debug('For Shard ID {}: WebSocket Event: {}'.format(self.shard_id, msg))
self._dispatch('socket_response', msg)
op = msg.get('op')
data = msg.get('d')
seq = msg.get('s')
if seq is not None:
- state.sequence = seq
+ self.sequence = seq
if op == self.RECONNECT:
# "reconnect" can only be handled by the Client
@@ -300,7 +335,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# internal exception signalling to reconnect.
log.info('Received RECONNECT opcode.')
yield from self.close()
- raise ReconnectWebSocket()
+ raise ReconnectWebSocket(self.shard_id)
if op == self.HEARTBEAT_ACK:
return # disable noisy logging for now
@@ -317,11 +352,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
return
if op == self.INVALIDATE_SESSION:
- state.sequence = None
- state.session_id = None
+ self.sequence = None
+ self.session_id = None
if data == True:
yield from self.close()
- raise ResumeWebSocket()
+ raise ResumeWebSocket(self.shard_id)
yield from self.identify()
return
@@ -334,9 +369,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
is_ready = event == 'READY'
if is_ready:
- state.clear()
- state.sequence = msg['s']
- state.session_id = data['session_id']
+ self.sequence = msg['s']
+ self.session_id = data['session_id']
parser = 'parse_' + event.lower()
@@ -389,9 +423,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
except websockets.exceptions.ConnectionClosed as e:
if self._can_handle_close(e.code):
log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e))
- raise ResumeWebSocket() from e
+ raise ResumeWebSocket(self.shard_id) from e
else:
- raise ConnectionClosed(e) from e
+ raise ConnectionClosed(e, shard_id=self.shard_id) from e
@asyncio.coroutine
def send(self, data):
@@ -404,7 +438,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
yield from super().send(utils.to_json(data))
except websockets.exceptions.ConnectionClosed as e:
if not self._can_handle_close(e.code):
- raise ConnectionClosed(e) from e
+ raise ConnectionClosed(e, shard_id=self.shard_id) from e
@asyncio.coroutine
def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None):
@@ -615,7 +649,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop)
yield from self.received_message(json.loads(msg))
except websockets.exceptions.ConnectionClosed as e:
- raise ConnectionClosed(e) from e
+ raise ConnectionClosed(e, shard_id=None) from e
@asyncio.coroutine
def close_connection(self, force=False):