diff options
Diffstat (limited to 'discord/http.py')
| -rw-r--r-- | discord/http.py | 63 |
1 files changed, 25 insertions, 38 deletions
diff --git a/discord/http.py b/discord/http.py index 39c24fdb..de8b6b0d 100644 --- a/discord/http.py +++ b/discord/http.py @@ -38,9 +38,8 @@ log = logging.getLogger(__name__) from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound from . import __version__, utils -def json_or_text(response): - text = yield from response.text(encoding='utf-8') +async def json_or_text(response): + text = await response.text(encoding='utf-8') if response.headers['content-type'] == 'application/json': return json.loads(text) return text @@ -106,8 +105,7 @@ class HTTPClient: if self._session.closed: self._session = aiohttp.ClientSession(connector=self.connector, loop=self.loop) - @asyncio.coroutine - def request(self, route, *, header_bypass_delay=None, **kwargs): + async def request(self, route, *, header_bypass_delay=None, **kwargs): bucket = route.bucket method = route.method url = route.url @@ -148,16 +146,16 @@ class HTTPClient: if not self._global_over.is_set(): # wait until the global lock is complete - yield from self._global_over.wait() + await self._global_over.wait() - yield from lock + await lock with MaybeUnlock(lock) as maybe_lock: for tries in range(5): - r = yield from self._session.request(method, url, **kwargs) - log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status) - try: + async with self._session.request(method, url, **kwargs) as r: + log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status) + # even errors have text involved in them so this is safe to call - data = yield from json_or_text(r) + data = await json_or_text(r) # check if we have rate limit header information remaining = r.headers.get('X-Ratelimit-Remaining') @@ -191,7 +189,7 @@ class HTTPClient: log.info('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after) self._global_over.clear() - yield from asyncio.sleep(retry_after, loop=self.loop) + await asyncio.sleep(retry_after, loop=self.loop) log.debug('Done sleeping for the rate limit. Retrying...') # release the global lock now that the @@ -204,7 +202,7 @@ class HTTPClient: # we've received a 500 or 502, unconditional retry if r.status in {500, 502}: - yield from asyncio.sleep(1 + tries * 2, loop=self.loop) + await asyncio.sleep(1 + tries * 2, loop=self.loop) continue # the usual error cases @@ -214,32 +212,25 @@ class HTTPClient: raise NotFound(r, data) else: raise HTTPException(r, data) - finally: - # clean-up just in case - yield from r.release() + # We've run out of retries, raise. raise HTTPException(r, data) - @asyncio.coroutine - def get_attachment(self, url): - resp = yield from self._session.get(url) - try: + async def get_attachment(self, url): + async with self._session.get(url) as resp: if resp.status == 200: - return (yield from resp.read()) + return (await resp.read()) elif resp.status == 404: raise NotFound(resp, 'attachment not found') elif resp.status == 403: raise Forbidden(resp, 'cannot retrieve attachment') else: raise HTTPException(resp, 'failed to get attachment') - finally: - yield from resp.release() # state management - @asyncio.coroutine - def close(self): - yield from self._session.close() + async def close(self): + await self._session.close() def _token(self, token, *, bot=True): self.token = token @@ -248,13 +239,12 @@ class HTTPClient: # login management - @asyncio.coroutine - def static_login(self, token, *, bot): + async def static_login(self, token, *, bot): old_token, old_bot = self.token, self.bot_token self._token(token, bot=bot) try: - data = yield from self.request(Route('GET', '/users/@me')) + data = await self.request(Route('GET', '/users/@me')) except HTTPException as e: self._token(old_token, bot=old_bot) if e.response.status == 401: @@ -349,11 +339,10 @@ class HTTPClient: return self.request(r, data=form) - @asyncio.coroutine - def ack_message(self, channel_id, message_id): + async def ack_message(self, channel_id, message_id): r = Route('POST', '/channels/{channel_id}/messages/{message_id}/ack', channel_id=channel_id, message_id=message_id) - data = yield from self.request(r, json={'token': self._ack_token}) + data = await self.request(r, json={'token': self._ack_token}) self._ack_token = data['token'] def ack_guild(self, guild_id): @@ -751,10 +740,9 @@ class HTTPClient: def application_info(self): return self.request(Route('GET', '/oauth2/applications/@me')) - @asyncio.coroutine - def get_gateway(self, *, encoding='json', v=6, zlib=True): + async def get_gateway(self, *, encoding='json', v=6, zlib=True): try: - data = yield from self.request(Route('GET', '/gateway')) + data = await self.request(Route('GET', '/gateway')) except HTTPException as e: raise GatewayNotFound() from e if zlib: @@ -763,10 +751,9 @@ class HTTPClient: value = '{0}?encoding={1}&v={2}' return value.format(data['url'], encoding, v) - @asyncio.coroutine - def get_bot_gateway(self, *, encoding='json', v=6, zlib=True): + async def get_bot_gateway(self, *, encoding='json', v=6, zlib=True): try: - data = yield from self.request(Route('GET', '/gateway/bot')) + data = await self.request(Route('GET', '/gateway/bot')) except HTTPException as e: raise GatewayNotFound() from e |