aboutsummaryrefslogtreecommitdiff
path: root/src/gateway/shard.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/gateway/shard.rs')
-rw-r--r--src/gateway/shard.rs289
1 files changed, 145 insertions, 144 deletions
diff --git a/src/gateway/shard.rs b/src/gateway/shard.rs
index 72652b9..1e298b9 100644
--- a/src/gateway/shard.rs
+++ b/src/gateway/shard.rs
@@ -1,17 +1,15 @@
use chrono::UTC;
use std::io::Write;
use std::net::Shutdown;
-use std::sync::mpsc::{self, Sender as MpscSender};
-use std::sync::{Arc, Mutex};
-use std::thread::{self, Builder as ThreadBuilder};
+use std::thread;
use std::time::{Duration as StdDuration, Instant};
use std::mem;
-use super::{GatewayError, GatewayStatus, prep};
-use websocket::client::{Client as WsClient, Sender, Receiver};
-use websocket::message::Message as WsMessage;
-use websocket::result::WebSocketError;
-use websocket::stream::WebSocketStream;
-use websocket::ws::sender::Sender as WsSender;
+use super::{GatewayError, prep};
+use websocket::message::{CloseData, OwnedMessage};
+use websocket::stream::sync::AsTcpStream;
+use websocket::sync::client::{Client, ClientBuilder};
+use websocket::sync::stream::{TcpStream, TlsStream};
+use websocket::WebSocketError;
use ::constants::OpCode;
use ::http;
use ::internal::prelude::*;
@@ -26,6 +24,8 @@ use ::ext::voice::Manager as VoiceManager;
#[cfg(feature="cache")]
use ::utils;
+pub type WsClient = Client<TlsStream<TcpStream>>;
+
type CurrentPresence = (Option<Game>, OnlineStatus, bool);
/// A Shard is a higher-level handler for a websocket connection to Discord's
@@ -62,8 +62,8 @@ type CurrentPresence = (Option<Game>, OnlineStatus, bool);
/// [`receive`]: #method.receive
/// [docs]: https://discordapp.com/developers/docs/topics/gateway#sharding
/// [module docs]: index.html#sharding
-#[derive(Debug)]
pub struct Shard {
+ pub client: WsClient,
current_presence: CurrentPresence,
/// A tuple of:
///
@@ -73,13 +73,13 @@ pub struct Shard {
/// This can be used to calculate [`latency`].
///
/// [`latency`]: fn.latency.html
- heartbeat_instants: (Arc<Mutex<Instant>>, Option<Instant>),
- keepalive_channel: MpscSender<GatewayStatus>,
- /// This is used by the keepalive thread to determine whether the last
+ heartbeat_instants: (Instant, Option<Instant>),
+ heartbeat_interval: u64,
+ /// This is used by the heartbeater to determine whether the last
/// heartbeat was sent without an acknowledgement, and whether to reconnect.
// This _must_ be set to `true` in `Shard::handle_event`'s
// `Ok(GatewayEvent::HeartbeatAck)` arm.
- last_heartbeat_acknowledged: Arc<Mutex<bool>>,
+ last_heartbeat_acknowledged: bool,
seq: u64,
session_id: Option<String>,
shard_info: Option<[u64; 2]>,
@@ -118,13 +118,13 @@ impl Shard {
pub fn new(base_url: &str,
token: &str,
shard_info: Option<[u64; 2]>)
- -> Result<(Shard, ReadyEvent, Receiver<WebSocketStream>)> {
- let (mut sender, mut receiver) = connect(base_url)?;
+ -> Result<(Shard, ReadyEvent)> {
+ let mut client = connect(base_url)?;
let identification = prep::identify(token, shard_info);
- sender.send_json(&identification)?;
+ client.send_json(&identification)?;
- let heartbeat_interval = match receiver.recv_json(GatewayEvent::decode)? {
+ let heartbeat_interval = match client.recv_json(GatewayEvent::decode)? {
GatewayEvent::Hello(interval) => interval,
other => {
debug!("Unexpected event during shard start: {:?}", other);
@@ -132,47 +132,19 @@ impl Shard {
return Err(Error::Gateway(GatewayError::ExpectedHello));
},
};
-
- let (tx, rx) = mpsc::channel();
- let thread_name = match shard_info {
- Some(info) => format!("serenity keepalive [shard {}/{}]",
- info[0],
- info[1] - 1),
- None => "serenity keepalive [unsharded]".to_owned(),
- };
-
- let heartbeat_sent = Arc::new(Mutex::new(Instant::now()));
- let heartbeat_clone = heartbeat_sent.clone();
-
- // Set this to true: when the keepalive thread sends a heartbeat, it
- // will check if the value is `false`.
- //
- // If it is, it will reconnect. This enters the bot into a reconnect
- // loop. Set this to `true` to give Discord the first heartbeat to
- // acknowledge first.
- let last_ack = Arc::new(Mutex::new(true));
- let last_ack_clone = last_ack.clone();
-
- ThreadBuilder::new()
- .name(thread_name)
- .spawn(move || {
- prep::keepalive(heartbeat_interval, heartbeat_clone, last_ack_clone, sender, &rx)
- })?;
+ let heartbeat_sent = Instant::now();
// Parse READY
- let event = receiver.recv_json(GatewayEvent::decode)?;
- let (ready, sequence) = prep::parse_ready(event,
- &tx,
- &mut receiver,
- identification)?;
-
+ let event = client.recv_json(GatewayEvent::decode)?;
+ let (ready, sequence) = prep::parse_ready(event, &mut client, &identification)?;
Ok((feature_voice! {{
Shard {
+ client: client,
current_presence: (None, OnlineStatus::Online, false),
heartbeat_instants: (heartbeat_sent, None),
- last_heartbeat_acknowledged: last_ack,
- keepalive_channel: tx.clone(),
+ heartbeat_interval: heartbeat_interval,
+ last_heartbeat_acknowledged: true,
seq: sequence,
token: token.to_owned(),
session_id: Some(ready.ready.session_id.clone()),
@@ -182,17 +154,18 @@ impl Shard {
}
} else {
Shard {
+ client: client,
current_presence: (None, OnlineStatus::Online, false),
heartbeat_instants: (heartbeat_sent, None),
- last_heartbeat_acknowledged: last_ack,
- keepalive_channel: tx.clone(),
+ heartbeat_interval: heartbeat_interval,
+ last_heartbeat_acknowledged: true,
seq: sequence,
token: token.to_owned(),
session_id: Some(ready.ready.session_id.clone()),
shard_info: shard_info,
ws_url: base_url.to_owned(),
}
- }}, ready, receiver))
+ }}, ready))
}
/// Retrieves a copy of the current shard information.
@@ -328,20 +301,14 @@ impl Shard {
/// enabled.
#[allow(cyclomatic_complexity)]
#[doc(hidden)]
- pub fn handle_event(&mut self,
- event: Result<GatewayEvent>,
- mut receiver: &mut Receiver<WebSocketStream>)
- -> Result<Option<(Event, Option<Receiver<WebSocketStream>>)>> {
+ pub fn handle_event(&mut self, event: Result<GatewayEvent>) -> Result<Option<Event>> {
match event {
Ok(GatewayEvent::Dispatch(seq, event)) => {
- let status = GatewayStatus::Sequence(seq);
- let _ = self.keepalive_channel.send(status);
-
self.seq = seq;
self.handle_dispatch(&event);
- Ok(Some((event, None)))
+ Ok(Some(event))
},
Ok(GatewayEvent::Heartbeat(s)) => {
info!("Received shard heartbeat");
@@ -352,10 +319,12 @@ impl Shard {
s,
self.seq);
+ let _ = self.shutdown();
+
return if self.session_id.is_some() {
- self.resume(receiver).map(|(ev, rec)| Some((ev, Some(rec))))
+ self.resume().map(Some)
} else {
- self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec))))
+ self.reconnect().map(Some)
};
}
@@ -363,27 +332,27 @@ impl Shard {
"d": Value::Null,
"op": OpCode::Heartbeat.num(),
});
- let status = GatewayStatus::SendMessage(map);
- let _ = self.keepalive_channel.send(status);
+ self.client.send_json(&map)?;
Ok(None)
},
Ok(GatewayEvent::HeartbeatAck) => {
self.heartbeat_instants.1 = Some(Instant::now());
- *self.last_heartbeat_acknowledged.lock().unwrap() = true;
+ self.last_heartbeat_acknowledged = true;
Ok(None)
},
Ok(GatewayEvent::Hello(interval)) => {
if interval > 0 {
- let status = GatewayStatus::Interval(interval);
- let _ = self.keepalive_channel.send(status);
+ self.heartbeat_interval = interval;
}
+ let _ = self.shutdown();
+
if self.session_id.is_some() {
- self.resume(receiver).map(|(ev, rec)| Some((ev, Some(rec))))
+ self.resume().map(Some)
} else {
- self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec))))
+ self.reconnect().map(Some)
}
},
Ok(GatewayEvent::InvalidateSession) => {
@@ -392,21 +361,24 @@ impl Shard {
self.session_id = None;
let identification = prep::identify(&self.token, self.shard_info);
- let status = GatewayStatus::SendMessage(identification);
- let _ = self.keepalive_channel.send(status);
+ let _ = self.client.send_json(&identification);
Ok(None)
},
Ok(GatewayEvent::Reconnect) => {
- self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec))))
+ let _ = self.shutdown();
+
+ self.reconnect().map(Some)
},
- Err(Error::Gateway(GatewayError::Closed(num, message))) => {
+ Err(Error::Gateway(GatewayError::Closed(data))) => {
+ let num = data.as_ref().map(|d| d.status_code);
+ let reason = data.map(|d| d.reason);
let clean = num == Some(1000);
{
let kind = if clean { "Cleanly" } else { "Uncleanly" };
- info!("{} closing with {:?}: {}", kind, num, message);
+ info!("{} closing with {:?}: {:?}", kind, num, reason);
}
match num {
@@ -429,23 +401,26 @@ impl Shard {
self.session_id = None;
},
Some(other) if !clean => {
- warn!("Unknown unclean close {}: {:?}", other, message);
+ warn!("Unknown unclean close {}: {:?}", other, reason);
},
_ => {},
}
- let resume = num.map(|x| x != 1000 && x != 4004 && self.session_id.is_some())
- .unwrap_or(false);
+ let resume = num.map(|num| {
+ num != 1000 && num != 4004 && self.session_id.is_some()
+ }).unwrap_or(false);
if resume {
info!("Attempting to resume");
if self.session_id.is_some() {
- match self.resume(receiver) {
- Ok((ev, rec)) => {
+ let _ = self.shutdown();
+
+ match self.resume() {
+ Ok(ev) => {
info!("Resumed");
- return Ok(Some((ev, Some(rec))));
+ return Ok(Some(ev));
},
Err(why) => {
warn!("Error resuming: {:?}", why);
@@ -457,7 +432,9 @@ impl Shard {
info!("Reconnecting");
- self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec))))
+ let _ = self.shutdown();
+
+ self.reconnect().map(Some)
},
Err(Error::WebSocket(why)) => {
if let WebSocketError::NoDataAvailable = why {
@@ -477,11 +454,13 @@ impl Shard {
if self.session_id.is_some() {
info!("Attempting to resume");
- match self.resume(&mut receiver) {
- Ok((ev, rec)) => {
+ let _ = self.shutdown();
+
+ match self.resume() {
+ Ok(ev) => {
info!("Resumed");
- return Ok(Some((ev, Some(rec))));
+ return Ok(Some(ev));
},
Err(why) => {
warn!("Error resuming: {:?}", why);
@@ -492,7 +471,9 @@ impl Shard {
info!("Reconnecting");
- self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec))))
+ let _ = self.shutdown();
+
+ self.reconnect().map(Some)
},
Err(error) => Err(error),
}
@@ -528,25 +509,26 @@ impl Shard {
// Shamelessly stolen from brayzure's commit in eris:
// <https://github.com/abalabahaha/eris/commit/0ce296ae9a542bcec0edf1c999ee2d9986bed5a6>
pub fn latency(&self) -> Option<StdDuration> {
- self.heartbeat_instants.1.map(|send| send - *self.heartbeat_instants.0.lock().unwrap())
+ self.heartbeat_instants.1.map(|send| send - self.heartbeat_instants.0)
}
/// Shuts down the receiver by attempting to cleanly close the
/// connection.
#[doc(hidden)]
- pub fn shutdown_clean(receiver: &mut Receiver<WebSocketStream>)
- -> Result<()> {
- let r = receiver.get_mut().get_mut();
-
+ pub fn shutdown_clean(client: &mut WsClient) -> Result<()> {
{
- let mut sender = Sender::new(r.by_ref(), true);
- let message = WsMessage::close_because(1000, "");
+ let message = OwnedMessage::Close(Some(CloseData {
+ status_code: 1000,
+ reason: String::new(),
+ }));
- sender.send_message(&message)?;
+ client.send_message(&message)?;
}
- r.flush()?;
- r.shutdown(Shutdown::Both)?;
+ let mut stream = client.stream_ref().as_tcp();
+
+ stream.flush()?;
+ stream.shutdown(Shutdown::Both)?;
debug!("Cleanly shutdown shard");
@@ -555,11 +537,11 @@ impl Shard {
/// Uncleanly shuts down the receiver by not sending a close code.
#[doc(hidden)]
- pub fn shutdown(receiver: &mut Receiver<WebSocketStream>) -> Result<()> {
- let r = receiver.get_mut().get_mut();
+ pub fn shutdown(&mut self) -> Result<()> {
+ let mut stream = self.client.stream_ref().as_tcp();
- r.flush()?;
- r.shutdown(Shutdown::Both)?;
+ stream.flush()?;
+ stream.shutdown(Shutdown::Both)?;
Ok(())
}
@@ -612,7 +594,7 @@ impl Shard {
/// [`Event::GuildMembersChunk`]: ../../model/event/enum.Event.html#variant.GuildMembersChunk
/// [`Guild`]: ../../model/struct.Guild.html
/// [`Member`]: ../../model/struct.Member.html
- pub fn chunk_guilds(&self, guild_ids: &[GuildId], limit: Option<u16>, query: Option<&str>) {
+ pub fn chunk_guilds(&mut self, guild_ids: &[GuildId], limit: Option<u16>, query: Option<&str>) {
let msg = json!({
"op": OpCode::GetGuildMembers.num(),
"d": {
@@ -622,7 +604,7 @@ impl Shard {
},
});
- let _ = self.keepalive_channel.send(GatewayStatus::SendMessage(msg));
+ let _ = self.client.send_json(&msg);
}
/// Calculates the number of guilds that the shard is responsible for.
@@ -685,8 +667,48 @@ impl Shard {
}
}
- fn reconnect(&mut self, mut receiver: &mut Receiver<WebSocketStream>)
- -> Result<(Event, Receiver<WebSocketStream>)> {
+ #[doc(hidden)]
+ pub fn heartbeat(&mut self) -> Result<()> {
+ let map = json!({
+ "d": self.seq,
+ "op": OpCode::Heartbeat.num(),
+ });
+
+ trace!("Sending heartbeat d: {}", self.seq);
+
+ match self.client.send_json(&map) {
+ Ok(_) => {
+ self.heartbeat_instants.0 = Instant::now();
+ self.last_heartbeat_acknowledged = false;
+
+ Ok(())
+ },
+ Err(why) => {
+ match why {
+ Error::WebSocket(WebSocketError::IoError(err)) => {
+ if err.raw_os_error() != Some(32) {
+ debug!("Err w/ keepalive: {:?}", err);
+ }
+ },
+ other => warn!("Other err w/ keepalive: {:?}", other),
+ }
+
+ Err(Error::Gateway(GatewayError::HeartbeatFailed))
+ },
+ }
+ }
+
+ #[doc(hidden)]
+ pub fn heartbeat_interval(&self) -> i64 {
+ self.heartbeat_interval as i64
+ }
+
+ #[doc(hidden)]
+ pub fn last_heartbeat_acknowledged(&self) -> bool {
+ self.last_heartbeat_acknowledged
+ }
+
+ fn reconnect(&mut self) -> Result<Event> {
info!("Attempting to reconnect");
// Take a few attempts at reconnecting.
@@ -697,13 +719,13 @@ impl Shard {
&self.token,
self.shard_info);
- if let Ok((shard, ready, receiver_new)) = shard {
- let _ = Shard::shutdown(&mut receiver);
+ if let Ok((shard, ready)) = shard {
+ let _ = self.shutdown();
mem::replace(self, shard);
self.session_id = Some(ready.ready.session_id.clone());
- return Ok((Event::Ready(ready), receiver_new));
+ return Ok(Event::Ready(ready));
}
let seconds = i.pow(2);
@@ -719,17 +741,15 @@ impl Shard {
}
#[doc(hidden)]
- pub fn resume(&mut self, receiver: &mut Receiver<WebSocketStream>)
- -> Result<(Event, Receiver<WebSocketStream>)> {
+ pub fn resume(&mut self) -> Result<Event> {
let session_id = match self.session_id.clone() {
Some(session_id) => session_id,
None => return Err(Error::Gateway(GatewayError::NoSessionId)),
};
- let _ = receiver.shutdown_all();
- let (mut sender, mut receiver) = connect(&self.ws_url)?;
+ self.client = connect(&self.ws_url)?;
- sender.send_json(&json!({
+ self.client.send_json(&json!({
"op": OpCode::Resume.num(),
"d": {
"session_id": session_id,
@@ -743,7 +763,7 @@ impl Shard {
let ev;
loop {
- match receiver.recv_json(GatewayEvent::decode)? {
+ match self.client.recv_json(GatewayEvent::decode)? {
GatewayEvent::Dispatch(seq, event) => {
match event {
Event::Ready(ref ready) => {
@@ -759,10 +779,10 @@ impl Shard {
break;
},
GatewayEvent::Hello(i) => {
- let _ = self.keepalive_channel.send(GatewayStatus::Interval(i));
+ self.heartbeat_interval = i;
}
GatewayEvent::InvalidateSession => {
- sender.send_json(&prep::identify(&self.token, self.shard_info))?;
+ self.client.send_json(&prep::identify(&self.token, self.shard_info))?;
},
other => {
debug!("Unexpected event: {:?}", other);
@@ -772,12 +792,10 @@ impl Shard {
}
}
- let _ = self.keepalive_channel.send(GatewayStatus::Sender(sender));
-
- Ok((ev, receiver))
+ Ok(ev)
}
- fn update_presence(&self) {
+ fn update_presence(&mut self) {
let (ref game, status, afk) = self.current_presence;
let now = UTC::now().timestamp() as u64;
@@ -793,7 +811,7 @@ impl Shard {
},
});
- let _ = self.keepalive_channel.send(GatewayStatus::SendMessage(msg));
+ let _ = self.client.send_json(&msg);
#[cfg(feature="cache")]
{
@@ -808,34 +826,17 @@ impl Shard {
}
}
-fn connect(base_url: &str) -> Result<(Sender<WebSocketStream>, Receiver<WebSocketStream>)> {
+fn connect(base_url: &str) -> Result<WsClient> {
let url = prep::build_gateway_url(base_url)?;
- let response = WsClient::connect(url)?.send()?;
- response.validate()?;
+ let client = ClientBuilder::from_url(&url).connect_secure(None)?;
- let (mut sender, mut receiver) = response.begin().split();
-
- let timeout = StdDuration::from_secs(90);
+ let timeout = StdDuration::from_secs(1);
{
- let mut ws_stream = receiver.get_mut().get_mut();
- let stream = match *ws_stream {
- WebSocketStream::Tcp(ref mut s) => s,
- WebSocketStream::Ssl(ref mut s) => s.get_mut(),
- };
-
+ let stream = client.stream_ref().as_tcp();
stream.set_read_timeout(Some(timeout))?;
- }
-
- {
- let mut ws_stream = sender.get_mut();
- let stream = match *ws_stream {
- WebSocketStream::Tcp(ref mut s) => s,
- WebSocketStream::Ssl(ref mut s) => s.get_mut(),
- };
-
stream.set_read_timeout(Some(timeout))?;
}
- Ok((sender, receiver))
+ Ok(client)
}