aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/shard.py84
1 files changed, 70 insertions, 14 deletions
diff --git a/discord/shard.py b/discord/shard.py
index f817fb9a..31777816 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -28,10 +28,13 @@ import asyncio
import itertools
import logging
+import aiohttp
+
from .state import AutoShardedConnectionState
from .client import Client
+from .backoff import ExponentialBackoff
from .gateway import *
-from .errors import ClientException, InvalidArgument, ConnectionClosed
+from .errors import ClientException, InvalidArgument, HTTPException, GatewayNotFound, ConnectionClosed
from . import utils
from .enums import Status
@@ -39,8 +42,9 @@ log = logging.getLogger(__name__)
class EventType:
close = 0
- resume = 1
- identify = 2
+ reconnect = 1
+ resume = 2
+ identify = 3
class EventItem:
__slots__ = ('type', 'shard', 'error')
@@ -70,7 +74,18 @@ class Shard:
self._dispatch = client.dispatch
self._queue = client._queue
self.loop = self._client.loop
+ self._disconnect = False
+ self._reconnect = client._reconnect
+ self._backoff = ExponentialBackoff()
self._task = None
+ self._handled_exceptions = (
+ OSError,
+ HTTPException,
+ GatewayNotFound,
+ ConnectionClosed,
+ aiohttp.ClientError,
+ asyncio.TimeoutError,
+ )
@property
def id(self):
@@ -79,6 +94,33 @@ class Shard:
def launch(self):
self._task = self.loop.create_task(self.worker())
+ def _cancel_task(self):
+ if self._task is not None and not self._task.done():
+ self._task.cancel()
+
+ async def close(self):
+ self._cancel_task()
+ await self.ws.close(code=1000)
+
+ async def _handle_disconnect(self, e):
+ self._dispatch('disconnect')
+ if not self._reconnect:
+ self._queue.put_nowait(EventItem(EventType.close, self, e))
+ return
+
+ if self._client.is_closed():
+ return
+
+ if isinstance(e, ConnectionClosed):
+ if e.code != 1000:
+ self._queue.put_nowait(EventItem(EventType.close, self, e))
+ return
+
+ retry = self._backoff.delay()
+ log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
+ await asyncio.sleep(retry)
+ self._queue.put_nowait(EventItem(EventType.reconnect, self, e))
+
async def worker(self):
while not self._client.is_closed():
try:
@@ -87,14 +129,12 @@ class Shard:
etype = EventType.resume if e.resume else EventType.identify
self._queue.put_nowait(EventItem(etype, self, e))
break
- except ConnectionClosed as e:
- self._queue.put_nowait(EventItem(EventType.close, self, e))
+ except self._handled_exceptions as e:
+ await self._handle_disconnect(e)
break
- async def reconnect(self, exc):
- if self._task is not None and not self._task.done():
- self._task.cancel()
-
+ async def reidentify(self, exc):
+ self._cancel_task()
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence)
@@ -102,6 +142,16 @@ class Shard:
self.ws = await asyncio.wait_for(coro, timeout=180.0)
self.launch()
+ async def reconnect(self):
+ self._cancel_task()
+ try:
+ coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
+ self.ws = await asyncio.wait_for(coro, timeout=180.0)
+ except self._handled_exceptions as e:
+ await self._handle_disconnect(e)
+ else:
+ self.launch()
+
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
@@ -235,15 +285,21 @@ class AutoShardedClient(Client):
self._connection.shards_launched.set()
- async def _connect(self):
+ async def connect(self, *, reconnect=True):
+ self._reconnect = reconnect
await self.launch_shards()
- while True:
+ while not self.is_closed():
item = await self._queue.get()
if item.type == EventType.close:
- raise item.error
+ await self.close()
+ if isinstance(item.error, ConnectionClosed) and item.error.code != 1000:
+ raise item.error
+ return
elif item.type in (EventType.identify, EventType.resume):
- await item.shard.reconnect(item.error)
+ await item.shard.reidentify(item.error)
+ elif item.type == EventType.reconnect:
+ await item.shard.reconnect()
async def close(self):
"""|coro|
@@ -261,7 +317,7 @@ class AutoShardedClient(Client):
except Exception:
pass
- to_close = [asyncio.ensure_future(shard.ws.close(code=1000), loop=self.loop) for shard in self.shards.values()]
+ to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.shards.values()]
if to_close:
await asyncio.wait(to_close)