aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/shard.py140
1 files changed, 77 insertions, 63 deletions
diff --git a/discord/shard.py b/discord/shard.py
index 06e3f213..ef5ed119 100644
--- a/discord/shard.py
+++ b/discord/shard.py
@@ -22,8 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from __future__ import annotations
+
import asyncio
-import itertools
import logging
import aiohttp
@@ -34,22 +35,29 @@ from .backoff import ExponentialBackoff
from .gateway import *
from .errors import (
ClientException,
- InvalidArgument,
HTTPException,
GatewayNotFound,
ConnectionClosed,
PrivilegedIntentsRequired,
)
-from . import utils
from .enums import Status
+from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
+
+if TYPE_CHECKING:
+ from .gateway import DiscordWebSocket
+ from .activity import BaseActivity
+ from .enums import Status
+
+ EI = TypeVar('EI', bound='EventItem')
+
__all__ = (
'AutoShardedClient',
'ShardInfo',
)
-log = logging.getLogger(__name__)
+log: logging.Logger = logging.getLogger(__name__)
class EventType:
close = 0
@@ -62,36 +70,36 @@ class EventType:
class EventItem:
__slots__ = ('type', 'shard', 'error')
- def __init__(self, etype, shard, error):
- self.type = etype
- self.shard = shard
- self.error = error
+ def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None:
+ self.type: int = etype
+ self.shard: Optional['Shard'] = shard
+ self.error: Optional[Exception] = error
- def __lt__(self, other):
+ def __lt__(self: EI, other: EI) -> bool:
if not isinstance(other, EventItem):
return NotImplemented
return self.type < other.type
- def __eq__(self, other):
+ def __eq__(self: EI, other: EI) -> bool:
if not isinstance(other, EventItem):
return NotImplemented
return self.type == other.type
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(self.type)
class Shard:
- def __init__(self, ws, client, queue_put):
- self.ws = ws
- self._client = client
- self._dispatch = client.dispatch
- self._queue_put = queue_put
- self.loop = self._client.loop
- self._disconnect = False
+ def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None:
+ self.ws: DiscordWebSocket = ws
+ self._client: Client = client
+ self._dispatch: Callable[..., None] = client.dispatch
+ self._queue_put: Callable[[EventItem], None] = queue_put
+ self.loop: asyncio.AbstractEventLoop = self._client.loop
+ self._disconnect: bool = False
self._reconnect = client._reconnect
- self._backoff = ExponentialBackoff()
- self._task = None
- self._handled_exceptions = (
+ self._backoff: ExponentialBackoff = ExponentialBackoff()
+ self._task: Optional[asyncio.Task] = None
+ self._handled_exceptions: Tuple[Type[Exception], ...] = (
OSError,
HTTPException,
GatewayNotFound,
@@ -101,25 +109,26 @@ class Shard:
)
@property
- def id(self):
- return self.ws.shard_id
+ def id(self) -> int:
+ # DiscordWebSocket.shard_id is set in the from_client classmethod
+ return self.ws.shard_id # type: ignore
- def launch(self):
+ def launch(self) -> None:
self._task = self.loop.create_task(self.worker())
- def _cancel_task(self):
+ def _cancel_task(self) -> None:
if self._task is not None and not self._task.done():
self._task.cancel()
- async def close(self):
+ async def close(self) -> None:
self._cancel_task()
await self.ws.close(code=1000)
- async def disconnect(self):
+ async def disconnect(self) -> None:
await self.close()
self._dispatch('shard_disconnect', self.id)
- async def _handle_disconnect(self, e):
+ async def _handle_disconnect(self, e: Exception) -> None:
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
if not self._reconnect:
@@ -148,7 +157,7 @@ class Shard:
await asyncio.sleep(retry)
self._queue_put(EventItem(EventType.reconnect, self, e))
- async def worker(self):
+ async def worker(self) -> None:
while not self._client.is_closed():
try:
await self.ws.poll_event()
@@ -165,7 +174,7 @@ class Shard:
self._queue_put(EventItem(EventType.terminate, self, e))
break
- async def reidentify(self, exc):
+ async def reidentify(self, exc: ReconnectWebSocket) -> None:
self._cancel_task()
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
@@ -183,7 +192,7 @@ class Shard:
else:
self.launch()
- async def reconnect(self):
+ async def reconnect(self) -> None:
self._cancel_task()
try:
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
@@ -215,16 +224,16 @@ class ShardInfo:
__slots__ = ('_parent', 'id', 'shard_count')
- def __init__(self, parent, shard_count):
- self._parent = parent
- self.id = parent.id
- self.shard_count = shard_count
+ def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
+ self._parent: Shard = parent
+ self.id: int = parent.id
+ self.shard_count: Optional[int] = shard_count
- def is_closed(self):
+ def is_closed(self) -> bool:
""":class:`bool`: Whether the shard connection is currently closed."""
return not self._parent.ws.open
- async def disconnect(self):
+ async def disconnect(self) -> None:
"""|coro|
Disconnects a shard. When this is called, the shard connection will no
@@ -237,7 +246,7 @@ class ShardInfo:
await self._parent.disconnect()
- async def reconnect(self):
+ async def reconnect(self) -> None:
"""|coro|
Disconnects and then connects the shard again.
@@ -246,7 +255,7 @@ class ShardInfo:
await self._parent.disconnect()
await self._parent.reconnect()
- async def connect(self):
+ async def connect(self) -> None:
"""|coro|
Connects a shard. If the shard is already connected this does nothing.
@@ -257,11 +266,11 @@ class ShardInfo:
await self._parent.reconnect()
@property
- def latency(self):
+ def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
return self._parent.ws.latency
- def is_ws_ratelimited(self):
+ def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members
@@ -297,9 +306,12 @@ class AutoShardedClient(Client):
shard_ids: Optional[List[:class:`int`]]
An optional list of shard_ids to launch the shards with.
"""
- def __init__(self, *args, loop=None, **kwargs):
+ if TYPE_CHECKING:
+ _connection: AutoShardedConnectionState
+
+ def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None)
- self.shard_ids = kwargs.pop('shard_ids', None)
+ self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None:
@@ -315,18 +327,19 @@ class AutoShardedClient(Client):
self._connection._get_client = lambda: self
self.__queue = asyncio.PriorityQueue()
- def _get_websocket(self, guild_id=None, *, shard_id=None):
+ def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
if shard_id is None:
- shard_id = (guild_id >> 22) % self.shard_count
+ # guild_id won't be None if shard_id is None and shard_count won't be None here
+ shard_id = (guild_id >> 22) % self.shard_count # type: ignore
return self.__shards[shard_id].ws
- def _get_state(self, **options):
+ def _get_state(self, **options: Any) -> AutoShardedConnectionState:
return AutoShardedConnectionState(dispatch=self.dispatch,
handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
@property
- def latency(self):
+ def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This operates similarly to :meth:`Client.latency` except it uses the average
@@ -338,14 +351,14 @@ class AutoShardedClient(Client):
return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property
- def latencies(self):
+ def latencies(self) -> List[Tuple[int, float]]:
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This returns a list of tuples with elements ``(shard_id, latency)``.
"""
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
- def get_shard(self, shard_id):
+ def get_shard(self, shard_id: int) -> Optional[ShardInfo]:
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
try:
parent = self.__shards[shard_id]
@@ -355,11 +368,11 @@ class AutoShardedClient(Client):
return ShardInfo(parent, self.shard_count)
@property
- def shards(self):
+ def shards(self) -> Dict[int, ShardInfo]:
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
- async def launch_shard(self, gateway, shard_id, *, initial=False):
+ async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
@@ -372,7 +385,7 @@ class AutoShardedClient(Client):
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait)
ret.launch()
- async def launch_shards(self):
+ async def launch_shards(self) -> None:
if self.shard_count is None:
self.shard_count, gateway = await self.http.get_bot_gateway()
else:
@@ -389,7 +402,7 @@ class AutoShardedClient(Client):
self._connection.shards_launched.set()
- async def connect(self, *, reconnect=True):
+ async def connect(self, *, reconnect: bool = True) -> None:
self._reconnect = reconnect
await self.launch_shards()
@@ -413,7 +426,7 @@ class AutoShardedClient(Client):
elif item.type == EventType.clean_close:
return
- async def close(self):
+ async def close(self) -> None:
"""|coro|
Closes the connection to Discord.
@@ -425,7 +438,7 @@ class AutoShardedClient(Client):
for vc in self.voice_clients:
try:
- await vc.disconnect()
+ await vc.disconnect(force=True)
except Exception:
pass
@@ -436,7 +449,7 @@ class AutoShardedClient(Client):
await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
- async def change_presence(self, *, activity=None, status=None, shard_id=None):
+ async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[Status] = None, shard_id: int = None) -> None:
"""|coro|
Changes the client's presence.
@@ -468,23 +481,23 @@ class AutoShardedClient(Client):
"""
if status is None:
- status = 'online'
+ status_value = 'online'
status_enum = Status.online
elif status is Status.offline:
- status = 'invisible'
+ status_value = 'invisible'
status_enum = Status.offline
else:
status_enum = status
- status = str(status)
+ status_value = str(status)
if shard_id is None:
for shard in self.__shards.values():
- await shard.ws.change_presence(activity=activity, status=status)
+ await shard.ws.change_presence(activity=activity, status=status_value)
guilds = self._connection.guilds
else:
shard = self.__shards[shard_id]
- await shard.ws.change_presence(activity=activity, status=status)
+ await shard.ws.change_presence(activity=activity, status=status_value)
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]
activities = () if activity is None else (activity,)
@@ -493,10 +506,11 @@ class AutoShardedClient(Client):
if me is None:
continue
- me.activities = activities
+ # Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...]
+ me.activities = activities # type: ignore
me.status = status_enum
- def is_ws_ratelimited(self):
+ def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members