aboutsummaryrefslogtreecommitdiff
path: root/discord/state.py
diff options
context:
space:
mode:
authorRapptz <[email protected]>2017-01-07 21:55:47 -0500
committerRapptz <[email protected]>2017-01-07 23:19:39 -0500
commit20041ea756305f20c86a621232639932c50f107c (patch)
treefc9be7da66b1dffd274d96f85dd1cb7c605e56c2 /discord/state.py
parentFix variable shadowing in READY parsing. (diff)
downloaddiscord.py-20041ea756305f20c86a621232639932c50f107c.tar.xz
discord.py-20041ea756305f20c86a621232639932c50f107c.zip
Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process without the headaches of IPC/RPC or much difficulty.
Diffstat (limited to 'discord/state.py')
-rw-r--r--discord/state.py80
1 files changed, 77 insertions, 3 deletions
diff --git a/discord/state.py b/discord/state.py
index 383b559f..bd7fbdbe 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -43,6 +43,7 @@ import datetime
import asyncio
import logging
import weakref
+import itertools
class ListenerType(enum.Enum):
chunk = 0
@@ -60,13 +61,12 @@ class ConnectionState:
self.chunker = chunker
self.syncer = syncer
self.is_bot = None
+ self.shard_count = None
self._listeners = []
self.clear()
def clear(self):
self.user = None
- self.sequence = None
- self.session_id = None
self._users = weakref.WeakValueDictionary()
self._calls = {}
self._emojis = {}
@@ -355,7 +355,8 @@ class ConnectionState:
# the reason we're doing this is so it's also removed from the
# private channel by user cache as well
channel = self._get_private_channel(channel_id)
- self._remove_private_channel(channel)
+ if channel is not None:
+ self._remove_private_channel(channel)
def parse_channel_update(self, data):
channel_type = try_enum(ChannelType, data.get('type'))
@@ -701,3 +702,76 @@ class ConnectionState:
listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
self._listeners.append(listener)
return future
+
+class AutoShardedConnectionState(ConnectionState):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
+ self._ready_task = None
+
+ @asyncio.coroutine
+ def _delay_ready(self):
+ launch = self._ready_state.launch
+ while not launch.is_set():
+ # this snippet of code is basically waiting 2 seconds
+ # until the last GUILD_CREATE was sent
+ launch.set()
+ yield from asyncio.sleep(2.0 * self.shard_count, loop=self.loop)
+
+ guilds = sorted(self._ready_state.guilds, key=lambda g: g.shard_id)
+
+ # we only want to request ~75 guilds per chunk request.
+ # we also want to split the chunks per shard_id
+ for shard_id, sub_guilds in itertools.groupby(guilds, key=lambda g: g.shard_id):
+ sub_guilds = list(sub_guilds)
+
+ # split chunks by shard ID
+ chunks = []
+ for guild in sub_guilds:
+ chunks.extend(self.chunks_needed(guild))
+
+ splits = [sub_guilds[i:i + 75] for i in range(0, len(sub_guilds), 75)]
+ for split in splits:
+ yield from self.chunker(split, shard_id=shard_id)
+
+ # wait for the chunks
+ if chunks:
+ try:
+ yield from asyncio.wait(chunks, timeout=len(chunks) * 30.0, loop=self.loop)
+ except asyncio.TimeoutError:
+ log.info('Somehow timed out waiting for chunks for %s shard_id' % shard_id)
+
+ self.dispatch('shard_ready', shard_id)
+
+ # sleep a second for every shard ID.
+ # yield from asyncio.sleep(1.0, loop=self.loop)
+
+ # remove the state
+ try:
+ del self._ready_state
+ except AttributeError:
+ pass # already been deleted somehow
+
+ # regular users cannot shard so we won't worry about it here.
+
+ # dispatch the event
+ self.dispatch('ready')
+
+ def parse_ready(self, data):
+ if not hasattr(self, '_ready_state'):
+ self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
+
+ self.user = self.store_user(data['user'])
+
+ guilds = self._ready_state.guilds
+ for guild_data in data['guilds']:
+ guild = self._add_guild_from_data(guild_data)
+ if not self.is_bot or guild.large:
+ guilds.append(guild)
+
+ for pm in data.get('private_channels', []):
+ factory, _ = _channel_factory(pm['type'])
+ self._add_private_channel(factory(me=self.user, data=pm, state=self))
+
+ if self._ready_task is None:
+ self._ready_task = compat.create_task(self._delay_ready(), loop=self.loop)