aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2015-12-13 01:42:15 -0500
committerRapptz <[email protected]>2015-12-13 02:12:06 -0500
commit29ea58d0080e0e6f4931fa9cb7fc4c04a2b248df (patch)
treeed8d4bb42df6ac50aa489428a00a1b3497726077
parentClient.login no longer calls resp.json() (diff)
downloaddiscord.py-29ea58d0080e0e6f4931fa9cb7fc4c04a2b248df.tar.xz
discord.py-29ea58d0080e0e6f4931fa9cb7fc4c04a2b248df.zip
Implement cache of login credentials.
Also add endpoints.ME to easily access the @me endpoint.
-rw-r--r--discord/client.py52
-rw-r--r--discord/endpoints.py1
2 files changed, 50 insertions, 3 deletions
diff --git a/discord/client.py b/discord/client.py
index 42657f69..75b3c9a3 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -45,6 +45,7 @@ import websockets
import logging, traceback
import sys, time, re, json
+import tempfile, os, hashlib
log = logging.getLogger(__name__)
request_logging_format = '{method} {response.url} has returned {response.status}'
@@ -68,6 +69,10 @@ class Client:
loop : Optional[event loop].
The `event loop`_ to use for asynchronous operations. Defaults to ``None``,
in which case the default event loop is used via ``asyncio.get_event_loop()``.
+ cache_auth : Optional[bool]
+ Indicates if :meth:`login` should cache the authentication tokens. Defaults
+ to ``True``. The method in which the cache is written is done by writing to
+ disk to a temporary directory.
Attributes
-----------
@@ -101,6 +106,7 @@ class Client:
self.voice = None
self.loop = asyncio.get_event_loop() if loop is None else loop
self._listeners = []
+ self.cache_auth = options.get('cache_auth', True)
max_messages = options.get('max_messages')
if max_messages is None or max_messages < 100:
@@ -131,6 +137,10 @@ class Client:
# internals
+ def _get_cache_filename(self, email):
+ filename = hashlib.md5(email.encode('utf-8')).hexdigest()
+ return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
+
def handle_message(self, message):
removed = []
for i, (condition, future) in enumerate(self._listeners):
@@ -510,6 +520,31 @@ class Client:
usually when it isn't 200 or the known incorrect credentials
passing status code.
"""
+
+ # attempt to read the token from cache
+ if self.cache_auth:
+ try:
+ log.info('attempting to login via cache')
+ cache_file = self._get_cache_filename(email)
+ with open(cache_file, 'r') as f:
+ log.info('login cache file found')
+ self.token = f.read()
+ self.headers['authorization'] = self.token
+
+ check = yield from self.session.get(endpoints.ME, headers=self.headers)
+ if check.status == 200:
+ log.info('login cache token check succeeded')
+ yield from check.release()
+ self._is_logged_in = True
+ return
+
+ # at this point our check failed
+ # so we have to login and get the proper token and then
+ # redo the cache
+ except OSError as e:
+ log.info('a problem occurred while opening login cache')
+ pass # file not found et al
+
payload = {
'email': email,
'password': password
@@ -531,6 +566,18 @@ class Client:
self.headers['authorization'] = self.token
self._is_logged_in = True
+ # since we went through all this trouble
+ # let's make sure we don't have to do it again
+ if self.cache_auth:
+ try:
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ with open(cache_file, 'w') as f:
+ log.info('updating login cache')
+ f.write(self.token)
+ except OSError:
+ log.info('a problem occurred while updating the login cache')
+ pass
+
@asyncio.coroutine
def logout(self):
"""|coro|
@@ -683,7 +730,7 @@ class Client:
'recipient_id': user.id
}
- url = '{}/@me/channels'.format(endpoints.USERS)
+ url = '{}/channels'.format(endpoints.ME)
r = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=r))
yield from utils._verify_successful_response(r)
@@ -1216,8 +1263,7 @@ class Client:
'avatar': avatar
}
- url = '{0}/@me'.format(endpoints.USERS)
- r = yield from self.session.patch(url, headers=self.headers, data=utils.to_json(payload))
+ r = yield from self.session.patch(endpoints.ME, headers=self.headers, data=utils.to_json(payload))
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
diff --git a/discord/endpoints.py b/discord/endpoints.py
index 7266b775..d3e7f197 100644
--- a/discord/endpoints.py
+++ b/discord/endpoints.py
@@ -28,6 +28,7 @@ BASE = 'https://discordapp.com'
API_BASE = BASE + '/api'
GATEWAY = API_BASE + '/gateway'
USERS = API_BASE + '/users'
+ME = USERS + '/@me'
REGISTER = API_BASE + '/auth/register'
LOGIN = API_BASE + '/auth/login'
LOGOUT = API_BASE + '/auth/logout'