diff options
Diffstat (limited to 'discord/iterators.py')
| -rw-r--r-- | discord/iterators.py | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/discord/iterators.py b/discord/iterators.py index c3ac5643..86e21931 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -31,6 +31,7 @@ import datetime from .errors import NoMoreItems from .utils import time_snowflake, maybe_coroutine from .object import Object +from .audit_logs import AuditLogEntry PY35 = sys.version_info >= (3, 5) @@ -369,3 +370,111 @@ class HistoryIterator(_AsyncIterator): self.around = None return data return [] + +class AuditLogIterator(_AsyncIterator): + def __init__(self, guild, limit=None, before=None, after=None, reverse=None, user_id=None, action_type=None): + if isinstance(before, datetime.datetime): + before = Object(id=time_snowflake(before, high=False)) + if isinstance(after, datetime.datetime): + after = Object(id=time_snowflake(after, high=True)) + + + self.guild = guild + self.loop = guild._state.loop + self.request = guild._state.http.get_audit_logs + self.limit = limit + self.before = before + self.user_id = user_id + self.action_type = action_type + self.after = after + self._users = {} + self._state = guild._state + + if reverse is None: + self.reverse = after is not None + else: + self.reverse = reverse + + self._filter = None # entry dict -> bool + + self.entries = asyncio.Queue(loop=self.loop) + + if self.before and self.after: + if self.reverse: + self._strategy = self._after_strategy + self._filter = lambda m: int(m['id']) < self.before.id + else: + self._strategy = self._before_strategy + self._filter = lambda m: int(m['id']) > self.after.id + elif self.after: + self._strategy = self._after_strategy + else: + self._strategy = self._before_strategy + + @asyncio.coroutine + def _before_strategy(self, retrieve): + before = self.before.id if self.before else None + data = yield from self.request(self.guild.id, limit=retrieve, user_id=self.user_id, + action_type=self.action_type, before=before) + if len(data): + if self.limit is not None: + self.limit -= retrieve + self.before = Object(id=int(data['audit_log_entries'][-1]['id'])) + return data + + @asyncio.coroutine + def _after_strategy(self, retrieve): + after = self.after.id if self.after else None + data = yield from self.request(self.guild.id, limit=retrieve, user_id=self.user_id, + action_type=self.action_type, after=after) + if len(data): + if self.limit is not None: + self.limit -= retrieve + self.after = Object(id=int(data['audit_log_entries'][0]['id'])) + return data + + @asyncio.coroutine + def get(self): + if self.entries.empty(): + yield from self._fill() + + try: + return self.entries.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreItems() + + def _get_retrieve(self): + l = self.limit + if l is None: + r = 100 + elif l <= 100: + r = l + else: + r = 100 + + self.retrieve = r + return r > 0 + + @asyncio.coroutine + def _fill(self): + from .user import User + + if self._get_retrieve(): + data = yield from self._strategy(self.retrieve) + users = data.get('users', []) + data = data.get('audit_log_entries', []) + + if self.limit is None and len(data) < 100: + self.limit = 0 # terminate the infinite loop + + if self.reverse: + data = reversed(data) + if self._filter: + data = filter(self._filter, data) + + for user in users: + u = User(data=user, state=self._state) + self._users[u.id] = u + + for element in data: + yield from self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) |