diff options
Diffstat (limited to 'discord/state.py')
| -rw-r--r-- | discord/state.py | 37 |
1 files changed, 25 insertions, 12 deletions
diff --git a/discord/state.py b/discord/state.py index e630724e..0e8764b6 100644 --- a/discord/state.py +++ b/discord/state.py @@ -35,6 +35,9 @@ import weakref import inspect import gc +import os +import binascii + from .guild import Guild from .activity import BaseActivity from .user import User, ClientUser @@ -62,7 +65,7 @@ log = logging.getLogger(__name__) ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) class ConnectionState: - def __init__(self, *, dispatch, chunker, handlers, syncer, http, loop, **options): + def __init__(self, *, dispatch, handlers, syncer, http, loop, **options): self.loop = loop self.http = http self.max_messages = options.get('max_messages', 1000) @@ -70,7 +73,6 @@ class ConnectionState: self.max_messages = 1000 self.dispatch = dispatch - self.chunker = chunker self.syncer = syncer self.is_bot = None self.handlers = handlers @@ -132,6 +134,9 @@ class ConnectionState: # to reconnect loops which cause mass allocations and deallocations. gc.collect() + def get_nonce(self): + return binascii.hexlify(os.urandom(16)).decode('ascii') + def process_listeners(self, listener_type, argument, result): removed = [] for i, listener in enumerate(self._listeners): @@ -298,6 +303,10 @@ class ConnectionState: return channel or Object(id=channel_id), guild + async def chunker(self, guild_id, query='', limit=0, *, nonce=None): + ws = self._get_websocket(guild_id) # This is ignored upstream + await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) + async def request_offline_members(self, guilds): # get all the chunks chunks = [] @@ -307,7 +316,7 @@ class ConnectionState: # we only want to request ~75 guilds per chunk request. splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] for split in splits: - await self.chunker(split) + await self.chunker([g.id for g in split]) # wait for the chunks if chunks: @@ -329,10 +338,11 @@ class ConnectionState: # and they don't receive GUILD_MEMBER events which make computing # member_count impossible. The only way to fix it is by limiting # the limit parameter to 1 to 1000. - future = self.receive_member_query(guild_id, query) + nonce = self.get_nonce() + future = self.receive_member_query(guild_id, nonce) try: # start the query operation - await ws.request_chunks(guild_id, query, limit) + await ws.request_chunks(guild_id, query, limit, nonce=nonce) members = await asyncio.wait_for(future, timeout=5.0) if cache: @@ -894,8 +904,7 @@ class ConnectionState: guild._add_member(member) self.process_listeners(ListenerType.chunk, guild, len(members)) - names = [x.name.lower() for x in members] - self.process_listeners(ListenerType.query_members, (guild_id, names), members) + self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members) def parse_guild_integrations_update(self, data): guild = self._get_guild(int(data['guild_id'])) @@ -1025,10 +1034,10 @@ class ConnectionState: self._listeners.append(listener) return future - def receive_member_query(self, guild_id, query): - def predicate(args, *, guild_id=guild_id, query=query.lower()): - request_guild_id, names = args - return request_guild_id == guild_id and all(n.startswith(query) for n in names) + def receive_member_query(self, guild_id, nonce): + def predicate(args, *, guild_id=guild_id, nonce=nonce): + return args == (guild_id, nonce) + future = self.loop.create_future() listener = Listener(ListenerType.query_members, future, predicate) self._listeners.append(listener) @@ -1040,6 +1049,10 @@ class AutoShardedConnectionState(ConnectionState): self._ready_task = None self.shard_ids = () + async def chunker(self, guild_id, query='', limit=0, *, shard_id, nonce=None): + ws = self._get_websocket(shard_id=shard_id) + await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) + async def request_offline_members(self, guilds, *, shard_id): # get all the chunks chunks = [] @@ -1049,7 +1062,7 @@ class AutoShardedConnectionState(ConnectionState): # we only want to request ~75 guilds per chunk request. splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] for split in splits: - await self.chunker(split, shard_id=shard_id) + await self.chunker([g.id for g in split], shard_id=shard_id) # wait for the chunks if chunks: |