aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2016-10-17 01:10:22 -0400
committerRapptz <[email protected]>2017-01-03 09:51:52 -0500
commit53ab2631252bf0977446d762f07b3821edb151ee (patch)
treeabb8a2e7a966aadb22df8a3ca2220b646eae3765
parent[commands] Bot skip check now works with the new __eq__ changes. (diff)
downloaddiscord.py-53ab2631252bf0977446d762f07b3821edb151ee.tar.xz
discord.py-53ab2631252bf0977446d762f07b3821edb151ee.zip
Split channel types.
This splits them into the following: * DMChannel * GroupChannel * VoiceChannel * TextChannel This also makes the channels "stateful".
-rw-r--r--discord/__init__.py2
-rw-r--r--discord/abc.py277
-rw-r--r--discord/calls.py4
-rw-r--r--discord/channel.py468
-rw-r--r--discord/client.py8
-rw-r--r--discord/errors.py6
-rw-r--r--discord/iterators.py70
-rw-r--r--discord/message.py4
-rw-r--r--discord/server.py10
-rw-r--r--discord/state.py40
10 files changed, 715 insertions, 174 deletions
diff --git a/discord/__init__.py b/discord/__init__.py
index 1fd3d83c..55427cca 100644
--- a/discord/__init__.py
+++ b/discord/__init__.py
@@ -21,7 +21,7 @@ from .client import Client, AppInfo, ChannelPermissions
from .user import User
from .game import Game
from .emoji import Emoji
-from .channel import Channel, PrivateChannel
+from .channel import *
from .server import Server
from .member import Member, VoiceState
from .message import Message
diff --git a/discord/abc.py b/discord/abc.py
index 2bda266e..0b42b0e8 100644
--- a/discord/abc.py
+++ b/discord/abc.py
@@ -25,6 +25,12 @@ DEALINGS IN THE SOFTWARE.
"""
import abc
+import io
+import os
+import asyncio
+
+from .message import Message
+from .iterators import LogsFromIterator
class Snowflake(metaclass=abc.ABCMeta):
__slots__ = ()
@@ -75,3 +81,274 @@ class User(metaclass=abc.ABCMeta):
return NotImplemented
return True
return NotImplemented
+
+class GuildChannel(metaclass=abc.ABCMeta):
+ __slots__ = ()
+
+ @property
+ @abc.abstractmethod
+ def mention(self):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def overwrites_for(self, obj):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def permissions_for(self, user):
+ raise NotImplementedError
+
+ @classmethod
+ def __subclasshook__(cls, C):
+ if cls is GuildChannel:
+ if Snowflake.__subclasshook__(C) is NotImplemented:
+ return NotImplemented
+
+ mro = C.__mro__
+ for attr in ('name', 'server', 'overwrites_for', 'permissions_for', 'mention'):
+ for base in mro:
+ if attr in base.__dict__:
+ break
+ else:
+ return NotImplemented
+ return True
+ return NotImplemented
+
+class PrivateChannel(metaclass=abc.ABCMeta):
+ __slots__ = ()
+
+ @classmethod
+ def __subclasshook__(cls, C):
+ if cls is PrivateChannel:
+ if Snowflake.__subclasshook__(C) is NotImplemented:
+ return NotImplemented
+
+ mro = C.__mro__
+ for base in mro:
+ if 'me' in base.__dict__:
+ return True
+ return NotImplemented
+ return NotImplemented
+
+class MessageChannel(metaclass=abc.ABCMeta):
+ __slots__ = ()
+
+ @abc.abstractmethod
+ def _get_destination(self):
+ raise NotImplementedError
+
+ @asyncio.coroutine
+ def send_message(self, content, *, tts=False):
+ """|coro|
+
+ Sends a message to the channel with the content given.
+
+ The content must be a type that can convert to a string through ``str(content)``.
+
+ Parameters
+ ------------
+ content
+ The content of the message to send.
+ tts: bool
+ Indicates if the message should be sent using text-to-speech.
+
+ Raises
+ --------
+ HTTPException
+ Sending the message failed.
+ Forbidden
+ You do not have the proper permissions to send the message.
+
+ Returns
+ ---------
+ :class:`Message`
+ The message that was sent.
+ """
+
+ channel_id, guild_id = self._get_destination()
+ content = str(content)
+ data = yield from self._state.http.send_message(channel_id, content, guild_id=guild_id, tts=tts)
+ return Message(channel=self, state=self._state, data=data)
+
+ @asyncio.coroutine
+ def send_typing(self):
+ """|coro|
+
+ Send a *typing* status to the channel.
+
+ *Typing* status will go away after 10 seconds, or after a message is sent.
+ """
+
+ channel_id, _ = self._get_destination()
+ yield from self._state.http.send_typing(channel_id)
+
+ @asyncio.coroutine
+ def upload(self, fp, *, filename=None, content=None, tts=False):
+ """|coro|
+
+ Sends a message to the channel with the file given.
+
+ The ``fp`` parameter should be either a string denoting the location for a
+ file or a *file-like object*. The *file-like object* passed is **not closed**
+ at the end of execution. You are responsible for closing it yourself.
+
+ .. note::
+
+ If the file-like object passed is opened via ``open`` then the modes
+ 'rb' should be used.
+
+ The ``filename`` parameter is the filename of the file.
+ If this is not given then it defaults to ``fp.name`` or if ``fp`` is a string
+ then the ``filename`` will default to the string given. You can overwrite
+ this value by passing this in.
+
+ Parameters
+ ------------
+ fp
+ The *file-like object* or file path to send.
+ filename: str
+ The filename of the file. Defaults to ``fp.name`` if it's available.
+ content: str
+ The content of the message to send along with the file. This is
+ forced into a string by a ``str(content)`` call.
+ tts: bool
+ If the content of the message should be sent with TTS enabled.
+
+ Raises
+ -------
+ HTTPException
+ Sending the file failed.
+
+ Returns
+ --------
+ :class:`Message`
+ The message sent.
+ """
+
+ channel_id, guild_id = self._get_destination()
+
+ try:
+ with open(fp, 'rb') as f:
+ buffer = io.BytesIO(f.read())
+ if filename is None:
+ _, filename = os.path.split(fp)
+ except TypeError:
+ buffer = fp
+
+ state = self._state
+ data = yield from state.http.send_file(channel_id, buffer, guild_id=guild_id,
+ filename=filename, content=content, tts=tts)
+
+ return Message(channel=self, state=state, data=data)
+
+ @asyncio.coroutine
+ def get_message(self, id):
+ """|coro|
+
+ Retrieves a single :class:`Message` from a channel.
+
+ This can only be used by bot accounts.
+
+ Parameters
+ ------------
+ id: int
+ The message ID to look for.
+
+ Returns
+ --------
+ :class:`Message`
+ The message asked for.
+
+ Raises
+ --------
+ NotFound
+ The specified message was not found.
+ Forbidden
+ You do not have the permissions required to get a message.
+ HTTPException
+ Retrieving the message failed.
+ """
+
+ data = yield from self._state.http.get_message(self.id, id)
+ return Message(channel=self, state=self._state, data=data)
+
+ @asyncio.coroutine
+ def pins(self):
+ """|coro|
+
+ Returns a list of :class:`Message` that are currently pinned.
+
+ Raises
+ -------
+ HTTPException
+ Retrieving the pinned messages failed.
+ """
+
+ state = self._state
+ data = yield from state.http.pins_from(self.id)
+ return [Message(channel=self, state=state, data=m) for m in data]
+
+ def history(self, *, limit=100, before=None, after=None, around=None, reverse=None):
+ """Return an async iterator that enables receiving the channel's message history.
+
+ You must have Read Message History permissions to use this.
+
+ All parameters are optional.
+
+ Parameters
+ -----------
+ limit: int
+ The number of messages to retrieve.
+ before: :class:`Message` or `datetime`
+ Retrieve messages before this date or message.
+ If a date is provided it must be a timezone-naive datetime representing UTC time.
+ after: :class:`Message` or `datetime`
+ Retrieve messages after this date or message.
+ If a date is provided it must be a timezone-naive datetime representing UTC time.
+ around: :class:`Message` or `datetime`
+ Retrieve messages around this date or message.
+ If a date is provided it must be a timezone-naive datetime representing UTC time.
+ When using this argument, the maximum limit is 101. Note that if the limit is an
+ even number then this will return at most limit + 1 messages.
+ reverse: bool
+ If set to true, return messages in oldest->newest order. If unspecified,
+ this defaults to ``False`` for most cases. However if passing in a
+ ``after`` parameter then this is set to ``True``. This avoids getting messages
+ out of order in the ``after`` case.
+
+ Raises
+ ------
+ Forbidden
+ You do not have permissions to get channel message history.
+ HTTPException
+ The request to get message history failed.
+
+ Yields
+ -------
+ :class:`Message`
+ The message with the message data parsed.
+
+ Examples
+ ---------
+
+ Usage ::
+
+ counter = 0
+ async for message in channel.history(limit=200):
+ if message.author == client.user:
+ counter += 1
+
+ Python 3.4 Usage ::
+
+ count = 0
+ iterator = channel.history(limit=200)
+ while True:
+ try:
+ message = yield from iterator.get()
+ except discord.NoMoreMessages:
+ break
+ else:
+ if message.author == client.user:
+ counter += 1
+ """
+ return LogsFromIterator(self, limit=limit, before=before, after=after, around=around, reverse=reverse)
diff --git a/discord/calls.py b/discord/calls.py
index 94c55a14..0925f713 100644
--- a/discord/calls.py
+++ b/discord/calls.py
@@ -57,7 +57,7 @@ class CallMessage:
@property
def channel(self):
- """:class:`PrivateChannel`\: The private channel associated with this message."""
+ """:class:`GroupChannel`\: The private channel associated with this message."""
return self.message.channel
@property
@@ -131,7 +131,7 @@ class GroupCall:
@property
def channel(self):
- """:class:`PrivateChannel`\: Returns the channel the group call is in."""
+ """:class:`GroupChannel`\: Returns the channel the group call is in."""
return self.call.channel
def voice_state_for(self, user):
diff --git a/discord/channel.py b/discord/channel.py
index f79a2d5d..b1961dd4 100644
--- a/discord/channel.py
+++ b/discord/channel.py
@@ -23,8 +23,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
-import copy
-from . import utils
+from . import utils, abc
from .permissions import Permissions, PermissionOverwrite
from .enums import ChannelType, try_enum
from collections import namedtuple
@@ -33,82 +32,54 @@ from .role import Role
from .user import User
from .member import Member
+import copy
+import asyncio
+
+__all__ = ('TextChannel', 'VoiceChannel', 'DMChannel', 'GroupChannel', '_channel_factory')
+
Overwrites = namedtuple('Overwrites', 'id allow deny type')
-class Channel(Hashable):
- """Represents a Discord server channel.
+class CommonGuildChannel(Hashable):
+ __slots__ = ()
- Supported Operations:
+ def __str__(self):
+ return self.name
- +-----------+---------------------------------------+
- | Operation | Description |
- +===========+=======================================+
- | x == y | Checks if two channels are equal. |
- +-----------+---------------------------------------+
- | x != y | Checks if two channels are not equal. |
- +-----------+---------------------------------------+
- | hash(x) | Returns the channel's hash. |
- +-----------+---------------------------------------+
- | str(x) | Returns the channel's name. |
- +-----------+---------------------------------------+
+ @asyncio.coroutine
+ def _move(self, position):
+ if position < 0:
+ raise InvalidArgument('Channel position cannot be less than 0.')
- Attributes
- -----------
- name: str
- The channel name.
- server: :class:`Server`
- The server the channel belongs to.
- id: int
- The channel ID.
- topic: Optional[str]
- The channel's topic. None if it doesn't exist.
- is_private: bool
- ``True`` if the channel is a private channel (i.e. PM). ``False`` in this case.
- position: int
- The position in the channel list. This is a number that starts at 0. e.g. the
- top channel is position 0. The position varies depending on being a voice channel
- or a text channel, so a 0 position voice channel is on top of the voice channel
- list.
- type: :class:`ChannelType`
- The channel type. There is a chance that the type will be ``str`` if
- the channel type is not within the ones recognised by the enumerator.
- bitrate: int
- The channel's preferred audio bitrate in bits per second.
- voice_members
- A list of :class:`Members` that are currently inside this voice channel.
- If :attr:`type` is not :attr:`ChannelType.voice` then this is always an empty array.
- user_limit: int
- The channel's limit for number of members that can be in a voice channel.
- """
+ http = self._state.http
+ url = '{0}/{1.server.id}/channels'.format(http.GUILDS, self)
+ channels = [c for c in self.server.channels if isinstance(c, type(self))]
- __slots__ = ( 'voice_members', 'name', 'id', 'server', 'topic',
- 'type', 'bitrate', 'user_limit', '_state', 'position',
- '_permission_overwrites' )
+ if position >= len(channels):
+ raise InvalidArgument('Channel position cannot be greater than {}'.format(len(channels) - 1))
- def __init__(self, *, state, server, data):
- self._state = state
- self.id = int(data['id'])
- self._update(server, data)
- self.voice_members = []
+ channels.sort(key=lambda c: c.position)
- def __str__(self):
- return self.name
+ try:
+ # remove ourselves from the channel list
+ channels.remove(self)
+ except ValueError:
+ # not there somehow lol
+ return
+ else:
+ # add ourselves at our designated position
+ channels.insert(position, self)
- def _update(self, server, data):
- self.server = server
- self.name = data['name']
- self.topic = data.get('topic')
- self.position = data['position']
- self.bitrate = data.get('bitrate')
- self.type = data['type']
- self.user_limit = data.get('user_limit')
- self._permission_overwrites = []
+ payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)]
+ yield from http.patch(url, json=payload, bucket='move_channel')
+
+ def _fill_overwrites(self, data):
+ self._overwrites = []
everyone_index = 0
everyone_id = self.server.id
for index, overridden in enumerate(data.get('permission_overwrites', [])):
overridden_id = int(overridden.pop('id'))
- self._permission_overwrites.append(Overwrites(id=overridden_id, **overridden))
+ self._overwrites.append(Overwrites(id=overridden_id, **overridden))
if overridden['type'] == 'member':
continue
@@ -122,7 +93,7 @@ class Channel(Hashable):
everyone_index = index
# do the swap
- tmp = self._permission_overwrites
+ tmp = self._overwrites
if tmp:
tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index]
@@ -131,7 +102,7 @@ class Channel(Hashable):
"""Returns a list of :class:`Roles` that have been overridden from
their default values in the :attr:`Server.roles` attribute."""
ret = []
- for overwrite in filter(lambda o: o.type == 'role', self._permission_overwrites):
+ for overwrite in filter(lambda o: o.type == 'role', self._overwrites):
role = utils.get(self.server.roles, id=overwrite.id)
if role is None:
continue
@@ -147,10 +118,6 @@ class Channel(Hashable):
return self.server.id == self.id
@property
- def is_private(self):
- return False
-
- @property
def mention(self):
"""str : The string that allows you to mention the channel."""
return '<#{0.id}>'.format(self)
@@ -182,7 +149,7 @@ class Channel(Hashable):
else:
predicate = lambda p: True
- for overwrite in filter(predicate, self._permission_overwrites):
+ for overwrite in filter(predicate, self._overwrites):
if overwrite.id == obj.id:
allow = Permissions(overwrite.allow)
deny = Permissions(overwrite.deny)
@@ -276,7 +243,7 @@ class Channel(Hashable):
allows = 0
# Apply channel specific role permission overwrites
- for overwrite in self._permission_overwrites:
+ for overwrite in self._overwrites:
if overwrite.type == 'role' and overwrite.id in member_role_ids:
denies |= overwrite.deny
allows |= overwrite.allow
@@ -284,7 +251,7 @@ class Channel(Hashable):
base.handle_overwrite(allow=allows, deny=denies)
# Apply member specific permission overwrites
- for overwrite in self._permission_overwrites:
+ for overwrite in self._overwrites:
if overwrite.type == 'member' and overwrite.id == member.id:
base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny)
break
@@ -307,14 +274,286 @@ class Channel(Hashable):
base.value &= ~denied.value
# text channels do not have voice related permissions
- if self.type is ChannelType.text:
+ if isinstance(self, TextChannel):
denied = Permissions.voice()
base.value &= ~denied.value
return base
-class PrivateChannel(Hashable):
- """Represents a Discord private channel.
+ @asyncio.coroutine
+ def delete(self):
+ """|coro|
+
+ Deletes the channel.
+
+ You must have Manage Channel permission to use this.
+
+ Raises
+ -------
+ Forbidden
+ You do not have proper permissions to delete the channel.
+ NotFound
+ The channel was not found or was already deleted.
+ HTTPException
+ Deleting the channel failed.
+ """
+ yield from self._state.http.delete_channel(self.id)
+
+class TextChannel(abc.MessageChannel, CommonGuildChannel):
+ """Represents a Discord server text channel.
+
+ Supported Operations:
+
+ +-----------+---------------------------------------+
+ | Operation | Description |
+ +===========+=======================================+
+ | x == y | Checks if two channels are equal. |
+ +-----------+---------------------------------------+
+ | x != y | Checks if two channels are not equal. |
+ +-----------+---------------------------------------+
+ | hash(x) | Returns the channel's hash. |
+ +-----------+---------------------------------------+
+ | str(x) | Returns the channel's name. |
+ +-----------+---------------------------------------+
+
+ Attributes
+ -----------
+ name: str
+ The channel name.
+ server: :class:`Server`
+ The server the channel belongs to.
+ id: int
+ The channel ID.
+ topic: Optional[str]
+ The channel's topic. None if it doesn't exist.
+ position: int
+ The position in the channel list. This is a number that starts at 0. e.g. the
+ top channel is position 0.
+ """
+
+ __slots__ = ( 'name', 'id', 'server', 'topic', '_state',
+ 'position', '_overwrites' )
+
+ def __init__(self, *, state, server, data):
+ self._state = state
+ self.id = int(data['id'])
+ self._update(server, data)
+
+ def _update(self, server, data):
+ self.server = server
+ self.name = data['name']
+ self.topic = data.get('topic')
+ self.position = data['position']
+ self._fill_overwrites(data)
+
+ def _get_destination(self):
+ return self.id, self.server.id
+
+ @asyncio.coroutine
+ def edit(self, **options):
+ """|coro|
+
+ Edits the channel.
+
+ You must have the Manage Channel permission to use this.
+
+ Parameters
+ ----------
+ name: str
+ The new channel name.
+ topic: str
+ The new channel's topic.
+ position: int
+ The new channel's position.
+
+ Raises
+ ------
+ InvalidArgument
+ If position is less than 0 or greater than the number of channels.
+ Forbidden
+ You do not have permissions to edit the channel.
+ HTTPException
+ Editing the channel failed.
+ """
+ try:
+ position = options.pop('position')
+ except KeyError:
+ pass
+ else:
+ yield from self._move(position)
+ self.position = position
+
+ if options:
+ data = yield from self._state.http.edit_channel(self.id, **options)
+ self._update(self.server, data)
+
+class VoiceChannel(CommonGuildChannel):
+ """Represents a Discord server voice channel.
+
+ Supported Operations:
+
+ +-----------+---------------------------------------+
+ | Operation | Description |
+ +===========+=======================================+
+ | x == y | Checks if two channels are equal. |
+ +-----------+---------------------------------------+
+ | x != y | Checks if two channels are not equal. |
+ +-----------+---------------------------------------+
+ | hash(x) | Returns the channel's hash. |
+ +-----------+---------------------------------------+
+ | str(x) | Returns the channel's name. |
+ +-----------+---------------------------------------+
+
+ Attributes
+ -----------
+ name: str
+ The channel name.
+ server: :class:`Server`
+ The server the channel belongs to.
+ id: int
+ The channel ID.
+ position: int
+ The position in the channel list. This is a number that starts at 0. e.g. the
+ top channel is position 0.
+ bitrate: int
+ The channel's preferred audio bitrate in bits per second.
+ voice_members
+ A list of :class:`Members` that are currently inside this voice channel.
+ user_limit: int
+ The channel's limit for number of members that can be in a voice channel.
+ """
+
+ __slots__ = ( 'voice_members', 'name', 'id', 'server', 'bitrate',
+ 'user_limit', '_state', 'position', '_overwrites' )
+
+ def __init__(self, *, state, server, data):
+ self._state = state
+ self.id = int(data['id'])
+ self._update(server, data)
+ self.voice_members = []
+
+ def _update(self, server, data):
+ self.server = server
+ self.name = data['name']
+ self.position = data['position']
+ self.bitrate = data.get('bitrate')
+ self.user_limit = data.get('user_limit')
+ self._fill_overwrites(data)
+
+ @asyncio.coroutine
+ def edit(self, **options):
+ """|coro|
+
+ Edits the channel.
+
+ You must have the Manage Channel permission to use this.
+
+ Parameters
+ ----------
+ bitrate: int
+ The new channel's bitrate.
+ user_limit: int
+ The new channel's user limit.
+ position: int
+ The new channel's position.
+
+ Raises
+ ------
+ Forbidden
+ You do not have permissions to edit the channel.
+ HTTPException
+ Editing the channel failed.
+ """
+
+ try:
+ position = options.pop('position')
+ except KeyError:
+ pass
+ else:
+ yield from self._move(position)
+ self.position = position
+
+ if options:
+ data = yield from self._state.http.edit_channel(self.id, **options)
+ self._update(self.server, data)
+
+class DMChannel(abc.MessageChannel, Hashable):
+ """Represents a Discord direct message channel.
+
+ Supported Operations:
+
+ +-----------+-------------------------------------------------+
+ | Operation | Description |
+ +===========+=================================================+
+ | x == y | Checks if two channels are equal. |
+ +-----------+-------------------------------------------------+
+ | x != y | Checks if two channels are not equal. |
+ +-----------+-------------------------------------------------+
+ | hash(x) | Returns the channel's hash. |
+ +-----------+-------------------------------------------------+
+ | str(x) | Returns a string representation of the channel |
+ +-----------+-------------------------------------------------+
+
+ Attributes
+ ----------
+ recipient: :class:`User`
+ The user you are participating with in the direct message channel.
+ me: :class:`User`
+ The user presenting yourself.
+ id: int
+ The direct message channel ID.
+ """
+
+ __slots__ = ('id', 'recipient', 'me', '_state')
+
+ def __init__(self, *, me, state, data):
+ self._state = state
+ self.recipient = state.try_insert_user(data['recipients'][0])
+ self.me = me
+ self.id = int(data['id'])
+
+ def _get_destination(self):
+ return self.id, None
+
+ def __str__(self):
+ return 'Direct Message with %s' % self.recipient
+
+ @property
+ def created_at(self):
+ """Returns the direct message channel's creation time in UTC."""
+ return utils.snowflake_time(self.id)
+
+ def permissions_for(self, user=None):
+ """Handles permission resolution for a :class:`User`.
+
+ This function is there for compatibility with other channel types.
+
+ Actual direct messages do not really have the concept of permissions.
+
+ This returns all the Text related permissions set to true except:
+
+ - send_tts_messages: You cannot send TTS messages in a DM.
+ - manage_messages: You cannot delete others messages in a DM.
+
+ Parameters
+ -----------
+ user: :class:`User`
+ The user to check permissions for. This parameter is ignored
+ but kept for compatibility.
+
+ Returns
+ --------
+ :class:`Permissions`
+ The resolved permissions.
+ """
+
+ base = Permissions.text()
+ base.send_tts_messages = False
+ base.manage_messages = False
+ return base
+
+class GroupChannel(abc.MessageChannel, Hashable):
+ """Represents a Discord group channel.
Supported Operations:
@@ -333,50 +572,42 @@ class PrivateChannel(Hashable):
Attributes
----------
recipients: list of :class:`User`
- The users you are participating with in the private channel.
+ The users you are participating with in the group channel.
me: :class:`User`
The user presenting yourself.
id: int
- The private channel ID.
- is_private: bool
- ``True`` if the channel is a private channel (i.e. PM). ``True`` in this case.
- type: :class:`ChannelType`
- The type of private channel.
- owner: Optional[:class:`User`]
- The user that owns the private channel. If the channel type is not
- :attr:`ChannelType.group` then this is always ``None``.
+ The group channel ID.
+ owner: :class:`User`
+ The user that owns the group channel.
icon: Optional[str]
- The private channel's icon hash. If the channel type is not
- :attr:`ChannelType.group` then this is always ``None``.
+ The group channel's icon hash if provided.
name: Optional[str]
- The private channel's name. If the channel type is not
- :attr:`ChannelType.group` then this is always ``None``.
+ The group channel's name if provided.
"""
- __slots__ = ('id', 'recipients', 'type', 'owner', 'icon', 'name', 'me', '_state')
+ __slots__ = ('id', 'recipients', 'owner', 'icon', 'name', 'me', '_state')
def __init__(self, *, me, state, data):
self._state = state
self.recipients = [state.try_insert_user(u) for u in data['recipients']]
self.id = int(data['id'])
self.me = me
- self.type = try_enum(ChannelType, data['type'])
self._update_group(data)
def _update_group(self, data):
owner_id = utils._get_as_snowflake(data, 'owner_id')
self.icon = data.get('icon')
self.name = data.get('name')
- self.owner = utils.find(lambda u: u.id == owner_id, self.recipients)
- @property
- def is_private(self):
- return True
+ if owner_id == self.me.id:
+ self.owner = self.me
+ else:
+ self.owner = utils.find(lambda u: u.id == owner_id, self.recipients)
- def __str__(self):
- if self.type is ChannelType.private:
- return 'Direct Message with {0.name}'.format(self.user)
+ def _get_destination(self):
+ return self.id, None
+ def __str__(self):
if self.name:
return self.name
@@ -386,15 +617,6 @@ class PrivateChannel(Hashable):
return ', '.join(map(lambda x: x.name, self.recipients))
@property
- def user(self):
- """A property that returns the first recipient of the private channel.
-
- This is mainly for compatibility and ease of use with old style private
- channels that had a single recipient.
- """
- return self.recipients[0]
-
- @property
def icon_url(self):
"""Returns the channel's icon URL if available or an empty string otherwise."""
if self.icon is None:
@@ -404,27 +626,26 @@ class PrivateChannel(Hashable):
@property
def created_at(self):
- """Returns the private channel's creation time in UTC."""
+ """Returns the channel's creation time in UTC."""
return utils.snowflake_time(self.id)
def permissions_for(self, user):
"""Handles permission resolution for a :class:`User`.
- This function is there for compatibility with :class:`Channel`.
+ This function is there for compatibility with other channel types.
- Actual private messages do not really have the concept of permissions.
+ Actual direct messages do not really have the concept of permissions.
This returns all the Text related permissions set to true except:
- - send_tts_messages: You cannot send TTS messages in a PM.
- - manage_messages: You cannot delete others messages in a PM.
+ - send_tts_messages: You cannot send TTS messages in a DM.
+ - manage_messages: You cannot delete others messages in a DM.
- This also handles permissions for :attr:`ChannelType.group` channels
- such as kicking or mentioning everyone.
+ This also checks the kick_members permission if the user is the owner.
Parameters
-----------
- user : :class:`User`
+ user: :class:`User`
The user to check permissions for.
Returns
@@ -436,11 +657,22 @@ class PrivateChannel(Hashable):
base = Permissions.text()
base.send_tts_messages = False
base.manage_messages = False
- base.mention_everyone = self.type is ChannelType.group
+ base.mention_everyone = True
- if user == self.owner:
+ if user.id == self.owner.id:
base.kick_members = True
return base
-
+def _channel_factory(channel_type):
+ value = try_enum(ChannelType, channel_type)
+ if value is ChannelType.text:
+ return TextChannel, value
+ elif value is ChannelType.voice:
+ return VoiceChannel, value
+ elif value is ChannelType.private:
+ return DMChannel, value
+ elif value is ChannelType.group:
+ return GroupChannel, value
+ else:
+ return None, value
diff --git a/discord/client.py b/discord/client.py
index b1dd1c22..94aaa6c4 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -27,7 +27,7 @@ DEALINGS IN THE SOFTWARE.
from . import __version__ as library_version
from .user import User
from .member import Member
-from .channel import Channel, PrivateChannel
+from .channel import *
from .server import Server
from .message import Message
from .invite import Invite
@@ -261,9 +261,9 @@ class Client:
@asyncio.coroutine
def _resolve_destination(self, destination):
- if isinstance(destination, Channel):
+ if isinstance(destination, TextChannel):
return destination.id, destination.server.id
- elif isinstance(destination, PrivateChannel):
+ elif isinstance(destination, DMChannel):
return destination.id, None
elif isinstance(destination, Server):
return destination.id, destination.id
@@ -283,7 +283,7 @@ class Client:
# couldn't find it in cache so YOLO
return destination.id, destination.id
else:
- fmt = 'Destination must be Channel, PrivateChannel, User, or Object. Received {0.__class__.__name__}'
+ fmt = 'Destination must be TextChannel, DMChannel, User, or Object. Received {0.__class__.__name__}'
raise InvalidArgument(fmt.format(destination))
def __getattr__(self, name):
diff --git a/discord/errors.py b/discord/errors.py
index 46d5e940..5449b77e 100644
--- a/discord/errors.py
+++ b/discord/errors.py
@@ -38,6 +38,12 @@ class ClientException(DiscordException):
"""
pass
+class NoMoreMessages(DiscordException):
+ """Exception that is thrown when a ``history`` operation has no more
+ messages. This is only exposed for Python 3.4 only.
+ """
+ pass
+
class GatewayNotFound(DiscordException):
"""An exception that is usually thrown when the gateway hub
for the :class:`Client` websocket is not found."""
diff --git a/discord/iterators.py b/discord/iterators.py
index 63a8776d..91470d80 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -27,23 +27,26 @@ DEALINGS IN THE SOFTWARE.
import sys
import asyncio
import aiohttp
+import datetime
+
+from .errors import NoMoreMessages
+from .utils import time_snowflake
from .message import Message
from .object import Object
PY35 = sys.version_info >= (3, 5)
-
class LogsFromIterator:
- """Iterator for recieving logs.
+ """Iterator for receiving logs.
- The messages endpoint has two behaviors we care about here:
+ The messages endpoint has two behaviours we care about here:
If `before` is specified, the messages endpoint returns the `limit`
newest messages before `before`, sorted with newest first. For filling over
- 100 messages, update the `before` parameter to the oldest message recieved.
+ 100 messages, update the `before` parameter to the oldest message received.
Messages will be returned in order by time.
If `after` is specified, it returns the `limit` oldest messages after
`after`, sorted with newest first. For filling over 100 messages, update the
- `after` parameter to the newest message recieved. If messages are not
+ `after` parameter to the newest message received. If messages are not
reversed, they will be out of order (99-0, 199-100, so on)
A note that if both before and after are specified, before is ignored by the
@@ -51,8 +54,7 @@ class LogsFromIterator:
Parameters
-----------
- client : class:`Client`
- channel : class:`Channel`
+ channel: class:`Channel`
Channel from which to request logs
limit : int
Maximum number of messages to retrieve
@@ -63,24 +65,37 @@ class LogsFromIterator:
around : :class:`Message` or id-like
Message around which all messages must be. Limit max 101. Note that if
limit is an even number, this will return at most limit+1 messages.
- reverse : bool
+ reverse: bool
If set to true, return messages in oldest->newest order. Recommended
when using with "after" queries with limit over 100, otherwise messages
- will be out of order. Defaults to False for backwards compatability.
+ will be out of order.
"""
- def __init__(self, client, channel, limit,
- before=None, after=None, around=None, reverse=False):
- self.client = client
+ def __init__(self, channel, limit,
+ before=None, after=None, around=None, reverse=None):
+
+ if isinstance(before, datetime.datetime):
+ before = Object(id=time_snowflake(before, high=False))
+ if isinstance(after, datetime.datetime):
+ after = Object(id=time_snowflake(after, high=True))
+ if isinstance(around, datetime.datetime):
+ around = Object(id=time_snowflake(around))
+
self.channel = channel
+ self.ctx = channel._state
+ self.logs_from = channel._state.http.logs_from
self.limit = limit
self.before = before
self.after = after
self.around = around
- self.reverse = reverse
+
+ if reverse is None:
+ self.reverse = after is not None
+ else:
+ self.reverse = reverse
+
self._filter = None # message dict -> bool
self.messages = asyncio.Queue()
- self.ctx = client.connection.ctx
if self.around:
if self.limit > 101:
@@ -92,29 +107,32 @@ class LogsFromIterator:
self._retrieve_messages = self._retrieve_messages_around_strategy
if self.before and self.after:
- self._filter = lambda m: self.after.id < m['id'] < self.before.id
+ self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
elif self.before:
- self._filter = lambda m: m['id'] < self.before.id
+ self._filter = lambda m: int(m['id']) < self.before.id
elif self.after:
- self._filter = lambda m: self.after.id < m['id']
+ self._filter = lambda m: self.after.id < int(m['id'])
elif self.before and self.after:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy
- self._filter = lambda m: m['id'] < self.before.id
+ self._filter = lambda m: int(m['id']) < self.before.id
else:
self._retrieve_messages = self._retrieve_messages_before_strategy
- self._filter = lambda m: m['id'] > self.after.id
+ self._filter = lambda m: int(m['id']) > self.after.id
elif self.after:
self._retrieve_messages = self._retrieve_messages_after_strategy
else:
self._retrieve_messages = self._retrieve_messages_before_strategy
@asyncio.coroutine
- def iterate(self):
+ def get(self):
if self.messages.empty():
yield from self.fill_messages()
- return self.messages.get_nowait()
+ try:
+ return self.messages.get_nowait()
+ except asyncio.QueueEmpty:
+ raise NoMoreMessages()
@asyncio.coroutine
def fill_messages(self):
@@ -136,7 +154,7 @@ class LogsFromIterator:
@asyncio.coroutine
def _retrieve_messages_before_strategy(self, retrieve):
"""Retrieve messages using before parameter."""
- data = yield from self.client._logs_from(self.channel, retrieve, before=self.before)
+ data = yield from self.logs_from(self.channel.id, retrieve, before=getattr(self.before, 'id', None))
if len(data):
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
@@ -145,7 +163,7 @@ class LogsFromIterator:
@asyncio.coroutine
def _retrieve_messages_after_strategy(self, retrieve):
"""Retrieve messages using after parameter."""
- data = yield from self.client._logs_from(self.channel, retrieve, after=self.after)
+ data = yield from self.logs_from(self.channel.id, retrieve, after=getattr(self.after, 'id', None))
if len(data):
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
@@ -155,7 +173,7 @@ class LogsFromIterator:
def _retrieve_messages_around_strategy(self, retrieve):
"""Retrieve messages using around parameter."""
if self.around:
- data = yield from self.client._logs_from(self.channel, retrieve, around=self.around)
+ data = yield from self.logs_from(self.channel.id, retrieve, around=getattr(self.around, 'id', None))
self.around = None
return data
return []
@@ -168,9 +186,9 @@ class LogsFromIterator:
@asyncio.coroutine
def __anext__(self):
try:
- msg = yield from self.iterate()
+ msg = yield from self.get()
return msg
- except asyncio.QueueEmpty:
+ except NoMoreMessages:
# if we're still empty at this point...
# we didn't get any new messages so stop looping
raise StopAsyncIteration()
diff --git a/discord/message.py b/discord/message.py
index 28ab18d1..c2caaf9d 100644
--- a/discord/message.py
+++ b/discord/message.py
@@ -24,9 +24,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
-from . import utils
from .user import User
from .reaction import Reaction
+from . import utils, abc
from .object import Object
from .calls import CallMessage
import re
@@ -292,7 +292,7 @@ class Message:
self.channel.is_private = True
return
- if not self.channel.is_private:
+ if isinstance(self.channel, abc.GuildChannel):
self.server = self.channel.server
found = self.server.get_member(self.author.id)
if found is not None:
diff --git a/discord/server.py b/discord/server.py
index 7d4cb468..d1523d6f 100644
--- a/discord/server.py
+++ b/discord/server.py
@@ -29,8 +29,8 @@ from .role import Role
from .member import Member, VoiceState
from .emoji import Emoji
from .game import Game
-from .channel import Channel
-from .enums import ServerRegion, Status, try_enum, VerificationLevel
+from .channel import *
+from .enums import ServerRegion, Status, ChannelType, try_enum, VerificationLevel
from .mixins import Hashable
import copy
@@ -273,7 +273,11 @@ class Server(Hashable):
if 'channels' in data:
channels = data['channels']
for c in channels:
- channel = Channel(server=self, data=c, state=self._state)
+ if c['type'] == ChannelType.text.value:
+ channel = TextChannel(server=self, data=c, state=self._state)
+ else:
+ channel = VoiceChannel(server=self, data=c, state=self._state)
+
self._add_channel(channel)
@utils.cached_slot_property('_default_role')
diff --git a/discord/state.py b/discord/state.py
index ad9bb172..b33d11ae 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -30,7 +30,7 @@ from .game import Game
from .emoji import Emoji
from .reaction import Reaction
from .message import Message
-from .channel import Channel, PrivateChannel
+from .channel import *
from .member import Member
from .role import Role
from . import utils, compat
@@ -153,13 +153,13 @@ class ConnectionState:
def _add_private_channel(self, channel):
self._private_channels[channel.id] = channel
- if channel.type is ChannelType.private:
- self._private_channels_by_user[channel.user.id] = channel
+ if isinstance(channel, DMChannel):
+ self._private_channels_by_user[channel.recipient.id] = channel
def _remove_private_channel(self, channel):
self._private_channels.pop(channel.id, None)
- if channel.type is ChannelType.private:
- self._private_channels_by_user.pop(channel.user.id, None)
+ if isinstance(channel, DMChannel):
+ self._private_channels_by_user.pop(channel.recipient.id, None)
def _get_message(self, msg_id):
return utils.find(lambda m: m.id == msg_id, self.messages)
@@ -229,7 +229,8 @@ class ConnectionState:
servers.append(server)
for pm in data.get('private_channels'):
- self._add_private_channel(PrivateChannel(me=self.user, data=pm, state=self.ctx))
+ factory, _ = _channel_factory(pm['type'])
+ self._add_private_channel(factory(me=self.user, data=pm, state=self.ctx))
compat.create_task(self._delay_ready(), loop=self.loop)
@@ -348,13 +349,18 @@ class ConnectionState:
self.user = User(state=self.ctx, data=data)
def parse_channel_delete(self, data):
- server = self._get_server(int(data['guild_id']))
+ server = self._get_server(utils._get_as_snowflake(data, 'guild_id'))
+ channel_id = int(data['id'])
if server is not None:
- channel_id = data.get('id')
channel = server.get_channel(channel_id)
if channel is not None:
server._remove_channel(channel)
self.dispatch('channel_delete', channel)
+ else:
+ # the reason we're doing this is so it's also removed from the
+ # private channel by user cache as well
+ channel = self._get_private_channel(channel_id)
+ self._remove_private_channel(channel)
def parse_channel_update(self, data):
channel_type = try_enum(ChannelType, data.get('type'))
@@ -375,15 +381,15 @@ class ConnectionState:
self.dispatch('channel_update', old_channel, channel)
def parse_channel_create(self, data):
- ch_type = try_enum(ChannelType, data.get('type'))
+ factory, ch_type = _channel_factory(data['type'])
channel = None
if ch_type in (ChannelType.group, ChannelType.private):
- channel = PrivateChannel(me=self.user, data=data, state=self.ctx)
+ channel = factory(me=self.user, data=data, state=self.ctx)
self._add_private_channel(channel)
else:
server = self._get_server(utils._get_as_snowflake(data, 'guild_id'))
if server is not None:
- channel = Channel(server=server, state=self.ctx, data=data)
+ channel = factory(server=server, state=self.ctx, data=data)
server._add_channel(channel)
self.dispatch('channel_create', channel)
@@ -638,14 +644,12 @@ class ConnectionState:
if channel is not None:
member = None
user_id = utils._get_as_snowflake(data, 'user_id')
- is_private = getattr(channel, 'is_private', None)
- if is_private == None:
- return
-
- if is_private:
- member = channel.user
- else:
+ if isinstance(channel, DMChannel):
+ member = channel.recipient
+ elif isinstance(channel, TextChannel):
member = channel.server.get_member(user_id)
+ elif isinstance(channel, GroupChannel):
+ member = utils.find(lambda x: x.id == user_id, channel.recipients)
if member is not None:
timestamp = datetime.datetime.utcfromtimestamp(data.get('timestamp'))