aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2020-04-06 21:34:55 -0400
committerRapptz <[email protected]>2020-07-25 09:59:37 -0400
commit09ecb16680fe878b92f621016484614ddd88c0a1 (patch)
tree3e2e7307091dc2fa24136351ae7569897d776917
parentAdd revisions to check_once docs (diff)
downloaddiscord.py-09ecb16680fe878b92f621016484614ddd88c0a1.tar.xz
discord.py-09ecb16680fe878b92f621016484614ddd88c0a1.zip
Rewrite of AutoShardedClient to prevent overlapping identify
This is experimental and I'm unsure if it actually works
-rw-r--r--discord/client.py9
-rw-r--r--discord/gateway.py22
-rw-r--r--discord/shard.py103
-rw-r--r--discord/state.py2
4 files changed, 72 insertions, 64 deletions
diff --git a/discord/client.py b/discord/client.py
index 86232155..0fcdcd48 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -453,11 +453,14 @@ class Client:
while True:
try:
await self.ws.poll_event()
- except ResumeWebSocket:
- log.info('Got a request to RESUME the websocket.')
+ except ReconnectWebSocket as e:
+ log.info('Got a request to %s the websocket.', e.op)
self.dispatch('disconnect')
+ if not e.resume:
+ await asyncio.sleep(5.0)
+
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id,
- sequence=self.ws.sequence, resume=True)
+ sequence=self.ws.sequence, resume=e.resume)
self.ws = await asyncio.wait_for(coro, timeout=180.0)
async def connect(self, *, reconnect=True):
diff --git a/discord/gateway.py b/discord/gateway.py
index 2ddd2a23..c2a432f8 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -50,13 +50,15 @@ __all__ = (
'KeepAliveHandler',
'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket',
- 'ResumeWebSocket',
+ 'ReconnectWebSocket',
)
-class ResumeWebSocket(Exception):
- """Signals to initialise via RESUME opcode instead of IDENTIFY."""
- def __init__(self, shard_id):
+class ReconnectWebSocket(Exception):
+ """Signals to safely reconnect the websocket."""
+ def __init__(self, shard_id, *, resume=True):
self.shard_id = shard_id
+ self.resume = resume
+ self.op = 'RESUME' if resume else 'IDENTIFY'
EventListener = namedtuple('EventListener', 'predicate event result future')
@@ -385,7 +387,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# internal exception signalling to reconnect.
log.debug('Received RECONNECT opcode.')
await self.close()
- raise ResumeWebSocket(self.shard_id)
+ raise ReconnectWebSocket(self.shard_id)
if op == self.HEARTBEAT_ACK:
self._keep_alive.ack()
@@ -406,16 +408,14 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
if op == self.INVALIDATE_SESSION:
if data is True:
- await asyncio.sleep(5.0)
await self.close()
- raise ResumeWebSocket(self.shard_id)
+ raise ReconnectWebSocket(self.shard_id)
self.sequence = None
self.session_id = None
log.info('Shard ID %s session has been invalidated.', self.shard_id)
- await asyncio.sleep(5.0)
- await self.identify()
- return
+ await self.close(code=1000)
+ raise ReconnectWebSocket(self.shard_id, resume=False)
log.warning('Unknown OP code %s.', op)
return
@@ -489,7 +489,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
except websockets.exceptions.ConnectionClosed as exc:
if self._can_handle_close(exc.code):
log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason)
- raise ResumeWebSocket(self.shard_id) from exc
+ raise ReconnectWebSocket(self.shard_id) from exc
else:
log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason)
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc
diff --git a/discord/shard.py b/discord/shard.py
index 6d599dab..ad564bb2 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -33,61 +33,58 @@ import websockets
from .state import AutoShardedConnectionState
from .client import Client
from .gateway import *
-from .errors import ClientException, InvalidArgument
+from .errors import ClientException, InvalidArgument, ConnectionClosed
from . import utils
from .enums import Status
log = logging.getLogger(__name__)
+class EventType:
+ Close = 0
+ Resume = 1
+ Identify = 2
+
class Shard:
def __init__(self, ws, client):
self.ws = ws
self._client = client
self._dispatch = client.dispatch
+ self._queue = client._queue
self.loop = self._client.loop
- self._current = self.loop.create_future()
- self._current.set_result(None) # we just need an already done future
- self._pending = asyncio.Event()
- self._pending_task = None
+ self._task = None
@property
def id(self):
return self.ws.shard_id
- def is_pending(self):
- return not self._pending.is_set()
-
- def complete_pending_reads(self):
- self._pending.set()
-
- async def _pending_reads(self):
- try:
- while self.is_pending():
- await self.poll()
- except asyncio.CancelledError:
- pass
-
- def launch_pending_reads(self):
- self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop)
-
- def wait(self):
- return self._pending_task
+ def launch(self):
+ self._task = self.loop.create_task(self.worker())
- async def poll(self):
- try:
- await self.ws.poll_event()
- except ResumeWebSocket:
- log.info('Got a request to RESUME the websocket at Shard ID %s.', self.id)
- coro = DiscordWebSocket.from_client(self._client, resume=True, shard_id=self.id,
- session=self.ws.session_id, sequence=self.ws.sequence)
- self._dispatch('disconnect')
- self.ws = await asyncio.wait_for(coro, timeout=180.0)
-
- def get_future(self):
- if self._current.done():
- self._current = asyncio.ensure_future(self.poll(), loop=self.loop)
+ async def worker(self):
+ while True:
+ try:
+ await self.ws.poll_event()
+ except ReconnectWebSocket as e:
+ etype = EventType.resume if e.resume else EventType.identify
+ self._queue.put_nowait((etype, self, e))
+ break
+ except ConnectionClosed as e:
+ self._queue.put_nowait((EventType.close, self, e))
+ break
+
+ async def reconnect(self, exc):
+ if self._task is not None and not self._task.done():
+ self._task.cancel()
+
+ log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
+ if not exc.resume:
+ await asyncio.sleep(5.0)
- return self._current
+ coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
+ session=self.ws.session_id, sequence=self.ws.sequence)
+ self._dispatch('disconnect')
+ self.ws = await asyncio.wait_for(coro, timeout=180.0)
+ self.launch()
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
@@ -134,6 +131,7 @@ class AutoShardedClient(Client):
# the key is the shard_id
self.shards = {}
self._connection._get_websocket = self._get_websocket
+ self._queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None):
if shard_id is None:
@@ -220,8 +218,10 @@ class AutoShardedClient(Client):
# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
- ret.launch_pending_reads()
- await asyncio.sleep(5.0)
+ ret.launch()
+
+ if len(self.shards) == self.shard_count:
+ self._connection.shards_launched.set()
async def launch_shards(self):
if self.shard_count is None:
@@ -234,26 +234,29 @@ class AutoShardedClient(Client):
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
self._connection.shard_ids = shard_ids
+ last_shard_id = shard_ids[-1]
for shard_id in shard_ids:
await self.launch_shard(gateway, shard_id)
+ if shard_id != last_shard_id:
+ await asyncio.sleep(5.0)
- shards_to_wait_for = []
- for shard in self.shards.values():
- shard.complete_pending_reads()
- shards_to_wait_for.append(shard.wait())
+ # shards_to_wait_for = []
+ # for shard in self.shards.values():
+ # shard.complete_pending_reads()
+ # shards_to_wait_for.append(shard.wait())
- # wait for all pending tasks to finish
- await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
+ # # wait for all pending tasks to finish
+ # await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
async def _connect(self):
await self.launch_shards()
while True:
- pollers = [shard.get_future() for shard in self.shards.values()]
- done, _ = await asyncio.wait(pollers, return_when=asyncio.FIRST_COMPLETED)
- for f in done:
- # we wanna re-raise to the main Client.connect handler if applicable
- f.result()
+ etype, shard, exc = await self._queue.get()
+ if etype == EventType.close:
+ raise exc
+ elif etype in (EventType.identify, EventType.resume):
+ await shard.reconnect(exc)
async def close(self):
"""|coro|
diff --git a/discord/state.py b/discord/state.py
index f84d85ba..6148889d 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -1047,6 +1047,7 @@ class AutoShardedConnectionState(ConnectionState):
super().__init__(*args, **kwargs)
self._ready_task = None
self.shard_ids = ()
+ self.shards_launched = asyncio.Event()
async def chunker(self, guild_id, query='', limit=0, *, shard_id=None, nonce=None):
ws = self._get_websocket(guild_id, shard_id=shard_id)
@@ -1073,6 +1074,7 @@ class AutoShardedConnectionState(ConnectionState):
log.info('Finished requesting guild member chunks for %d guilds.', len(guilds))
async def _delay_ready(self):
+ await self.shards_launched.wait()
launch = self._ready_state.launch
while True:
# this snippet of code is basically waiting 2 * shard_ids seconds