aboutsummaryrefslogtreecommitdiff
path: root/discord/shard.py
diff options
context:
space:
mode:
authorRapptz <[email protected]>2020-04-06 21:34:55 -0400
committerRapptz <[email protected]>2020-07-25 09:59:37 -0400
commit09ecb16680fe878b92f621016484614ddd88c0a1 (patch)
tree3e2e7307091dc2fa24136351ae7569897d776917 /discord/shard.py
parentAdd revisions to check_once docs (diff)
downloaddiscord.py-09ecb16680fe878b92f621016484614ddd88c0a1.tar.xz
discord.py-09ecb16680fe878b92f621016484614ddd88c0a1.zip
Rewrite of AutoShardedClient to prevent overlapping identify
This is experimental and I'm unsure if it actually works
Diffstat (limited to 'discord/shard.py')
-rw-r--r--discord/shard.py103
1 files changed, 53 insertions, 50 deletions
diff --git a/discord/shard.py b/discord/shard.py
index 6d599dab..ad564bb2 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -33,61 +33,58 @@ import websockets
from .state import AutoShardedConnectionState
from .client import Client
from .gateway import *
-from .errors import ClientException, InvalidArgument
+from .errors import ClientException, InvalidArgument, ConnectionClosed
from . import utils
from .enums import Status
log = logging.getLogger(__name__)
+class EventType:
+ Close = 0
+ Resume = 1
+ Identify = 2
+
class Shard:
def __init__(self, ws, client):
self.ws = ws
self._client = client
self._dispatch = client.dispatch
+ self._queue = client._queue
self.loop = self._client.loop
- self._current = self.loop.create_future()
- self._current.set_result(None) # we just need an already done future
- self._pending = asyncio.Event()
- self._pending_task = None
+ self._task = None
@property
def id(self):
return self.ws.shard_id
- def is_pending(self):
- return not self._pending.is_set()
-
- def complete_pending_reads(self):
- self._pending.set()
-
- async def _pending_reads(self):
- try:
- while self.is_pending():
- await self.poll()
- except asyncio.CancelledError:
- pass
-
- def launch_pending_reads(self):
- self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop)
-
- def wait(self):
- return self._pending_task
+ def launch(self):
+ self._task = self.loop.create_task(self.worker())
- async def poll(self):
- try:
- await self.ws.poll_event()
- except ResumeWebSocket:
- log.info('Got a request to RESUME the websocket at Shard ID %s.', self.id)
- coro = DiscordWebSocket.from_client(self._client, resume=True, shard_id=self.id,
- session=self.ws.session_id, sequence=self.ws.sequence)
- self._dispatch('disconnect')
- self.ws = await asyncio.wait_for(coro, timeout=180.0)
-
- def get_future(self):
- if self._current.done():
- self._current = asyncio.ensure_future(self.poll(), loop=self.loop)
+ async def worker(self):
+ while True:
+ try:
+ await self.ws.poll_event()
+ except ReconnectWebSocket as e:
+ etype = EventType.resume if e.resume else EventType.identify
+ self._queue.put_nowait((etype, self, e))
+ break
+ except ConnectionClosed as e:
+ self._queue.put_nowait((EventType.close, self, e))
+ break
+
+ async def reconnect(self, exc):
+ if self._task is not None and not self._task.done():
+ self._task.cancel()
+
+ log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
+ if not exc.resume:
+ await asyncio.sleep(5.0)
- return self._current
+ coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
+ session=self.ws.session_id, sequence=self.ws.sequence)
+ self._dispatch('disconnect')
+ self.ws = await asyncio.wait_for(coro, timeout=180.0)
+ self.launch()
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
@@ -134,6 +131,7 @@ class AutoShardedClient(Client):
# the key is the shard_id
self.shards = {}
self._connection._get_websocket = self._get_websocket
+ self._queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None):
if shard_id is None:
@@ -220,8 +218,10 @@ class AutoShardedClient(Client):
# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
- ret.launch_pending_reads()
- await asyncio.sleep(5.0)
+ ret.launch()
+
+ if len(self.shards) == self.shard_count:
+ self._connection.shards_launched.set()
async def launch_shards(self):
if self.shard_count is None:
@@ -234,26 +234,29 @@ class AutoShardedClient(Client):
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
self._connection.shard_ids = shard_ids
+ last_shard_id = shard_ids[-1]
for shard_id in shard_ids:
await self.launch_shard(gateway, shard_id)
+ if shard_id != last_shard_id:
+ await asyncio.sleep(5.0)
- shards_to_wait_for = []
- for shard in self.shards.values():
- shard.complete_pending_reads()
- shards_to_wait_for.append(shard.wait())
+ # shards_to_wait_for = []
+ # for shard in self.shards.values():
+ # shard.complete_pending_reads()
+ # shards_to_wait_for.append(shard.wait())
- # wait for all pending tasks to finish
- await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
+ # # wait for all pending tasks to finish
+ # await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
async def _connect(self):
await self.launch_shards()
while True:
- pollers = [shard.get_future() for shard in self.shards.values()]
- done, _ = await asyncio.wait(pollers, return_when=asyncio.FIRST_COMPLETED)
- for f in done:
- # we wanna re-raise to the main Client.connect handler if applicable
- f.result()
+ etype, shard, exc = await self._queue.get()
+ if etype == EventType.close:
+ raise exc
+ elif etype in (EventType.identify, EventType.resume):
+ await shard.reconnect(exc)
async def close(self):
"""|coro|