aboutsummaryrefslogtreecommitdiff
path: root/discord/iterators.py
diff options
context:
space:
mode:
Diffstat (limited to 'discord/iterators.py')
-rw-r--r--discord/iterators.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/discord/iterators.py b/discord/iterators.py
index c3ac5643..86e21931 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -31,6 +31,7 @@ import datetime
from .errors import NoMoreItems
from .utils import time_snowflake, maybe_coroutine
from .object import Object
+from .audit_logs import AuditLogEntry
PY35 = sys.version_info >= (3, 5)
@@ -369,3 +370,111 @@ class HistoryIterator(_AsyncIterator):
self.around = None
return data
return []
+
+class AuditLogIterator(_AsyncIterator):
+ def __init__(self, guild, limit=None, before=None, after=None, reverse=None, user_id=None, action_type=None):
+ if isinstance(before, datetime.datetime):
+ before = Object(id=time_snowflake(before, high=False))
+ if isinstance(after, datetime.datetime):
+ after = Object(id=time_snowflake(after, high=True))
+
+
+ self.guild = guild
+ self.loop = guild._state.loop
+ self.request = guild._state.http.get_audit_logs
+ self.limit = limit
+ self.before = before
+ self.user_id = user_id
+ self.action_type = action_type
+ self.after = after
+ self._users = {}
+ self._state = guild._state
+
+ if reverse is None:
+ self.reverse = after is not None
+ else:
+ self.reverse = reverse
+
+ self._filter = None # entry dict -> bool
+
+ self.entries = asyncio.Queue(loop=self.loop)
+
+ if self.before and self.after:
+ if self.reverse:
+ self._strategy = self._after_strategy
+ self._filter = lambda m: int(m['id']) < self.before.id
+ else:
+ self._strategy = self._before_strategy
+ self._filter = lambda m: int(m['id']) > self.after.id
+ elif self.after:
+ self._strategy = self._after_strategy
+ else:
+ self._strategy = self._before_strategy
+
+ @asyncio.coroutine
+ def _before_strategy(self, retrieve):
+ before = self.before.id if self.before else None
+ data = yield from self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
+ action_type=self.action_type, before=before)
+ if len(data):
+ if self.limit is not None:
+ self.limit -= retrieve
+ self.before = Object(id=int(data['audit_log_entries'][-1]['id']))
+ return data
+
+ @asyncio.coroutine
+ def _after_strategy(self, retrieve):
+ after = self.after.id if self.after else None
+ data = yield from self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
+ action_type=self.action_type, after=after)
+ if len(data):
+ if self.limit is not None:
+ self.limit -= retrieve
+ self.after = Object(id=int(data['audit_log_entries'][0]['id']))
+ return data
+
+ @asyncio.coroutine
+ def get(self):
+ if self.entries.empty():
+ yield from self._fill()
+
+ try:
+ return self.entries.get_nowait()
+ except asyncio.QueueEmpty:
+ raise NoMoreItems()
+
+ def _get_retrieve(self):
+ l = self.limit
+ if l is None:
+ r = 100
+ elif l <= 100:
+ r = l
+ else:
+ r = 100
+
+ self.retrieve = r
+ return r > 0
+
+ @asyncio.coroutine
+ def _fill(self):
+ from .user import User
+
+ if self._get_retrieve():
+ data = yield from self._strategy(self.retrieve)
+ users = data.get('users', [])
+ data = data.get('audit_log_entries', [])
+
+ if self.limit is None and len(data) < 100:
+ self.limit = 0 # terminate the infinite loop
+
+ if self.reverse:
+ data = reversed(data)
+ if self._filter:
+ data = filter(self._filter, data)
+
+ for user in users:
+ u = User(data=user, state=self._state)
+ self._users[u.id] = u
+
+ for element in data:
+ yield from self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))