diff options
| author | Rapptz <[email protected]> | 2017-01-07 21:55:47 -0500 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2017-01-07 23:19:39 -0500 |
| commit | 20041ea756305f20c86a621232639932c50f107c (patch) | |
| tree | fc9be7da66b1dffd274d96f85dd1cb7c605e56c2 /discord/state.py | |
| parent | Fix variable shadowing in READY parsing. (diff) | |
| download | discord.py-20041ea756305f20c86a621232639932c50f107c.tar.xz discord.py-20041ea756305f20c86a621232639932c50f107c.zip | |
Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process
without the headaches of IPC/RPC or much difficulty.
Diffstat (limited to 'discord/state.py')
| -rw-r--r-- | discord/state.py | 80 |
1 files changed, 77 insertions, 3 deletions
diff --git a/discord/state.py b/discord/state.py index 383b559f..bd7fbdbe 100644 --- a/discord/state.py +++ b/discord/state.py @@ -43,6 +43,7 @@ import datetime import asyncio import logging import weakref +import itertools class ListenerType(enum.Enum): chunk = 0 @@ -60,13 +61,12 @@ class ConnectionState: self.chunker = chunker self.syncer = syncer self.is_bot = None + self.shard_count = None self._listeners = [] self.clear() def clear(self): self.user = None - self.sequence = None - self.session_id = None self._users = weakref.WeakValueDictionary() self._calls = {} self._emojis = {} @@ -355,7 +355,8 @@ class ConnectionState: # the reason we're doing this is so it's also removed from the # private channel by user cache as well channel = self._get_private_channel(channel_id) - self._remove_private_channel(channel) + if channel is not None: + self._remove_private_channel(channel) def parse_channel_update(self, data): channel_type = try_enum(ChannelType, data.get('type')) @@ -701,3 +702,76 @@ class ConnectionState: listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id) self._listeners.append(listener) return future + +class AutoShardedConnectionState(ConnectionState): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) + self._ready_task = None + + @asyncio.coroutine + def _delay_ready(self): + launch = self._ready_state.launch + while not launch.is_set(): + # this snippet of code is basically waiting 2 seconds + # until the last GUILD_CREATE was sent + launch.set() + yield from asyncio.sleep(2.0 * self.shard_count, loop=self.loop) + + guilds = sorted(self._ready_state.guilds, key=lambda g: g.shard_id) + + # we only want to request ~75 guilds per chunk request. + # we also want to split the chunks per shard_id + for shard_id, sub_guilds in itertools.groupby(guilds, key=lambda g: g.shard_id): + sub_guilds = list(sub_guilds) + + # split chunks by shard ID + chunks = [] + for guild in sub_guilds: + chunks.extend(self.chunks_needed(guild)) + + splits = [sub_guilds[i:i + 75] for i in range(0, len(sub_guilds), 75)] + for split in splits: + yield from self.chunker(split, shard_id=shard_id) + + # wait for the chunks + if chunks: + try: + yield from asyncio.wait(chunks, timeout=len(chunks) * 30.0, loop=self.loop) + except asyncio.TimeoutError: + log.info('Somehow timed out waiting for chunks for %s shard_id' % shard_id) + + self.dispatch('shard_ready', shard_id) + + # sleep a second for every shard ID. + # yield from asyncio.sleep(1.0, loop=self.loop) + + # remove the state + try: + del self._ready_state + except AttributeError: + pass # already been deleted somehow + + # regular users cannot shard so we won't worry about it here. + + # dispatch the event + self.dispatch('ready') + + def parse_ready(self, data): + if not hasattr(self, '_ready_state'): + self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) + + self.user = self.store_user(data['user']) + + guilds = self._ready_state.guilds + for guild_data in data['guilds']: + guild = self._add_guild_from_data(guild_data) + if not self.is_bot or guild.large: + guilds.append(guild) + + for pm in data.get('private_channels', []): + factory, _ = _channel_factory(pm['type']) + self._add_private_channel(factory(me=self.user, data=pm, state=self)) + + if self._ready_task is None: + self._ready_task = compat.create_task(self._delay_ready(), loop=self.loop) |