diff options
Diffstat (limited to 'discord/state.py')
| -rw-r--r-- | discord/state.py | 60 |
1 files changed, 54 insertions, 6 deletions
diff --git a/discord/state.py b/discord/state.py index c929d8ff..6c41f9a7 100644 --- a/discord/state.py +++ b/discord/state.py @@ -34,14 +34,26 @@ from .role import Role from . import utils from .enums import Status -from collections import deque + +from collections import deque, namedtuple import copy import datetime +import asyncio +import enum +import logging + +class ListenerType(enum.Enum): + chunk = 0 + +Listener = namedtuple('Listener', ('type', 'future', 'predicate')) +log = logging.getLogger(__name__) class ConnectionState: - def __init__(self, dispatch, max_messages): + def __init__(self, dispatch, max_messages, *, loop): + self.loop = loop self.max_messages = max_messages self.dispatch = dispatch + self._listeners = [] self.clear() def clear(self): @@ -52,6 +64,30 @@ class ConnectionState: self._private_channels_by_user = {} self.messages = deque(maxlen=self.max_messages) + def process_listeners(self, listener_type, argument, result): + removed = [] + for i, listener in enumerate(self._listeners): + if listener.type != listener_type: + continue + + future = listener.future + if future.cancelled(): + removed.append(i) + continue + + try: + passed = listener.predicate(argument) + except Exception as e: + future.set_exception(e) + removed.append(i) + else: + if passed: + future.set_result(result) + removed.append(i) + + for index in reversed(removed): + del self._listeners[index] + @property def servers(self): return self._servers.values() @@ -103,9 +139,6 @@ class ConnectionState: self._add_private_channel(PrivateChannel(id=pm['id'], user=User(**pm['recipient']))) - # we're all ready - self.dispatch('ready') - def parse_message_create(self, data): channel = self.get_channel(data.get('channel_id')) message = Message(channel=channel, **data) @@ -213,7 +246,7 @@ class ConnectionState: def parse_guild_member_add(self, data): server = self._get_server(data.get('guild_id')) - self._add_member(server, data) + member = self._add_member(server, data) server._member_count += 1 self.dispatch('member_join', member) @@ -345,6 +378,15 @@ class ConnectionState: role._update(**data['role']) self.dispatch('server_role_update', old_role, role) + def parse_guild_members_chunk(self, data): + server = self._get_server(data.get('guild_id')) + members = data.get('members', []) + for member in members: + self._add_member(server, member) + + log.info('processed a chunk for {} members.'.format(len(members))) + self.process_listeners(ListenerType.chunk, server, len(members)) + def parse_voice_state_update(self, data): server = self._get_server(data.get('guild_id')) if server is not None: @@ -381,3 +423,9 @@ class ConnectionState: pm = self._get_private_channel(id) if pm is not None: return pm + + def receive_chunk(self, guild_id): + future = asyncio.Future(loop=self.loop) + listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id) + self._listeners.append(listener) + return future |