aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2016-03-06 05:24:18 -0500
committerRapptz <[email protected]>2016-03-06 05:24:18 -0500
commit425bd2c0911cd90aa30ab18e79b60078ee1cb0f0 (patch)
treea4cabb90788405054ff37cde7e921a18501eeb5b
parentAdd created_at properties for Server and User. (diff)
downloaddiscord.py-425bd2c0911cd90aa30ab18e79b60078ee1cb0f0.tar.xz
discord.py-425bd2c0911cd90aa30ab18e79b60078ee1cb0f0.zip
Move chunking logic back into ConnectionState.
This allows for a nicer design when dealing with parsers that could end up being coroutines.
-rw-r--r--discord/client.py33
-rw-r--r--discord/state.py24
2 files changed, 26 insertions, 31 deletions
diff --git a/discord/client.py b/discord/client.py
index 9ec9adf0..73037aa8 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -51,7 +51,7 @@ import logging, traceback
import sys, time, re, json
import tempfile, os, hashlib
import itertools
-import zlib, math
+import zlib
from random import randint as random_integer
PY35 = sys.version_info >= (3, 5)
@@ -122,7 +122,7 @@ class Client:
if max_messages is None or max_messages < 100:
max_messages = 5000
- self.connection = ConnectionState(self.dispatch, max_messages, loop=self.loop)
+ self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop)
# Blame Jake for this
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
@@ -145,28 +145,6 @@ class Client:
# internals
- def _get_all_chunks(self):
- # a chunk has a maximum of 1000 members.
- # we need to find out how many futures we're actually waiting for
- large_servers = filter(lambda s: s.large, self.servers)
- futures = []
- for server in large_servers:
- chunks_needed = math.ceil(server._member_count / 1000)
- for chunk in range(chunks_needed):
- futures.append(self.connection.receive_chunk(server.id))
-
- return futures
-
- @asyncio.coroutine
- def _fill_offline(self):
- yield from self.request_offline_members(filter(lambda s: s.large, self.servers))
- chunks = self._get_all_chunks()
-
- if chunks:
- yield from asyncio.wait(chunks)
-
- self.dispatch('ready')
-
def _get_cache_filename(self, email):
filename = hashlib.md5(email.encode('utf-8')).hexdigest()
return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
@@ -392,11 +370,10 @@ class Client:
func = getattr(self.connection, parser)
except AttributeError:
log.info('Unhandled event {}'.format(event))
- else:
- func(data)
- if is_ready:
- utils.create_task(self._fill_offline(), loop=self.loop)
+ result = func(data)
+ if asyncio.iscoroutine(result):
+ utils.create_task(result, loop=self.loop)
@asyncio.coroutine
def _make_websocket(self, initial=True):
diff --git a/discord/state.py b/discord/state.py
index 6db17218..fec2e111 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -36,10 +36,9 @@ from .enums import Status
from collections import deque, namedtuple
-import copy
+import copy, enum, math
import datetime
import asyncio
-import enum
import logging
class ListenerType(enum.Enum):
@@ -49,10 +48,11 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
log = logging.getLogger(__name__)
class ConnectionState:
- def __init__(self, dispatch, max_messages, *, loop):
+ def __init__(self, dispatch, chunker, max_messages, *, loop):
self.loop = loop
self.max_messages = max_messages
self.dispatch = dispatch
+ self.chunker = chunker
self._listeners = []
self.clear()
@@ -128,6 +128,7 @@ class ConnectionState:
self._add_server(server)
return server
+ @asyncio.coroutine
def parse_ready(self, data):
self.user = User(**data['user'])
guilds = data.get('guilds')
@@ -139,6 +140,23 @@ class ConnectionState:
self._add_private_channel(PrivateChannel(id=pm['id'],
user=User(**pm['recipient'])))
+ # a chunk has a maximum of 1000 members.
+ # we need to find out how many futures we're actually waiting for
+
+ large_servers = [s for s in self.servers if s.large]
+ yield from self.chunker(large_servers)
+
+ chunks = []
+ for server in large_servers:
+ chunks_needed = math.ceil(server._member_count / 1000)
+ for chunk in range(chunks_needed):
+ chunks.append(self.receive_chunk(server.id))
+
+ if chunks:
+ yield from asyncio.wait(chunks)
+
+ self.dispatch('ready')
+
def parse_message_create(self, data):
channel = self.get_channel(data.get('channel_id'))
message = Message(channel=channel, **data)