diff options
| author | Rapptz <[email protected]> | 2016-06-12 20:32:59 -0400 |
|---|---|---|
| committer | Rapptz <[email protected]> | 2016-06-12 20:33:09 -0400 |
| commit | 1fba1b06faca31d07c9296b2badabfe22f173001 (patch) | |
| tree | 858fc299aa7299f1013da117599992f6049fc42d /discord/http.py | |
| parent | Change HTTPException to only take a single parameter. (diff) | |
| download | discord.py-1fba1b06faca31d07c9296b2badabfe22f173001.tar.xz discord.py-1fba1b06faca31d07c9296b2badabfe22f173001.zip | |
Rewrite HTTP handling significantly.
This should have a more uniform approach to rate limit handling. Instead
of queueing every request, wait until we receive a 429 and then block
the requesting bucket until we're done being rate limited. This should
reduce the number of 429s done by the API significantly (about 66% avg).
This also consistently checks for 502 retries across all requests.
Diffstat (limited to 'discord/http.py')
| -rw-r--r-- | discord/http.py | 484 |
1 files changed, 484 insertions, 0 deletions
diff --git a/discord/http.py b/discord/http.py new file mode 100644 index 00000000..15cd08ef --- /dev/null +++ b/discord/http.py @@ -0,0 +1,484 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2016 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import aiohttp +import asyncio +import json +import sys +import logging +import io +import inspect +import weakref +from random import randint as random_integer + +log = logging.getLogger(__name__) + +from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound +from . import utils, __version__ + +def json_or_text(response): + text = yield from response.text(encoding='utf-8') + if response.headers['content-type'] == 'application/json': + return json.loads(text) + return text + +def _func_(): + # emulate __func__ from C++ + return inspect.currentframe().f_back.f_code.co_name + +class HTTPClient: + """Represents an HTTP client sending HTTP requests to the Discord API.""" + + BASE = 'https://discordapp.com' + API_BASE = BASE + '/api' + GATEWAY = API_BASE + '/gateway' + USERS = API_BASE + '/users' + ME = USERS + '/@me' + REGISTER = API_BASE + '/auth/register' + LOGIN = API_BASE + '/auth/login' + LOGOUT = API_BASE + '/auth/logout' + GUILDS = API_BASE + '/guilds' + CHANNELS = API_BASE + '/channels' + APPLICATIONS = API_BASE + '/oauth2/applications' + + SUCCESS_LOG = '{method} {url} with {json} has received {text}' + REQUEST_LOG = '{method} {url} has returned {status}' + + def __init__(self, connector=None, *, loop=None): + self.loop = asyncio.get_event_loop() if loop is None else loop + self.connector = connector + self.session = aiohttp.ClientSession(connector=connector, loop=self.loop) + self._locks = weakref.WeakValueDictionary() + self.token = None + self.bot_token = False + + user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' + self.user_agent = user_agent.format(__version__, sys.version_info, aiohttp.__version__) + + @asyncio.coroutine + def request(self, method, url, *, bucket=None, **kwargs): + lock = self._locks.get(bucket) + if lock is None: + lock = asyncio.Lock(loop=self.loop) + if bucket is not None: + self._locks[bucket] = lock + + # header creation + headers = { + 'User-Agent': self.user_agent, + } + + if self.token is not None: + headers['Authorization'] = 'Bot ' + self.token if self.bot_token else self.token + + # some checking if it's a JSON request + if 'json' in kwargs: + headers['Content-Type'] = 'application/json' + kwargs['data'] = utils.to_json(kwargs.pop('json')) + + kwargs['headers'] = headers + with (yield from lock): + for tries in range(5): + r = yield from self.session.request(method, url, **kwargs) + log.debug(self.REQUEST_LOG.format(method=method, url=url, status=r.status)) + try: + # even errors have text involved in them so this is safe to call + data = yield from json_or_text(r) + + # the request was successful so just return the text/json + if 300 > r.status >= 200: + log.debug(self.SUCCESS_LOG.format(method=method, url=url, + json=kwargs.get('data'), text=data)) + return data + + # we are being rate limited + if r.status == 429: + fmt = 'We are being rate limited. Retrying in {:.2} seconds. Handled under the bucket "{}"' + + # sleep a bit + retry_after = data['retry_after'] / 1000.0 + log.info(fmt.format(retry_after, bucket)) + yield from asyncio.sleep(retry_after) + continue + + # we've received a 502, unconditional retry + if r.status == 502 and tries <= 5: + yield from asyncio.sleep(1 + tries * 2) + continue + + # the usual error cases + if r.status == 403: + raise Forbidden(r, data) + elif r.status == 404: + raise NotFound(r, data) + else: + raise HTTPException(r, data) + finally: + # clean-up just in case + yield from r.release() + + def get(self, *args, **kwargs): + return self.request('GET', *args, **kwargs) + + def put(self, *args, **kwargs): + return self.request('PUT', *args, **kwargs) + + def patch(self, *args, **kwargs): + return self.request('PATCH', *args, **kwargs) + + def delete(self, *args, **kwargs): + return self.request('DELETE', *args, **kwargs) + + def post(self, *args, **kwargs): + return self.request('POST', *args, **kwargs) + + # state management + + @asyncio.coroutine + def close(self): + yield from self.session.close() + + def recreate(self): + self.session = aiohttp.ClientSession(self.connector, loop=self.loop) + + def _token(self, token, *, bot=True): + self.token = token + self.bot_token = bot + + # login management + + @asyncio.coroutine + def email_login(self, email, password): + payload = { + 'email': email, + 'password': password + } + + try: + data = yield from self.post(self.LOGIN, json=payload, bucket=_func_()) + except HTTPException as e: + if e.response.status == 400: + raise LoginFailure('Improper credentials have been passed.') from e + raise + + self._token(data['token'], bot=False) + return data + + @asyncio.coroutine + def static_login(self, token, *, bot): + old_state = (self.token, self.bot_token) + self._token(token, bot=bot) + + try: + data = yield from self.get(self.ME) + except HTTPException as e: + self._token(*old_state) + if e.response.status == 401: + raise LoginFailure('Improper token has been passed.') from e + raise e + + return data + + def logout(self): + return self.post(self.LOGOUT, bucket=_func_()) + + # Message management + + def start_private_message(self, user_id): + payload = { + 'recipient_id': user_id + } + + return self.post(self.ME + '/channels', json=payload, bucket=_func_()) + + def send_message(self, channel_id, content, *, guild_id=None, tts=False): + url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) + payload = { + 'content': str(content), + 'nonce': random_integer(-2**63, 2**63 - 1) + } + + if tts: + payload['tts'] = True + + return self.post(url, json=payload, bucket='messages:' + str(guild_id)) + + def send_typing(self, channel_id): + url = '{0.CHANNELS}/{1}/typing'.format(self, channel_id) + return self.post(url, bucket=_func_()) + + def send_file(self, channel_id, buffer, *, guild_id=None, filename=None, content=None, tts=False): + url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) + form = aiohttp.FormData() + + if content is not None: + form.add_field('content', str(content)) + + form.add_field('tts', 'true' if tts else 'false') + form.add_field('file', io.BytesIO(buffer), filename=filename, content_type='application/octet-stream') + + return self.post(url, data=form, bucket='messages:' + str(guild_id)) + + def delete_message(self, channel_id, message_id, guild_id=None): + url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id) + bucket = '{}:{}'.format(_func_(), guild_id) + return self.delete(url, bucket=bucket) + + def delete_messages(self, channel_id, message_ids, guild_id=None): + url = '{0.CHANNELS}/{1}/messages/bulk_delete'.format(self, channel_id) + payload = { + 'messages': message_ids + } + bucket = '{}:{}'.format(_func_(), guild_id) + return self.post(url, json=payload, bucket=bucket) + + def edit_message(self, message_id, channel_id, content, *, guild_id=None): + url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id) + payload = { + 'content': str(content) + } + return self.patch(url, json=payload, bucket='messages:' + str(guild_id)) + + + def logs_from(self, channel_id, limit, before=None, after=None): + url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) + params = { + 'limit': limit + } + + if before: + params['before'] = before + if after: + params['after'] = after + + return self.get(url, params=params, bucket=_func_()) + + # Member management + + def kick(self, user_id, guild_id): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + return self.delete(url, bucket=_func_()) + + def ban(self, user_id, guild_id, delete_message_days=1): + url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id) + params = { + 'delete-message-days': delete_message_days + } + return self.put(url, params=params, bucket=_func_()) + + def unban(self, user_id, guild_id): + url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id) + return self.delete(url, bucket=_func_()) + + def server_voice_state(self, user_id, guild_id, *, mute=False, deafen=False): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + payload = { + 'mute': mute, + 'deafen': deafen + } + return self.patch(url, json=payload, bucket='members:' + str(guild_id)) + + def edit_profile(self, password, username, avatar, **fields): + payload = { + 'password': password, + 'username': username, + 'avatar': avatar + } + + if 'email' in fields: + payload['email'] = fields['email'] + + if 'new_password' in fields: + payload['new_password'] = fields['new_password'] + + return self.patch(self.ME, json=payload, bucket=_func_()) + + def change_my_nickname(self, guild_id, nickname): + url = '{0.GUILDS}/{1}/members/@me/nick'.format(self, guild_id) + payload = { + 'nick': nickname + } + bucket = '{}:{}'.format(_func_(), guild_id) + return self.patch(url, json=payload, bucket=bucket) + + def change_nickname(self, guild_id, user_id, nickname): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + payload = { + 'nick': nickname + } + bucket = '{}:{}'.format(_func_(), guild_id) + return self.patch(url, json=payload, bucket=bucket) + + # Channel management + + def edit_channel(self, channel_id, **options): + url = '{0.CHANNELS}/{1}'.format(self, channel_id) + + valid_keys = ('name', 'topic', 'bitrate', 'user_limit') + payload = { + k: v for k, v in options.items() if k in valid_keys + } + + return self.patch(url, json=payload, bucket=_func_()) + + def create_channel(self, guild_id, name, channe_type): + url = '{0.GUILDS}/{1}/channels'.format(self, guild_id) + payload = { + 'name': name, + 'type': channe_type + } + + return self.post(url, json=payload, bucket=_func_()) + + def delete_channel(self, channel_id): + url = '{0.CHANNELS}/{1}'.format(self, channel_id) + return self.delete(url, bucket=_func_()) + + # Server management + + def leave_server(self, guild_id): + url = '{0.USERS}/@me/guilds/{1}'.format(self, guild_id) + return self.delete(url, bucket=_func_()) + + def delete_server(self, guild_id): + url = '{0.GUILDS}/{1}'.format(self, guild_id) + return self.delete(url, bucket=_func_()) + + def create_server(self, name, region, icon): + payload = { + 'name': name, + 'icon': icon, + 'region': region + } + + return self.post(self.GUILDS, json=payload, bucket=_func_()) + + def edit_server(self, guild_id, **fields): + valid_keys = ('name', 'region', 'icon', 'afk_timeout', 'owner_id', + 'afk_channel_id', 'splash', 'verification_level') + + payload = { + k: v for k, v in fields.items() if k in valid_keys + } + + url = '{0.GUILDS}/{1}'.format(self, guild_id) + return self.patch(url, json=payload, bucket=_func_()) + + def get_bans(self, guild_id): + url = '{0.GUILDS}/{1}/bans'.format(self, guild_id) + return self.get(url, bucket=_func_()) + + # Invite management + + def create_invite(self, channel_id, **options): + url = '{0.CHANNELS}/{1}/invites'.format(self, channel_id) + payload = { + 'max_age': options.get('max_age', 0), + 'max_uses': options.get('max_uses', 0), + 'temporary': options.get('temporary', False), + 'xkcdpass': options.get('xkcd', False) + } + + return self.post(url, json=payload, bucket=_func_()) + + def get_invite(self, invite_id): + url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) + return self.get(url, bucket=_func_()) + + def invites_from(self, guild_id): + url = '{0.GUILDS}/{1}/invites'.format(self, guild_id) + return self.get(url, bucket=_func_()) + + def accept_invite(self, invite_id): + url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) + return self.post(url, bucket=_func_()) + + def delete_invite(self, invite_id): + url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) + return self.delete(url, bucket=_func_()) + + # Role management + + def edit_role(self, guild_id, role_id, **fields): + url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id) + valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable') + payload = { + k: v for k, v in fields.items() if k in valid_keys + } + return self.patch(url, json=payload, bucket='roles:' + str(guild_id)) + + def delete_role(self, guild_id, role_id): + url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id) + return self.delete(url, bucket=_func_()) + + def replace_roles(self, user_id, guild_id, role_ids): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + payload = { + 'roles': role_ids + } + return self.patch(url, json=payload, bucket='members:' + str(guild_id)) + + def create_role(self, guild_id): + url = '{0.GUILDS}/{1}/roles'.format(self, guild_id) + return self.post(url, bucket=_func_()) + + def edit_channel_permissions(self, channel_id, target, allow, deny, type): + url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target) + payload = { + 'id': target, + 'allow': allow, + 'deny': deny, + 'type': type + } + return self.put(url, json=payload, bucket=_func_()) + + def delete_channel_permissions(self, channel_id, target): + url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target) + return self.delete(url, bucket=_func_()) + + # Voice management + + def move_member(self, user_id, guild_id, channel_id): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + payload = { + 'channel_id': channel_id + } + return self.patch(url, json=payload, bucket='members:' + str(guild_id)) + + # Misc + + def application_info(self): + url = '{0.APPLICATIONS}/@me'.format(self) + return self.get(url, bucket=_func_()) + + @asyncio.coroutine + def get_gateway(self): + try: + data = yield from self.get(self.GATEWAY, bucket=_func_()) + except HTTPException as e: + raise GatewayNotFound() from e + return data.get('url') + '?encoding=json&v=4' |