aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/client.py24
-rw-r--r--discord/gateway.py8
-rw-r--r--discord/shard.py31
-rw-r--r--discord/state.py37
4 files changed, 42 insertions, 58 deletions
diff --git a/discord/client.py b/discord/client.py
index f292c9d2..03845fd1 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -223,13 +223,13 @@ class Client:
'ready': self._handle_ready
}
- self._connection = ConnectionState(dispatch=self.dispatch, chunker=self._chunker, handlers=self._handlers,
+ self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
syncer=self._syncer, http=self.http, loop=self.loop, **options)
self._connection.shard_count = self.shard_count
self._closed = False
self._ready = asyncio.Event()
- self._connection._get_websocket = lambda g: self.ws
+ self._connection._get_websocket = self._get_websocket
if VoiceClient.warn_nacl:
VoiceClient.warn_nacl = False
@@ -237,26 +237,12 @@ class Client:
# internals
+ def _get_websocket(self, guild_id=None, *, shard_id=None):
+ return self.ws
+
async def _syncer(self, guilds):
await self.ws.request_sync(guilds)
- async def _chunker(self, guild):
- try:
- guild_id = guild.id
- except AttributeError:
- guild_id = [s.id for s in guild]
-
- payload = {
- 'op': 8,
- 'd': {
- 'guild_id': guild_id,
- 'query': '',
- 'limit': 0
- }
- }
-
- await self.ws.send_as_json(payload)
-
def _handle_ready(self):
self._ready.set()
diff --git a/discord/gateway.py b/discord/gateway.py
index 9b0f1d81..15368d56 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -535,15 +535,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
}
await self.send_as_json(payload)
- async def request_chunks(self, guild_id, query, limit):
+ async def request_chunks(self, guild_id, query, limit, *, nonce=None):
payload = {
'op': self.REQUEST_MEMBERS,
'd': {
- 'guild_id': str(guild_id),
+ 'guild_id': guild_id,
'query': query,
'limit': limit
}
}
+
+ if nonce:
+ payload['d']['nonce'] = nonce
+
await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
diff --git a/discord/shard.py b/discord/shard.py
index e133cd0c..6d599dab 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -126,38 +126,19 @@ class AutoShardedClient(Client):
elif not isinstance(self.shard_ids, (list, tuple)):
raise ClientException('shard_ids parameter must be a list or a tuple.')
- self._connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self._chunker,
+ self._connection = AutoShardedConnectionState(dispatch=self.dispatch,
handlers=self._handlers, syncer=self._syncer,
http=self.http, loop=self.loop, **kwargs)
# instead of a single websocket, we have multiple
# the key is the shard_id
self.shards = {}
+ self._connection._get_websocket = self._get_websocket
- def _get_websocket(guild_id):
- i = (guild_id >> 22) % self.shard_count
- return self.shards[i].ws
-
- self._connection._get_websocket = _get_websocket
-
- async def _chunker(self, guild, *, shard_id=None):
- try:
- guild_id = guild.id
- shard_id = shard_id or guild.shard_id
- except AttributeError:
- guild_id = [s.id for s in guild]
-
- payload = {
- 'op': 8,
- 'd': {
- 'guild_id': guild_id,
- 'query': '',
- 'limit': 0
- }
- }
-
- ws = self.shards[shard_id].ws
- await ws.send_as_json(payload)
+ def _get_websocket(self, guild_id=None, *, shard_id=None):
+ if shard_id is None:
+ shard_id = (guild_id >> 22) % self.shard_count
+ return self.shards[shard_id].ws
@property
def latency(self):
diff --git a/discord/state.py b/discord/state.py
index e630724e..0e8764b6 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -35,6 +35,9 @@ import weakref
import inspect
import gc
+import os
+import binascii
+
from .guild import Guild
from .activity import BaseActivity
from .user import User, ClientUser
@@ -62,7 +65,7 @@ log = logging.getLogger(__name__)
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
class ConnectionState:
- def __init__(self, *, dispatch, chunker, handlers, syncer, http, loop, **options):
+ def __init__(self, *, dispatch, handlers, syncer, http, loop, **options):
self.loop = loop
self.http = http
self.max_messages = options.get('max_messages', 1000)
@@ -70,7 +73,6 @@ class ConnectionState:
self.max_messages = 1000
self.dispatch = dispatch
- self.chunker = chunker
self.syncer = syncer
self.is_bot = None
self.handlers = handlers
@@ -132,6 +134,9 @@ class ConnectionState:
# to reconnect loops which cause mass allocations and deallocations.
gc.collect()
+ def get_nonce(self):
+ return binascii.hexlify(os.urandom(16)).decode('ascii')
+
def process_listeners(self, listener_type, argument, result):
removed = []
for i, listener in enumerate(self._listeners):
@@ -298,6 +303,10 @@ class ConnectionState:
return channel or Object(id=channel_id), guild
+ async def chunker(self, guild_id, query='', limit=0, *, nonce=None):
+ ws = self._get_websocket(guild_id) # This is ignored upstream
+ await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
+
async def request_offline_members(self, guilds):
# get all the chunks
chunks = []
@@ -307,7 +316,7 @@ class ConnectionState:
# we only want to request ~75 guilds per chunk request.
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
for split in splits:
- await self.chunker(split)
+ await self.chunker([g.id for g in split])
# wait for the chunks
if chunks:
@@ -329,10 +338,11 @@ class ConnectionState:
# and they don't receive GUILD_MEMBER events which make computing
# member_count impossible. The only way to fix it is by limiting
# the limit parameter to 1 to 1000.
- future = self.receive_member_query(guild_id, query)
+ nonce = self.get_nonce()
+ future = self.receive_member_query(guild_id, nonce)
try:
# start the query operation
- await ws.request_chunks(guild_id, query, limit)
+ await ws.request_chunks(guild_id, query, limit, nonce=nonce)
members = await asyncio.wait_for(future, timeout=5.0)
if cache:
@@ -894,8 +904,7 @@ class ConnectionState:
guild._add_member(member)
self.process_listeners(ListenerType.chunk, guild, len(members))
- names = [x.name.lower() for x in members]
- self.process_listeners(ListenerType.query_members, (guild_id, names), members)
+ self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members)
def parse_guild_integrations_update(self, data):
guild = self._get_guild(int(data['guild_id']))
@@ -1025,10 +1034,10 @@ class ConnectionState:
self._listeners.append(listener)
return future
- def receive_member_query(self, guild_id, query):
- def predicate(args, *, guild_id=guild_id, query=query.lower()):
- request_guild_id, names = args
- return request_guild_id == guild_id and all(n.startswith(query) for n in names)
+ def receive_member_query(self, guild_id, nonce):
+ def predicate(args, *, guild_id=guild_id, nonce=nonce):
+ return args == (guild_id, nonce)
+
future = self.loop.create_future()
listener = Listener(ListenerType.query_members, future, predicate)
self._listeners.append(listener)
@@ -1040,6 +1049,10 @@ class AutoShardedConnectionState(ConnectionState):
self._ready_task = None
self.shard_ids = ()
+ async def chunker(self, guild_id, query='', limit=0, *, shard_id, nonce=None):
+ ws = self._get_websocket(shard_id=shard_id)
+ await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
+
async def request_offline_members(self, guilds, *, shard_id):
# get all the chunks
chunks = []
@@ -1049,7 +1062,7 @@ class AutoShardedConnectionState(ConnectionState):
# we only want to request ~75 guilds per chunk request.
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
for split in splits:
- await self.chunker(split, shard_id=shard_id)
+ await self.chunker([g.id for g in split], shard_id=shard_id)
# wait for the chunks
if chunks: