aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
Diffstat (limited to 'discord')
-rw-r--r--discord/client.py37
-rw-r--r--discord/gateway.py14
-rw-r--r--discord/http.py124
-rw-r--r--discord/shard.py10
4 files changed, 107 insertions, 78 deletions
diff --git a/discord/client.py b/discord/client.py
index 068e3d65..a0e27d8e 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -480,19 +480,6 @@ class Client:
"""
await self.close()
- async def _connect(self):
- coro = DiscordWebSocket.from_client(self, initial=True, shard_id=self.shard_id)
- self.ws = await asyncio.wait_for(coro, timeout=180.0)
- while True:
- try:
- await self.ws.poll_event()
- except ReconnectWebSocket as e:
- log.info('Got a request to %s the websocket.', e.op)
- self.dispatch('disconnect')
- coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id,
- sequence=self.ws.sequence, resume=e.resume)
- self.ws = await asyncio.wait_for(coro, timeout=180.0)
-
async def connect(self, *, reconnect=True):
"""|coro|
@@ -519,9 +506,22 @@ class Client:
"""
backoff = ExponentialBackoff()
+ ws_params = {
+ 'initial': True,
+ 'shard_id': self.shard_id,
+ }
while not self.is_closed():
try:
- await self._connect()
+ coro = DiscordWebSocket.from_client(self, **ws_params)
+ self.ws = await asyncio.wait_for(coro, timeout=60.0)
+ ws_params['initial'] = False
+ while True:
+ await self.ws.poll_event()
+ except ReconnectWebSocket as e:
+ log.info('Got a request to %s the websocket.', e.op)
+ self.dispatch('disconnect')
+ ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
+ continue
except (OSError,
HTTPException,
GatewayNotFound,
@@ -540,6 +540,11 @@ class Client:
if self.is_closed():
return
+ # If we get connection reset by peer then try to RESUME
+ if isinstance(exc, OSError) and exc.errno in (54, 10054):
+ ws_params.update(sequence=self.ws.sequence, initial=False, resume=True, session=self.ws.session_id)
+ continue
+
# We should only get this when an unhandled close code happens,
# such as a clean disconnect (1000) or a bad state (bad token, no sharding, etc)
# sometimes, discord sends us 1000 for unknown reasons so we should reconnect
@@ -552,6 +557,10 @@ class Client:
retry = backoff.delay()
log.exception("Attempting a reconnect in %.2fs", retry)
await asyncio.sleep(retry)
+ # Always try to RESUME the connection
+ # If the connection is not RESUME-able then the gateway will invalidate the session.
+ # This is apparently what the official Discord client does.
+ ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id)
async def close(self):
"""|coro|
diff --git a/discord/gateway.py b/discord/gateway.py
index 3f92ec1f..f262477f 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -508,16 +508,21 @@ class DiscordWebSocket:
elif msg.type is aiohttp.WSMsgType.ERROR:
log.debug('Received %s', msg)
raise msg.data
- elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
+ elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
log.debug('Received %s', msg)
raise WebSocketClosure
- except WebSocketClosure as e:
+ except WebSocketClosure:
+ # Ensure the keep alive handler is closed
+ if self._keep_alive:
+ self._keep_alive.stop()
+ self._keep_alive = None
+
if self._can_handle_close():
log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code)
- raise ReconnectWebSocket(self.shard_id) from e
+ raise ReconnectWebSocket(self.shard_id) from None
elif self.socket.close_code is not None:
log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code)
- raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e
+ raise ConnectionClosed(self.socket, shard_id=self.shard_id) from None
async def send(self, data):
self._dispatch('socket_raw_send', data)
@@ -598,6 +603,7 @@ class DiscordWebSocket:
async def close(self, code=4000):
if self._keep_alive:
self._keep_alive.stop()
+ self._keep_alive = None
await self.socket.close(code=code)
diff --git a/discord/http.py b/discord/http.py
index 00e66ac2..ceb6137a 100644
--- a/discord/http.py
+++ b/discord/http.py
@@ -180,68 +180,76 @@ class HTTPClient:
if files:
for f in files:
f.reset(seek=tries)
-
- async with self.__session.request(method, url, **kwargs) as r:
- log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status)
-
- # even errors have text involved in them so this is safe to call
- data = await json_or_text(r)
-
- # check if we have rate limit header information
- remaining = r.headers.get('X-Ratelimit-Remaining')
- if remaining == '0' and r.status != 429:
- # we've depleted our current bucket
- delta = utils._parse_ratelimit_header(r, use_clock=self.use_clock)
- log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
- maybe_lock.defer()
- self.loop.call_later(delta, lock.release)
-
- # the request was successful so just return the text/json
- if 300 > r.status >= 200:
- log.debug('%s %s has received %s', method, url, data)
- return data
-
- # we are being rate limited
- if r.status == 429:
- if not r.headers.get('Via'):
- # Banned by Cloudflare more than likely.
+ try:
+ async with self.__session.request(method, url, **kwargs) as r:
+ log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status)
+
+ # even errors have text involved in them so this is safe to call
+ data = await json_or_text(r)
+
+ # check if we have rate limit header information
+ remaining = r.headers.get('X-Ratelimit-Remaining')
+ if remaining == '0' and r.status != 429:
+ # we've depleted our current bucket
+ delta = utils._parse_ratelimit_header(r, use_clock=self.use_clock)
+ log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
+ maybe_lock.defer()
+ self.loop.call_later(delta, lock.release)
+
+ # the request was successful so just return the text/json
+ if 300 > r.status >= 200:
+ log.debug('%s %s has received %s', method, url, data)
+ return data
+
+ # we are being rate limited
+ if r.status == 429:
+ if not r.headers.get('Via'):
+ # Banned by Cloudflare more than likely.
+ raise HTTPException(r, data)
+
+ fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"'
+
+ # sleep a bit
+ retry_after = data['retry_after'] / 1000.0
+ log.warning(fmt, retry_after, bucket)
+
+ # check if it's a global rate limit
+ is_global = data.get('global', False)
+ if is_global:
+ log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
+ self._global_over.clear()
+
+ await asyncio.sleep(retry_after)
+ log.debug('Done sleeping for the rate limit. Retrying...')
+
+ # release the global lock now that the
+ # global rate limit has passed
+ if is_global:
+ self._global_over.set()
+ log.debug('Global rate limit is now over.')
+
+ continue
+
+ # we've received a 500 or 502, unconditional retry
+ if r.status in {500, 502}:
+ await asyncio.sleep(1 + tries * 2)
+ continue
+
+ # the usual error cases
+ if r.status == 403:
+ raise Forbidden(r, data)
+ elif r.status == 404:
+ raise NotFound(r, data)
+ else:
raise HTTPException(r, data)
- fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"'
-
- # sleep a bit
- retry_after = data['retry_after'] / 1000.0
- log.warning(fmt, retry_after, bucket)
-
- # check if it's a global rate limit
- is_global = data.get('global', False)
- if is_global:
- log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
- self._global_over.clear()
-
- await asyncio.sleep(retry_after)
- log.debug('Done sleeping for the rate limit. Retrying...')
-
- # release the global lock now that the
- # global rate limit has passed
- if is_global:
- self._global_over.set()
- log.debug('Global rate limit is now over.')
-
- continue
-
- # we've received a 500 or 502, unconditional retry
- if r.status in {500, 502}:
- await asyncio.sleep(1 + tries * 2)
+ # This is handling exceptions from the request
+ except OSError as e:
+ # Connection reset by peer
+ if e.errno in (54, 10054):
+ # Just re-do the request
continue
- # the usual error cases
- if r.status == 403:
- raise Forbidden(r, data)
- elif r.status == 404:
- raise NotFound(r, data)
- else:
- raise HTTPException(r, data)
# We've run out of retries, raise.
raise HTTPException(r, data)
diff --git a/discord/shard.py b/discord/shard.py
index dfa3849c..7659e5ec 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -112,6 +112,12 @@ class Shard:
if self._client.is_closed():
return
+ if isinstance(e, OSError) and e.errno in (54, 10054):
+ # If we get Connection reset by peer then always try to RESUME the connection.
+ exc = ReconnectWebSocket(self.id, resume=True)
+ self._queue.put_nowait(EventItem(EventType.resume, self, exc))
+ return
+
if isinstance(e, ConnectionClosed):
if e.code != 1000:
self._queue.put_nowait(EventItem(EventType.close, self, e))
@@ -142,7 +148,7 @@ class Shard:
try:
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence)
- self.ws = await asyncio.wait_for(coro, timeout=180.0)
+ self.ws = await asyncio.wait_for(coro, timeout=60.0)
except self._handled_exceptions as e:
await self._handle_disconnect(e)
else:
@@ -152,7 +158,7 @@ class Shard:
self._cancel_task()
try:
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
- self.ws = await asyncio.wait_for(coro, timeout=180.0)
+ self.ws = await asyncio.wait_for(coro, timeout=60.0)
except self._handled_exceptions as e:
await self._handle_disconnect(e)
else: