aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
Diffstat (limited to 'discord')
-rw-r--r--discord/gateway.py5
-rw-r--r--discord/http.py2
-rw-r--r--discord/utils.py23
3 files changed, 24 insertions, 6 deletions
diff --git a/discord/gateway.py b/discord/gateway.py
index 5fec8748..0a6b05d0 100644
--- a/discord/gateway.py
+++ b/discord/gateway.py
@@ -25,7 +25,6 @@ DEALINGS IN THE SOFTWARE.
import asyncio
from collections import namedtuple, deque
import concurrent.futures
-import json
import logging
import struct
import sys
@@ -421,7 +420,7 @@ class DiscordWebSocket:
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
- msg = json.loads(msg)
+ msg = utils.from_json(msg)
log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
self._dispatch('socket_response', msg)
@@ -882,7 +881,7 @@ class DiscordVoiceWebSocket:
# This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
- await self.received_message(json.loads(msg.data))
+ await self.received_message(utils.from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
diff --git a/discord/http.py b/discord/http.py
index cfc71e22..c9c340ae 100644
--- a/discord/http.py
+++ b/discord/http.py
@@ -99,7 +99,7 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any]
text = await response.text(encoding='utf-8')
try:
if response.headers['content-type'] == 'application/json':
- return json.loads(text)
+ return utils.from_json(text)
except KeyError:
# Thanks Cloudflare
pass
diff --git a/discord/utils.py b/discord/utils.py
index 6070882f..b5563a67 100644
--- a/discord/utils.py
+++ b/discord/utils.py
@@ -63,6 +63,14 @@ import warnings
from .errors import InvalidArgument
+try:
+ import orjson
+except ModuleNotFoundError:
+ HAS_ORJSON = False
+else:
+ HAS_ORJSON = True
+
+
__all__ = (
'oauth_url',
'snowflake_time',
@@ -468,8 +476,19 @@ def _bytes_to_base64_data(data: bytes) -> str:
return fmt.format(mime=mime, data=b64)
-def to_json(obj: Any) -> str:
- return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
+if HAS_ORJSON:
+
+ def to_json(obj: Any) -> str: # type: ignore
+ return orjson.dumps(obj).decode('utf-8')
+
+ from_json = orjson.loads # type: ignore
+
+else:
+
+ def to_json(obj: Any) -> str:
+ return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
+
+ from_json = json.loads
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: