aboutsummaryrefslogtreecommitdiff
path: root/discord/http.py
diff options
context:
space:
mode:
authorRapptz <[email protected]>2016-06-12 20:32:59 -0400
committerRapptz <[email protected]>2016-06-12 20:33:09 -0400
commit1fba1b06faca31d07c9296b2badabfe22f173001 (patch)
tree858fc299aa7299f1013da117599992f6049fc42d /discord/http.py
parentChange HTTPException to only take a single parameter. (diff)
downloaddiscord.py-1fba1b06faca31d07c9296b2badabfe22f173001.tar.xz
discord.py-1fba1b06faca31d07c9296b2badabfe22f173001.zip
Rewrite HTTP handling significantly.
This should have a more uniform approach to rate limit handling. Instead of queueing every request, wait until we receive a 429 and then block the requesting bucket until we're done being rate limited. This should reduce the number of 429s done by the API significantly (about 66% avg). This also consistently checks for 502 retries across all requests.
Diffstat (limited to 'discord/http.py')
-rw-r--r--discord/http.py484
1 files changed, 484 insertions, 0 deletions
diff --git a/discord/http.py b/discord/http.py
new file mode 100644
index 00000000..15cd08ef
--- /dev/null
+++ b/discord/http.py
@@ -0,0 +1,484 @@
+# -*- coding: utf-8 -*-
+
+"""
+The MIT License (MIT)
+
+Copyright (c) 2015-2016 Rapptz
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+
+import aiohttp
+import asyncio
+import json
+import sys
+import logging
+import io
+import inspect
+import weakref
+from random import randint as random_integer
+
+log = logging.getLogger(__name__)
+
+from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound
+from . import utils, __version__
+
+def json_or_text(response):
+ text = yield from response.text(encoding='utf-8')
+ if response.headers['content-type'] == 'application/json':
+ return json.loads(text)
+ return text
+
+def _func_():
+ # emulate __func__ from C++
+ return inspect.currentframe().f_back.f_code.co_name
+
+class HTTPClient:
+ """Represents an HTTP client sending HTTP requests to the Discord API."""
+
+ BASE = 'https://discordapp.com'
+ API_BASE = BASE + '/api'
+ GATEWAY = API_BASE + '/gateway'
+ USERS = API_BASE + '/users'
+ ME = USERS + '/@me'
+ REGISTER = API_BASE + '/auth/register'
+ LOGIN = API_BASE + '/auth/login'
+ LOGOUT = API_BASE + '/auth/logout'
+ GUILDS = API_BASE + '/guilds'
+ CHANNELS = API_BASE + '/channels'
+ APPLICATIONS = API_BASE + '/oauth2/applications'
+
+ SUCCESS_LOG = '{method} {url} with {json} has received {text}'
+ REQUEST_LOG = '{method} {url} has returned {status}'
+
+ def __init__(self, connector=None, *, loop=None):
+ self.loop = asyncio.get_event_loop() if loop is None else loop
+ self.connector = connector
+ self.session = aiohttp.ClientSession(connector=connector, loop=self.loop)
+ self._locks = weakref.WeakValueDictionary()
+ self.token = None
+ self.bot_token = False
+
+ user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
+ self.user_agent = user_agent.format(__version__, sys.version_info, aiohttp.__version__)
+
+ @asyncio.coroutine
+ def request(self, method, url, *, bucket=None, **kwargs):
+ lock = self._locks.get(bucket)
+ if lock is None:
+ lock = asyncio.Lock(loop=self.loop)
+ if bucket is not None:
+ self._locks[bucket] = lock
+
+ # header creation
+ headers = {
+ 'User-Agent': self.user_agent,
+ }
+
+ if self.token is not None:
+ headers['Authorization'] = 'Bot ' + self.token if self.bot_token else self.token
+
+ # some checking if it's a JSON request
+ if 'json' in kwargs:
+ headers['Content-Type'] = 'application/json'
+ kwargs['data'] = utils.to_json(kwargs.pop('json'))
+
+ kwargs['headers'] = headers
+ with (yield from lock):
+ for tries in range(5):
+ r = yield from self.session.request(method, url, **kwargs)
+ log.debug(self.REQUEST_LOG.format(method=method, url=url, status=r.status))
+ try:
+ # even errors have text involved in them so this is safe to call
+ data = yield from json_or_text(r)
+
+ # the request was successful so just return the text/json
+ if 300 > r.status >= 200:
+ log.debug(self.SUCCESS_LOG.format(method=method, url=url,
+ json=kwargs.get('data'), text=data))
+ return data
+
+ # we are being rate limited
+ if r.status == 429:
+ fmt = 'We are being rate limited. Retrying in {:.2} seconds. Handled under the bucket "{}"'
+
+ # sleep a bit
+ retry_after = data['retry_after'] / 1000.0
+ log.info(fmt.format(retry_after, bucket))
+ yield from asyncio.sleep(retry_after)
+ continue
+
+ # we've received a 502, unconditional retry
+ if r.status == 502 and tries <= 5:
+ yield from asyncio.sleep(1 + tries * 2)
+ continue
+
+ # the usual error cases
+ if r.status == 403:
+ raise Forbidden(r, data)
+ elif r.status == 404:
+ raise NotFound(r, data)
+ else:
+ raise HTTPException(r, data)
+ finally:
+ # clean-up just in case
+ yield from r.release()
+
+ def get(self, *args, **kwargs):
+ return self.request('GET', *args, **kwargs)
+
+ def put(self, *args, **kwargs):
+ return self.request('PUT', *args, **kwargs)
+
+ def patch(self, *args, **kwargs):
+ return self.request('PATCH', *args, **kwargs)
+
+ def delete(self, *args, **kwargs):
+ return self.request('DELETE', *args, **kwargs)
+
+ def post(self, *args, **kwargs):
+ return self.request('POST', *args, **kwargs)
+
+ # state management
+
+ @asyncio.coroutine
+ def close(self):
+ yield from self.session.close()
+
+ def recreate(self):
+ self.session = aiohttp.ClientSession(self.connector, loop=self.loop)
+
+ def _token(self, token, *, bot=True):
+ self.token = token
+ self.bot_token = bot
+
+ # login management
+
+ @asyncio.coroutine
+ def email_login(self, email, password):
+ payload = {
+ 'email': email,
+ 'password': password
+ }
+
+ try:
+ data = yield from self.post(self.LOGIN, json=payload, bucket=_func_())
+ except HTTPException as e:
+ if e.response.status == 400:
+ raise LoginFailure('Improper credentials have been passed.') from e
+ raise
+
+ self._token(data['token'], bot=False)
+ return data
+
+ @asyncio.coroutine
+ def static_login(self, token, *, bot):
+ old_state = (self.token, self.bot_token)
+ self._token(token, bot=bot)
+
+ try:
+ data = yield from self.get(self.ME)
+ except HTTPException as e:
+ self._token(*old_state)
+ if e.response.status == 401:
+ raise LoginFailure('Improper token has been passed.') from e
+ raise e
+
+ return data
+
+ def logout(self):
+ return self.post(self.LOGOUT, bucket=_func_())
+
+ # Message management
+
+ def start_private_message(self, user_id):
+ payload = {
+ 'recipient_id': user_id
+ }
+
+ return self.post(self.ME + '/channels', json=payload, bucket=_func_())
+
+ def send_message(self, channel_id, content, *, guild_id=None, tts=False):
+ url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
+ payload = {
+ 'content': str(content),
+ 'nonce': random_integer(-2**63, 2**63 - 1)
+ }
+
+ if tts:
+ payload['tts'] = True
+
+ return self.post(url, json=payload, bucket='messages:' + str(guild_id))
+
+ def send_typing(self, channel_id):
+ url = '{0.CHANNELS}/{1}/typing'.format(self, channel_id)
+ return self.post(url, bucket=_func_())
+
+ def send_file(self, channel_id, buffer, *, guild_id=None, filename=None, content=None, tts=False):
+ url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
+ form = aiohttp.FormData()
+
+ if content is not None:
+ form.add_field('content', str(content))
+
+ form.add_field('tts', 'true' if tts else 'false')
+ form.add_field('file', io.BytesIO(buffer), filename=filename, content_type='application/octet-stream')
+
+ return self.post(url, data=form, bucket='messages:' + str(guild_id))
+
+ def delete_message(self, channel_id, message_id, guild_id=None):
+ url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
+ bucket = '{}:{}'.format(_func_(), guild_id)
+ return self.delete(url, bucket=bucket)
+
+ def delete_messages(self, channel_id, message_ids, guild_id=None):
+ url = '{0.CHANNELS}/{1}/messages/bulk_delete'.format(self, channel_id)
+ payload = {
+ 'messages': message_ids
+ }
+ bucket = '{}:{}'.format(_func_(), guild_id)
+ return self.post(url, json=payload, bucket=bucket)
+
+ def edit_message(self, message_id, channel_id, content, *, guild_id=None):
+ url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
+ payload = {
+ 'content': str(content)
+ }
+ return self.patch(url, json=payload, bucket='messages:' + str(guild_id))
+
+
+ def logs_from(self, channel_id, limit, before=None, after=None):
+ url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
+ params = {
+ 'limit': limit
+ }
+
+ if before:
+ params['before'] = before
+ if after:
+ params['after'] = after
+
+ return self.get(url, params=params, bucket=_func_())
+
+ # Member management
+
+ def kick(self, user_id, guild_id):
+ url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
+ return self.delete(url, bucket=_func_())
+
+ def ban(self, user_id, guild_id, delete_message_days=1):
+ url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id)
+ params = {
+ 'delete-message-days': delete_message_days
+ }
+ return self.put(url, params=params, bucket=_func_())
+
+ def unban(self, user_id, guild_id):
+ url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id)
+ return self.delete(url, bucket=_func_())
+
+ def server_voice_state(self, user_id, guild_id, *, mute=False, deafen=False):
+ url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
+ payload = {
+ 'mute': mute,
+ 'deafen': deafen
+ }
+ return self.patch(url, json=payload, bucket='members:' + str(guild_id))
+
+ def edit_profile(self, password, username, avatar, **fields):
+ payload = {
+ 'password': password,
+ 'username': username,
+ 'avatar': avatar
+ }
+
+ if 'email' in fields:
+ payload['email'] = fields['email']
+
+ if 'new_password' in fields:
+ payload['new_password'] = fields['new_password']
+
+ return self.patch(self.ME, json=payload, bucket=_func_())
+
+ def change_my_nickname(self, guild_id, nickname):
+ url = '{0.GUILDS}/{1}/members/@me/nick'.format(self, guild_id)
+ payload = {
+ 'nick': nickname
+ }
+ bucket = '{}:{}'.format(_func_(), guild_id)
+ return self.patch(url, json=payload, bucket=bucket)
+
+ def change_nickname(self, guild_id, user_id, nickname):
+ url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
+ payload = {
+ 'nick': nickname
+ }
+ bucket = '{}:{}'.format(_func_(), guild_id)
+ return self.patch(url, json=payload, bucket=bucket)
+
+ # Channel management
+
+ def edit_channel(self, channel_id, **options):
+ url = '{0.CHANNELS}/{1}'.format(self, channel_id)
+
+ valid_keys = ('name', 'topic', 'bitrate', 'user_limit')
+ payload = {
+ k: v for k, v in options.items() if k in valid_keys
+ }
+
+ return self.patch(url, json=payload, bucket=_func_())
+
+ def create_channel(self, guild_id, name, channe_type):
+ url = '{0.GUILDS}/{1}/channels'.format(self, guild_id)
+ payload = {
+ 'name': name,
+ 'type': channe_type
+ }
+
+ return self.post(url, json=payload, bucket=_func_())
+
+ def delete_channel(self, channel_id):
+ url = '{0.CHANNELS}/{1}'.format(self, channel_id)
+ return self.delete(url, bucket=_func_())
+
+ # Server management
+
+ def leave_server(self, guild_id):
+ url = '{0.USERS}/@me/guilds/{1}'.format(self, guild_id)
+ return self.delete(url, bucket=_func_())
+
+ def delete_server(self, guild_id):
+ url = '{0.GUILDS}/{1}'.format(self, guild_id)
+ return self.delete(url, bucket=_func_())
+
+ def create_server(self, name, region, icon):
+ payload = {
+ 'name': name,
+ 'icon': icon,
+ 'region': region
+ }
+
+ return self.post(self.GUILDS, json=payload, bucket=_func_())
+
+ def edit_server(self, guild_id, **fields):
+ valid_keys = ('name', 'region', 'icon', 'afk_timeout', 'owner_id',
+ 'afk_channel_id', 'splash', 'verification_level')
+
+ payload = {
+ k: v for k, v in fields.items() if k in valid_keys
+ }
+
+ url = '{0.GUILDS}/{1}'.format(self, guild_id)
+ return self.patch(url, json=payload, bucket=_func_())
+
+ def get_bans(self, guild_id):
+ url = '{0.GUILDS}/{1}/bans'.format(self, guild_id)
+ return self.get(url, bucket=_func_())
+
+ # Invite management
+
+ def create_invite(self, channel_id, **options):
+ url = '{0.CHANNELS}/{1}/invites'.format(self, channel_id)
+ payload = {
+ 'max_age': options.get('max_age', 0),
+ 'max_uses': options.get('max_uses', 0),
+ 'temporary': options.get('temporary', False),
+ 'xkcdpass': options.get('xkcd', False)
+ }
+
+ return self.post(url, json=payload, bucket=_func_())
+
+ def get_invite(self, invite_id):
+ url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
+ return self.get(url, bucket=_func_())
+
+ def invites_from(self, guild_id):
+ url = '{0.GUILDS}/{1}/invites'.format(self, guild_id)
+ return self.get(url, bucket=_func_())
+
+ def accept_invite(self, invite_id):
+ url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
+ return self.post(url, bucket=_func_())
+
+ def delete_invite(self, invite_id):
+ url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
+ return self.delete(url, bucket=_func_())
+
+ # Role management
+
+ def edit_role(self, guild_id, role_id, **fields):
+ url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id)
+ valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable')
+ payload = {
+ k: v for k, v in fields.items() if k in valid_keys
+ }
+ return self.patch(url, json=payload, bucket='roles:' + str(guild_id))
+
+ def delete_role(self, guild_id, role_id):
+ url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id)
+ return self.delete(url, bucket=_func_())
+
+ def replace_roles(self, user_id, guild_id, role_ids):
+ url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
+ payload = {
+ 'roles': role_ids
+ }
+ return self.patch(url, json=payload, bucket='members:' + str(guild_id))
+
+ def create_role(self, guild_id):
+ url = '{0.GUILDS}/{1}/roles'.format(self, guild_id)
+ return self.post(url, bucket=_func_())
+
+ def edit_channel_permissions(self, channel_id, target, allow, deny, type):
+ url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target)
+ payload = {
+ 'id': target,
+ 'allow': allow,
+ 'deny': deny,
+ 'type': type
+ }
+ return self.put(url, json=payload, bucket=_func_())
+
+ def delete_channel_permissions(self, channel_id, target):
+ url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target)
+ return self.delete(url, bucket=_func_())
+
+ # Voice management
+
+ def move_member(self, user_id, guild_id, channel_id):
+ url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
+ payload = {
+ 'channel_id': channel_id
+ }
+ return self.patch(url, json=payload, bucket='members:' + str(guild_id))
+
+ # Misc
+
+ def application_info(self):
+ url = '{0.APPLICATIONS}/@me'.format(self)
+ return self.get(url, bucket=_func_())
+
+ @asyncio.coroutine
+ def get_gateway(self):
+ try:
+ data = yield from self.get(self.GATEWAY, bucket=_func_())
+ except HTTPException as e:
+ raise GatewayNotFound() from e
+ return data.get('url') + '?encoding=json&v=4'