diff options
| author | Rapptz <[email protected]> | 2020-04-06 21:34:55 -0400 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2020-07-25 09:59:37 -0400 |
| commit | 09ecb16680fe878b92f621016484614ddd88c0a1 (patch) | |
| tree | 3e2e7307091dc2fa24136351ae7569897d776917 /discord/shard.py | |
| parent | Add revisions to check_once docs (diff) | |
| download | discord.py-09ecb16680fe878b92f621016484614ddd88c0a1.tar.xz discord.py-09ecb16680fe878b92f621016484614ddd88c0a1.zip | |
Rewrite of AutoShardedClient to prevent overlapping identify
This is experimental and I'm unsure if it actually works
Diffstat (limited to 'discord/shard.py')
| -rw-r--r-- | discord/shard.py | 103 |
1 files changed, 53 insertions, 50 deletions
diff --git a/discord/shard.py b/discord/shard.py index 6d599dab..ad564bb2 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -33,61 +33,58 @@ import websockets from .state import AutoShardedConnectionState from .client import Client from .gateway import * -from .errors import ClientException, InvalidArgument +from .errors import ClientException, InvalidArgument, ConnectionClosed from . import utils from .enums import Status log = logging.getLogger(__name__) +class EventType: + Close = 0 + Resume = 1 + Identify = 2 + class Shard: def __init__(self, ws, client): self.ws = ws self._client = client self._dispatch = client.dispatch + self._queue = client._queue self.loop = self._client.loop - self._current = self.loop.create_future() - self._current.set_result(None) # we just need an already done future - self._pending = asyncio.Event() - self._pending_task = None + self._task = None @property def id(self): return self.ws.shard_id - def is_pending(self): - return not self._pending.is_set() - - def complete_pending_reads(self): - self._pending.set() - - async def _pending_reads(self): - try: - while self.is_pending(): - await self.poll() - except asyncio.CancelledError: - pass - - def launch_pending_reads(self): - self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop) - - def wait(self): - return self._pending_task + def launch(self): + self._task = self.loop.create_task(self.worker()) - async def poll(self): - try: - await self.ws.poll_event() - except ResumeWebSocket: - log.info('Got a request to RESUME the websocket at Shard ID %s.', self.id) - coro = DiscordWebSocket.from_client(self._client, resume=True, shard_id=self.id, - session=self.ws.session_id, sequence=self.ws.sequence) - self._dispatch('disconnect') - self.ws = await asyncio.wait_for(coro, timeout=180.0) - - def get_future(self): - if self._current.done(): - self._current = asyncio.ensure_future(self.poll(), loop=self.loop) + async def worker(self): + while True: + try: + await self.ws.poll_event() + except ReconnectWebSocket as e: + etype = EventType.resume if e.resume else EventType.identify + self._queue.put_nowait((etype, self, e)) + break + except ConnectionClosed as e: + self._queue.put_nowait((EventType.close, self, e)) + break + + async def reconnect(self, exc): + if self._task is not None and not self._task.done(): + self._task.cancel() + + log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) + if not exc.resume: + await asyncio.sleep(5.0) - return self._current + coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, + session=self.ws.session_id, sequence=self.ws.sequence) + self._dispatch('disconnect') + self.ws = await asyncio.wait_for(coro, timeout=180.0) + self.launch() class AutoShardedClient(Client): """A client similar to :class:`Client` except it handles the complications @@ -134,6 +131,7 @@ class AutoShardedClient(Client): # the key is the shard_id self.shards = {} self._connection._get_websocket = self._get_websocket + self._queue = asyncio.PriorityQueue() def _get_websocket(self, guild_id=None, *, shard_id=None): if shard_id is None: @@ -220,8 +218,10 @@ class AutoShardedClient(Client): # keep reading the shard while others connect self.shards[shard_id] = ret = Shard(ws, self) - ret.launch_pending_reads() - await asyncio.sleep(5.0) + ret.launch() + + if len(self.shards) == self.shard_count: + self._connection.shards_launched.set() async def launch_shards(self): if self.shard_count is None: @@ -234,26 +234,29 @@ class AutoShardedClient(Client): shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count) self._connection.shard_ids = shard_ids + last_shard_id = shard_ids[-1] for shard_id in shard_ids: await self.launch_shard(gateway, shard_id) + if shard_id != last_shard_id: + await asyncio.sleep(5.0) - shards_to_wait_for = [] - for shard in self.shards.values(): - shard.complete_pending_reads() - shards_to_wait_for.append(shard.wait()) + # shards_to_wait_for = [] + # for shard in self.shards.values(): + # shard.complete_pending_reads() + # shards_to_wait_for.append(shard.wait()) - # wait for all pending tasks to finish - await utils.sane_wait_for(shards_to_wait_for, timeout=300.0) + # # wait for all pending tasks to finish + # await utils.sane_wait_for(shards_to_wait_for, timeout=300.0) async def _connect(self): await self.launch_shards() while True: - pollers = [shard.get_future() for shard in self.shards.values()] - done, _ = await asyncio.wait(pollers, return_when=asyncio.FIRST_COMPLETED) - for f in done: - # we wanna re-raise to the main Client.connect handler if applicable - f.result() + etype, shard, exc = await self._queue.get() + if etype == EventType.close: + raise exc + elif etype in (EventType.identify, EventType.resume): + await shard.reconnect(exc) async def close(self): """|coro| |