aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2017-01-07 21:55:47 -0500
committerRapptz <[email protected]>2017-01-07 23:19:39 -0500
commit20041ea756305f20c86a621232639932c50f107c (patch)
treefc9be7da66b1dffd274d96f85dd1cb7c605e56c2
parentFix variable shadowing in READY parsing. (diff)
downloaddiscord.py-20041ea756305f20c86a621232639932c50f107c.tar.xz
discord.py-20041ea756305f20c86a621232639932c50f107c.zip
Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process without the headaches of IPC/RPC or much difficulty.
-rw-r--r--discord/__init__.py1
-rw-r--r--discord/client.py8
-rw-r--r--discord/errors.py9
-rw-r--r--discord/gateway.py80
-rw-r--r--discord/guild.py8
-rw-r--r--discord/http.py9
-rw-r--r--discord/shard.py174
-rw-r--r--discord/state.py80
-rw-r--r--docs/api.rst3
9 files changed, 341 insertions, 31 deletions
diff --git a/discord/__init__.py b/discord/__init__.py
index e8a29e45..4d04aeb8 100644
--- a/discord/__init__.py
+++ b/discord/__init__.py
@@ -37,6 +37,7 @@ from . import utils, opus, compat, abc
from .enums import ChannelType, GuildRegion, Status, MessageType, VerificationLevel
from collections import namedtuple
from .embeds import Embed
+from .shard import AutoShardedClient
import logging
diff --git a/discord/client.py b/discord/client.py
index 2e0696c9..f8f45870 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -142,6 +142,7 @@ class Client:
self.connection = ConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
syncer=self._syncer, http=self.http, loop=self.loop, **options)
+ self.connection.shard_count = self.shard_count
self._closed = asyncio.Event(loop=self.loop)
self._is_logged_in = asyncio.Event(loop=self.loop)
self._is_ready = asyncio.Event(loop=self.loop)
@@ -405,11 +406,14 @@ class Client:
while not self.is_closed:
try:
- yield from self.ws.poll_event()
+ yield from ws.poll_event()
except (ReconnectWebSocket, ResumeWebSocket) as e:
resume = type(e) is ResumeWebSocket
log.info('Got ' + type(e).__name__)
- self.ws = yield from DiscordWebSocket.from_client(self, resume=resume)
+ self.ws = yield from DiscordWebSocket.from_client(self, shard_id=self.shard_id,
+ session=self.ws.session_id,
+ sequence=self.ws.sequence,
+ resume=resume)
except ConnectionClosed as e:
yield from self.close()
if e.code != 1000:
diff --git a/discord/errors.py b/discord/errors.py
index 5449b77e..46751b62 100644
--- a/discord/errors.py
+++ b/discord/errors.py
@@ -118,14 +118,17 @@ class ConnectionClosed(ClientException):
Attributes
-----------
- code : int
+ code: int
The close code of the websocket.
- reason : str
+ reason: str
The reason provided for the closure.
+ shard_id: Optional[int]
+ The shard ID that got closed if applicable.
"""
- def __init__(self, original):
+ def __init__(self, original, *, shard_id):
# This exception is just the same exception except
# reconfigured to subclass ClientException for users
self.code = original.code
self.reason = original.reason
+ self.shard_id = shard_id
super().__init__(str(original))
diff --git a/discord/gateway.py b/discord/gateway.py
index 2154cc98..fcba2dfc 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -47,11 +47,13 @@ __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket',
class ReconnectWebSocket(Exception):
"""Signals to handle the RECONNECT opcode."""
- pass
+ def __init__(self, shard_id):
+ self.shard_id = shard_id
class ResumeWebSocket(Exception):
"""Signals to initialise via RESUME opcode instead of IDENTIFY."""
- pass
+ def __init__(self, shard_id):
+ self.shard_id = shard_id
EventListener = namedtuple('EventListener', 'predicate event result future')
@@ -81,7 +83,7 @@ class KeepAliveHandler(threading.Thread):
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
- 'd': self.ws._connection.sequence
+ 'd': self.ws.sequence
}
def stop(self):
@@ -165,9 +167,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# the keep alive
self._keep_alive = None
+ # ws related stuff
+ self.session_id = None
+ self.sequence = None
+
@classmethod
@asyncio.coroutine
- def from_client(cls, client, *, resume=False):
+ def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@@ -180,8 +186,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
ws._connection = client.connection
ws._dispatch = client.dispatch
ws.gateway = gateway
- ws.shard_id = client.shard_id
- ws.shard_count = client.shard_count
+ ws.shard_id = shard_id
+ ws.shard_count = client.connection.shard_count
+ ws.session_id = session
+ ws.sequence = sequence
client.connection._update_references(ws)
@@ -206,6 +214,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
else:
return ws
+ @classmethod
+ @asyncio.coroutine
+ def from_sharded_client(cls, client):
+ if client.shard_count is None:
+ client.shard_count, gateway = yield from client.http.get_bot_gateway()
+ else:
+ gateway = yield from client.http.get_gateway()
+
+ ret = []
+ client.connection.shard_count = client.shard_count
+
+ for shard_id in range(client.shard_count):
+ ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
+ ws.token = client.http.token
+ ws._connection = client.connection
+ ws._dispatch = client.dispatch
+ ws.gateway = gateway
+ ws.shard_id = shard_id
+ ws.shard_count = client.shard_count
+
+ # OP HELLO
+ yield from ws.poll_event()
+ yield from ws.identify()
+ ret.append(ws)
+ log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
+ yield from asyncio.sleep(5.0, loop=client.loop)
+
+ return ret
+
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
@@ -262,12 +299,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
@asyncio.coroutine
def resume(self):
"""Sends the RESUME packet."""
- state = self._connection
payload = {
'op': self.RESUME,
'd': {
- 'seq': state.sequence,
- 'session_id': state.session_id,
+ 'seq': self.sequence,
+ 'session_id': self.session_id,
'token': self.token
}
}
@@ -283,16 +319,15 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
msg = msg.decode('utf-8')
msg = json.loads(msg)
- state = self._connection
- log.debug('WebSocket Event: {}'.format(msg))
+ log.debug('For Shard ID {}: WebSocket Event: {}'.format(self.shard_id, msg))
self._dispatch('socket_response', msg)
op = msg.get('op')
data = msg.get('d')
seq = msg.get('s')
if seq is not None:
- state.sequence = seq
+ self.sequence = seq
if op == self.RECONNECT:
# "reconnect" can only be handled by the Client
@@ -300,7 +335,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# internal exception signalling to reconnect.
log.info('Received RECONNECT opcode.')
yield from self.close()
- raise ReconnectWebSocket()
+ raise ReconnectWebSocket(self.shard_id)
if op == self.HEARTBEAT_ACK:
return # disable noisy logging for now
@@ -317,11 +352,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
return
if op == self.INVALIDATE_SESSION:
- state.sequence = None
- state.session_id = None
+ self.sequence = None
+ self.session_id = None
if data == True:
yield from self.close()
- raise ResumeWebSocket()
+ raise ResumeWebSocket(self.shard_id)
yield from self.identify()
return
@@ -334,9 +369,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
is_ready = event == 'READY'
if is_ready:
- state.clear()
- state.sequence = msg['s']
- state.session_id = data['session_id']
+ self.sequence = msg['s']
+ self.session_id = data['session_id']
parser = 'parse_' + event.lower()
@@ -389,9 +423,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
except websockets.exceptions.ConnectionClosed as e:
if self._can_handle_close(e.code):
log.info('Websocket closed with {0.code} ({0.reason}), attempting a reconnect.'.format(e))
- raise ResumeWebSocket() from e
+ raise ResumeWebSocket(self.shard_id) from e
else:
- raise ConnectionClosed(e) from e
+ raise ConnectionClosed(e, shard_id=self.shard_id) from e
@asyncio.coroutine
def send(self, data):
@@ -404,7 +438,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
yield from super().send(utils.to_json(data))
except websockets.exceptions.ConnectionClosed as e:
if not self._can_handle_close(e.code):
- raise ConnectionClosed(e) from e
+ raise ConnectionClosed(e, shard_id=self.shard_id) from e
@asyncio.coroutine
def change_presence(self, *, game=None, status=None, afk=False, since=0.0, idle=None):
@@ -615,7 +649,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
msg = yield from asyncio.wait_for(self.recv(), timeout=30.0, loop=self.loop)
yield from self.received_message(json.loads(msg))
except websockets.exceptions.ConnectionClosed as e:
- raise ConnectionClosed(e) from e
+ raise ConnectionClosed(e, shard_id=None) from e
@asyncio.coroutine
def close_connection(self, force=False):
diff --git a/discord/guild.py b/discord/guild.py
index 0f37a214..2255c297 100644
--- a/discord/guild.py
+++ b/discord/guild.py
@@ -325,6 +325,14 @@ class Guild(Hashable):
return self._member_count
@property
+ def shard_id(self):
+ """Returns the shard ID for this guild if applicable."""
+ count = self._state.shard_count
+ if count is None:
+ return None
+ return (self.id >> 22) % count
+
+ @property
def created_at(self):
"""Returns the guild's creation time in UTC."""
return utils.snowflake_time(self.id)
diff --git a/discord/http.py b/discord/http.py
index 2b885dec..4e5410d0 100644
--- a/discord/http.py
+++ b/discord/http.py
@@ -588,5 +588,14 @@ class HTTPClient:
raise GatewayNotFound() from e
return data.get('url') + '?encoding=json&v=6'
+ @asyncio.coroutine
+ def get_bot_gateway(self):
+ try:
+ data = yield from self.get(self.GATEWAY + '/bot', bucket=_func_())
+ except HTTPException as e:
+ raise GatewayNotFound() from e
+ else:
+ return data['shards'], data['url'] + '?encoding=json&v=6'
+
def get_user_info(self, user_id):
return self.get('{0.USERS}/{1}'.format(self, user_id), bucket=_func_())
diff --git a/discord/shard.py b/discord/shard.py
new file mode 100644
index 00000000..2be0ea12
--- /dev/null
+++ b/discord/shard.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-2016 Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+from .state import AutoShardedConnectionState
+from .client import Client
+from .gateway import *
+from .errors import ConnectionClosed
+from . import compat
+
+import asyncio
+import logging
+
+log = logging.getLogger(__name__)
+
+class Shard:
+ def __init__(self, ws, client):
+ self.ws = ws
+ self._client = client
+ self.loop = self._client.loop
+ self._current = asyncio.Future(loop=self.loop)
+ self._current.set_result(None) # we just need an already done future
+
+ @property
+ def id(self):
+ return self.ws.shard_id
+
+ @asyncio.coroutine
+ def poll(self):
+ try:
+ yield from self.ws.poll_event()
+ except (ReconnectWebSocket, ResumeWebSocket) as e:
+ resume = type(e) is ResumeWebSocket
+ log.info('Got ' + type(e).__name__)
+ self.ws = yield from DiscordWebSocket.from_client(self._client, resume=resume,
+ shard_id=self.id,
+ session=self.ws.session_id,
+ sequence=self.ws.sequence)
+ except ConnectionClosed as e:
+ yield from self._client.close()
+ if e.code != 1000:
+ raise
+
+ def get_future(self):
+ if self._current.done():
+ self._current = compat.create_task(self.poll(), loop=self.loop)
+
+ return self._current
+
+class AutoShardedClient(Client):
+ """A client similar to :class:`Client` except it handles the complications
+ of sharding for the user into a more manageable and transparent single
+ process bot.
+
+ When using this client, you will be able to use it as-if it was a regular
+ :class:`Client` with a single shard when implementation wise internally it
+ is split up into multiple shards. This allows you to not have to deal with
+ IPC or other complicated infrastructure.
+
+ It is recommended to use this client only if you have surpassed at least
+ 1000 guilds.
+
+ If no :attr:`shard_count` is provided, then the library will use the
+ Bot Gateway endpoint call to figure out how many shards to use.
+ """
+ def __init__(self, *args, loop=None, **kwargs):
+ kwargs.pop('shard_id', None)
+ super().__init__(*args, loop=loop, **kwargs)
+
+ self.connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self.request_offline_members,
+ syncer=self._syncer, http=self.http, loop=self.loop, **kwargs)
+
+ # instead of a single websocket, we have multiple
+ # the index is the shard_id
+ self.shards = []
+
+ @asyncio.coroutine
+ def request_offline_members(self, guild, *, shard_id=None):
+ """|coro|
+
+ Requests previously offline members from the guild to be filled up
+ into the :attr:`Guild.members` cache. This function is usually not
+ called.
+
+ When the client logs on and connects to the websocket, Discord does
+ not provide the library with offline members if the number of members
+ in the guild is larger than 250. You can check if a guild is large
+ if :attr:`Guild.large` is ``True``.
+
+ Parameters
+ -----------
+ guild: :class:`Guild` or list
+ The guild to request offline members for. If this parameter is a
+ list then it is interpreted as a list of guilds to request offline
+ members for.
+ """
+
+ try:
+ guild_id = guild.id
+ shard_id = shard_id or guild.shard_id
+ except AttributeError:
+ guild_id = [s.id for s in guild]
+
+ payload = {
+ 'op': 8,
+ 'd': {
+ 'guild_id': guild_id,
+ 'query': '',
+ 'limit': 0
+ }
+ }
+
+ ws = self.shards[shard_id].ws
+ yield from ws.send_as_json(payload)
+
+ @asyncio.coroutine
+ def connect(self):
+ """|coro|
+
+ Creates a websocket connection and lets the websocket listen
+ to messages from discord.
+
+ Raises
+ -------
+ GatewayNotFound
+ If the gateway to connect to discord is not found. Usually if this
+ is thrown then there is a discord API outage.
+ ConnectionClosed
+ The websocket connection has been terminated.
+ """
+ ret = yield from DiscordWebSocket.from_sharded_client(self)
+ self.shards = [Shard(ws, self) for ws in ret]
+
+ while not self.is_closed:
+ pollers = [shard.get_future() for shard in self.shards]
+ yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
+
+ @asyncio.coroutine
+ def close(self):
+ """|coro|
+
+ Closes the connection to discord.
+ """
+ if self.is_closed:
+ return
+
+ for shard in self.shards:
+ yield from shard.ws.close()
+
+ yield from self.http.close()
+ self._closed.set()
+ self._is_ready.clear()
diff --git a/discord/state.py b/discord/state.py
index 383b559f..bd7fbdbe 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -43,6 +43,7 @@ import datetime
import asyncio
import logging
import weakref
+import itertools
class ListenerType(enum.Enum):
chunk = 0
@@ -60,13 +61,12 @@ class ConnectionState:
self.chunker = chunker
self.syncer = syncer
self.is_bot = None
+ self.shard_count = None
self._listeners = []
self.clear()
def clear(self):
self.user = None
- self.sequence = None
- self.session_id = None
self._users = weakref.WeakValueDictionary()
self._calls = {}
self._emojis = {}
@@ -355,7 +355,8 @@ class ConnectionState:
# the reason we're doing this is so it's also removed from the
# private channel by user cache as well
channel = self._get_private_channel(channel_id)
- self._remove_private_channel(channel)
+ if channel is not None:
+ self._remove_private_channel(channel)
def parse_channel_update(self, data):
channel_type = try_enum(ChannelType, data.get('type'))
@@ -701,3 +702,76 @@ class ConnectionState:
listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
self._listeners.append(listener)
return future
+
+class AutoShardedConnectionState(ConnectionState):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
+ self._ready_task = None
+
+ @asyncio.coroutine
+ def _delay_ready(self):
+ launch = self._ready_state.launch
+ while not launch.is_set():
+ # this snippet of code is basically waiting 2 seconds
+ # until the last GUILD_CREATE was sent
+ launch.set()
+ yield from asyncio.sleep(2.0 * self.shard_count, loop=self.loop)
+
+ guilds = sorted(self._ready_state.guilds, key=lambda g: g.shard_id)
+
+ # we only want to request ~75 guilds per chunk request.
+ # we also want to split the chunks per shard_id
+ for shard_id, sub_guilds in itertools.groupby(guilds, key=lambda g: g.shard_id):
+ sub_guilds = list(sub_guilds)
+
+ # split chunks by shard ID
+ chunks = []
+ for guild in sub_guilds:
+ chunks.extend(self.chunks_needed(guild))
+
+ splits = [sub_guilds[i:i + 75] for i in range(0, len(sub_guilds), 75)]
+ for split in splits:
+ yield from self.chunker(split, shard_id=shard_id)
+
+ # wait for the chunks
+ if chunks:
+ try:
+ yield from asyncio.wait(chunks, timeout=len(chunks) * 30.0, loop=self.loop)
+ except asyncio.TimeoutError:
+ log.info('Somehow timed out waiting for chunks for %s shard_id' % shard_id)
+
+ self.dispatch('shard_ready', shard_id)
+
+ # sleep a second for every shard ID.
+ # yield from asyncio.sleep(1.0, loop=self.loop)
+
+ # remove the state
+ try:
+ del self._ready_state
+ except AttributeError:
+ pass # already been deleted somehow
+
+ # regular users cannot shard so we won't worry about it here.
+
+ # dispatch the event
+ self.dispatch('ready')
+
+ def parse_ready(self, data):
+ if not hasattr(self, '_ready_state'):
+ self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
+
+ self.user = self.store_user(data['user'])
+
+ guilds = self._ready_state.guilds
+ for guild_data in data['guilds']:
+ guild = self._add_guild_from_data(guild_data)
+ if not self.is_bot or guild.large:
+ guilds.append(guild)
+
+ for pm in data.get('private_channels', []):
+ factory, _ = _channel_factory(pm['type'])
+ self._add_private_channel(factory(me=self.user, data=pm, state=self))
+
+ if self._ready_task is None:
+ self._ready_task = compat.create_task(self._delay_ready(), loop=self.loop)
diff --git a/docs/api.rst b/docs/api.rst
index bcbf1f47..e001d8bb 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -37,6 +37,9 @@ Client
.. autoclass:: Client
:members:
+.. autoclass:: AutoShardedClient
+ :members:
+
Voice
-----