diff options
| author | Zeyla Hellyer <[email protected]> | 2017-06-16 20:29:57 -0700 |
|---|---|---|
| committer | Zeyla Hellyer <[email protected]> | 2017-06-16 20:29:57 -0700 |
| commit | 601704acb94601a134ae43e795474afe8574b2ae (patch) | |
| tree | 16194482225b4877ce70613962d277e81b13660b /src/gateway | |
| parent | Fix broken link from ModelError (diff) | |
| download | serenity-601704acb94601a134ae43e795474afe8574b2ae.tar.xz serenity-601704acb94601a134ae43e795474afe8574b2ae.zip | |
Rework shard logic and shard handling
Diffstat (limited to 'src/gateway')
| -rw-r--r-- | src/gateway/error.rs | 46 | ||||
| -rw-r--r-- | src/gateway/mod.rs | 32 | ||||
| -rw-r--r-- | src/gateway/prep.rs | 163 | ||||
| -rw-r--r-- | src/gateway/shard.rs | 577 |
4 files changed, 398 insertions, 420 deletions
diff --git a/src/gateway/error.rs b/src/gateway/error.rs index 374c4a3..2e96252 100644 --- a/src/gateway/error.rs +++ b/src/gateway/error.rs @@ -1,5 +1,5 @@ use std::error::Error as StdError; -use std::fmt::{self, Display}; +use std::fmt::{Display, Formatter, Result as FmtResult}; use websocket::message::CloseData; /// An error that occurred while attempting to deal with the gateway. @@ -16,33 +16,59 @@ pub enum Error { ExpectedHello, /// When there was an error sending a heartbeat. HeartbeatFailed, + /// When invalid authentication (a bad token) was sent in the IDENTIFY. + InvalidAuthentication, /// Expected a Ready or an InvalidateSession InvalidHandshake, /// An indicator that an unknown opcode was received from the gateway. InvalidOpCode, + /// When invalid sharding data was sent in the IDENTIFY. + /// + /// # Examples + /// + /// Sending a shard ID of 5 when sharding with 3 total is considered + /// invalid. + InvalidShardData, + /// When no authentication was sent in the IDENTIFY. + NoAuthentication, /// When a session Id was expected (for resuming), but was not present. NoSessionId, + /// When a shard would have too many guilds assigned to it. + /// + /// # Examples + /// + /// When sharding 5500 guilds on 2 shards, at least one of the shards will + /// have over the maximum number of allowed guilds per shard. + /// + /// This limit is currently 2500 guilds per shard. + OverloadedShard, /// Failed to reconnect after a number of attempts. ReconnectFailure, } impl Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> FmtResult { f.write_str(self.description()) } } impl StdError for Error { fn description(&self) -> &str { + use self::Error::*; + match *self { - Error::BuildingUrl => "Error building url", - Error::Closed(_) => "Connection closed", - Error::ExpectedHello => "Expected a Hello", - Error::HeartbeatFailed => "Failed sending a heartbeat", - Error::InvalidHandshake => "Expected a valid Handshake", - Error::InvalidOpCode => "Invalid OpCode", - Error::NoSessionId => "No Session Id present when required", - Error::ReconnectFailure => "Failed to Reconnect", + BuildingUrl => "Error building url", + Closed(_) => "Connection closed", + ExpectedHello => "Expected a Hello", + HeartbeatFailed => "Failed sending a heartbeat", + InvalidAuthentication => "Sent invalid authentication", + InvalidHandshake => "Expected a valid Handshake", + InvalidOpCode => "Invalid OpCode", + InvalidShardData => "Sent invalid shard data", + NoAuthentication => "Sent no authentication", + NoSessionId => "No Session Id present when required", + OverloadedShard => "Shard has too many guilds", + ReconnectFailure => "Failed to Reconnect", } } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index f45522a..6f839db 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -50,8 +50,38 @@ //! [docs]: https://discordapp.com/developers/docs/topics/gateway#sharding mod error; -mod prep; mod shard; pub use self::error::Error as GatewayError; pub use self::shard::Shard; + +/// Indicates the current connection stage of a [`Shard`]. +/// +/// This can be useful for knowing which shards are currently "down"/"up". +/// +/// [`Shard`]: struct.Shard.html +#[derive(Debug, Eq, PartialEq, PartialOrd, Ord)] +pub enum ConnectionStage { + /// Indicator that the [`Shard`] is normally connected and is not in, e.g., + /// a resume phase. + /// + /// [`Shard`]: struct.Shard.html + Connected, + /// Indicator that the [`Shard`] is connecting and is in, e.g., a resume + /// phase. + /// + /// [`Shard`]: struct.Shard.html + Connecting, + /// Indicator that the [`Shard`] is fully disconnected and is not in a + /// reconnecting phase. + /// + /// [`Shard`]: struct.Shard.html + Disconnected, + /// Indicator that the [`Shard`] is currently initiating a handshake. + /// + /// [`Shard`]: struct.Shard.html + Handshake, + /// Indicator that the [`Shard`] has sent an IDENTIFY packet and is awaiting + /// a READY packet. + Identifying, +} diff --git a/src/gateway/prep.rs b/src/gateway/prep.rs deleted file mode 100644 index 91b0a48..0000000 --- a/src/gateway/prep.rs +++ /dev/null @@ -1,163 +0,0 @@ -use serde_json::Value; -use std::env; -use super::GatewayError; -use websocket::client::Url; -use websocket::sync::stream::{TcpStream, TlsStream}; -use websocket::sync::Client as WsClient; -use ::constants::{self, LARGE_THRESHOLD, OpCode}; -use ::error::{Error, Result}; -use ::internal::ws_impl::{ReceiverExt, SenderExt}; -use ::model::event::{Event, GatewayEvent, ReadyEvent}; - -#[inline] -pub fn parse_ready(event: GatewayEvent, - client: &mut WsClient<TlsStream<TcpStream>>, - identification: &Value) - -> Result<(ReadyEvent, u64)> { - match event { - GatewayEvent::Dispatch(seq, Event::Ready(event)) => { - Ok((event, seq)) - }, - GatewayEvent::InvalidateSession => { - debug!("Session invalidation"); - - let _ = client.send_json(identification); - - match client.recv_json(GatewayEvent::decode)? { - GatewayEvent::Dispatch(seq, Event::Ready(event)) => { - Ok((event, seq)) - }, - other => { - debug!("Unexpected event: {:?}", other); - - Err(Error::Gateway(GatewayError::InvalidHandshake)) - }, - } - }, - other => { - debug!("Unexpected event: {:?}", other); - - Err(Error::Gateway(GatewayError::InvalidHandshake)) - }, - } -} - -pub fn identify(token: &str, shard_info: Option<[u64; 2]>) -> Value { - json!({ - "op": OpCode::Identify.num(), - "d": { - "compression": true, - "large_threshold": LARGE_THRESHOLD, - "shard": shard_info.unwrap_or([0, 1]), - "token": token, - "v": constants::GATEWAY_VERSION, - "properties": { - "$browser": "serenity", - "$device": "serenity", - "$os": env::consts::OS, - }, - }, - }) -} - -pub fn build_gateway_url(base: &str) -> Result<Url> { - Url::parse(&format!("{}?v={}", base, constants::GATEWAY_VERSION)) - .map_err(|_| Error::Gateway(GatewayError::BuildingUrl)) -} - -/* -pub fn keepalive(interval: u64, - heartbeat_sent: Arc<Mutex<Instant>>, - last_ack: Arc<Mutex<bool>>, - mut sender: Sender<WebSocketStream>, - channel: &MpscReceiver<GatewayStatus>) { - let mut base_interval = Duration::milliseconds(interval as i64); - let mut next_tick = UTC::now() + base_interval; - - let mut last_sequence = 0; - let mut last_successful = false; - - 'outer: loop { - thread::sleep(StdDuration::from_millis(100)); - - loop { - match channel.try_recv() { - Ok(GatewayStatus::Interval(interval)) => { - base_interval = Duration::milliseconds(interval as i64); - }, - Ok(GatewayStatus::Sender(new_sender)) => { - sender = new_sender; - }, - Ok(GatewayStatus::SendMessage(val)) => { - if let Err(why) = sender.send_json(&val) { - warn!("Error sending message: {:?}", why); - } - }, - Ok(GatewayStatus::Sequence(seq)) => { - last_sequence = seq; - }, - Err(TryRecvError::Empty) => break, - Err(TryRecvError::Disconnected) => break 'outer, - } - } - - if UTC::now() >= next_tick { - // If the last heartbeat didn't receive an acknowledgement, then - // shutdown and auto-reconnect. - if !*last_ack.lock().unwrap() { - debug!("Last heartbeat not acknowledged; re-connecting"); - - break; - } - - next_tick = next_tick + base_interval; - - let map = json!({ - "d": last_sequence, - "op": OpCode::Heartbeat.num(), - }); - - trace!("Sending heartbeat d: {}", last_sequence); - - match sender.send_json(&map) { - Ok(_) => { - let now = Instant::now(); - - *heartbeat_sent.lock().unwrap() = now; - *last_ack.lock().unwrap() = false; - }, - Err(why) => { - match why { - Error::WebSocket(WsError::IoError(err)) => { - if err.raw_os_error() != Some(32) { - debug!("Err w/ keepalive: {:?}", err); - } - }, - other => warn!("Other err w/ keepalive: {:?}", other), - } - - if last_successful { - debug!("If next keepalive fails, closing"); - } else { - break; - } - - last_successful = false; - }, - } - } - } - - debug!("Closing keepalive"); - - match sender.shutdown_all() { - Ok(_) => debug!("Successfully shutdown sender/receiver"), - Err(why) => { - // This can fail if the receiver already shutdown. - if why.raw_os_error() != Some(107) { - warn!("Failed to shutdown sender/receiver: {:?}", why); - } - }, - } -} -*/ diff --git a/src/gateway/shard.rs b/src/gateway/shard.rs index 1fc7c3f..38faa6b 100644 --- a/src/gateway/shard.rs +++ b/src/gateway/shard.rs @@ -1,21 +1,21 @@ use chrono::UTC; use serde_json::Value; +use std::env::consts; use std::io::Write; use std::net::Shutdown; -use std::thread; +use std::sync::{Arc, Mutex}; use std::time::{Duration as StdDuration, Instant}; -use std::mem; -use super::{GatewayError, prep}; +use super::{ConnectionStage, GatewayError}; +use websocket::client::Url; use websocket::message::{CloseData, OwnedMessage}; use websocket::stream::sync::AsTcpStream; use websocket::sync::client::{Client, ClientBuilder}; use websocket::sync::stream::{TcpStream, TlsStream}; use websocket::WebSocketError; -use ::constants::OpCode; -use ::http; +use ::constants::{self, OpCode}; use ::internal::prelude::*; -use ::internal::ws_impl::{ReceiverExt, SenderExt}; -use ::model::event::{Event, GatewayEvent, ReadyEvent}; +use ::internal::ws_impl::SenderExt; +use ::model::event::{Event, GatewayEvent}; use ::model::{Game, GuildId, OnlineStatus}; #[cfg(feature="voice")] @@ -24,6 +24,8 @@ use std::sync::mpsc::{self, Receiver as MpscReceiver}; use ::client::CACHE; #[cfg(feature="voice")] use ::ext::voice::Manager as VoiceManager; +#[cfg(feature="voice")] +use ::http; #[cfg(feature="cache")] use ::utils; @@ -76,8 +78,8 @@ pub struct Shard { /// This can be used to calculate [`latency`]. /// /// [`latency`]: fn.latency.html - heartbeat_instants: (Instant, Option<Instant>), - heartbeat_interval: u64, + heartbeat_instants: (Option<Instant>, Option<Instant>), + heartbeat_interval: Option<u64>, /// This is used by the heartbeater to determine whether the last /// heartbeat was sent without an acknowledgement, and whether to reconnect. // This _must_ be set to `true` in `Shard::handle_event`'s @@ -85,9 +87,10 @@ pub struct Shard { last_heartbeat_acknowledged: bool, seq: u64, session_id: Option<String>, - shard_info: Option<[u64; 2]>, - token: String, - ws_url: String, + shard_info: [u64; 2], + stage: ConnectionStage, + token: Arc<Mutex<String>>, + ws_url: Arc<Mutex<String>>, /// The voice connections that this Shard is responsible for. The Shard will /// update the voice connections' states. #[cfg(feature="voice")] @@ -120,62 +123,59 @@ impl Shard { /// // at this point, you can create a `loop`, and receive events and match /// // their variants /// ``` - pub fn new(base_url: &str, - token: &str, - shard_info: Option<[u64; 2]>) - -> Result<(Shard, ReadyEvent)> { - let mut client = connect(base_url)?; - - let identification = prep::identify(token, shard_info); - client.send_json(&identification)?; - - let heartbeat_interval = match client.recv_json(GatewayEvent::decode)? { - GatewayEvent::Hello(interval) => interval, - other => { - debug!("Unexpected event during shard start: {:?}", other); - - return Err(Error::Gateway(GatewayError::ExpectedHello)); - }, - }; - let heartbeat_sent = Instant::now(); - - // Parse READY - let event = client.recv_json(GatewayEvent::decode)?; - let (ready, sequence) = prep::parse_ready(event, &mut client, &identification)?; - - set_client_timeout(&mut client)?; - - Ok((feature_voice! {{ + pub fn new(ws_url: Arc<Mutex<String>>, + token: Arc<Mutex<String>>, + shard_info: [u64; 2]) + -> Result<Shard> { + let client = connect(&*ws_url.lock().unwrap())?; + + let current_presence = (None, OnlineStatus::Online, false); + let heartbeat_instants = (None, None); + let heartbeat_interval = None; + let last_heartbeat_acknowledged = true; + let seq = 0; + let stage = ConnectionStage::Handshake; + let session_id = None; + + let mut shard = feature_voice! {{ let (tx, rx) = mpsc::channel(); + let user = http::get_current_user()?; + Shard { - client: client, - current_presence: (None, OnlineStatus::Online, false), - heartbeat_instants: (heartbeat_sent, None), - heartbeat_interval: heartbeat_interval, - last_heartbeat_acknowledged: true, - seq: sequence, - token: token.to_owned(), - session_id: Some(ready.ready.session_id.clone()), - shard_info: shard_info, - ws_url: base_url.to_owned(), - manager: VoiceManager::new(tx, ready.ready.user.id), + client, + current_presence, + heartbeat_instants, + heartbeat_interval, + last_heartbeat_acknowledged, + seq, + stage, + token, + session_id, + shard_info, + ws_url, + manager: VoiceManager::new(tx, user.id), manager_rx: rx, } } else { Shard { - client: client, - current_presence: (None, OnlineStatus::Online, false), - heartbeat_instants: (heartbeat_sent, None), - heartbeat_interval: heartbeat_interval, - last_heartbeat_acknowledged: true, - seq: sequence, - token: token.to_owned(), - session_id: Some(ready.ready.session_id.clone()), - shard_info: shard_info, - ws_url: base_url.to_owned(), + client, + current_presence, + heartbeat_instants, + heartbeat_interval, + last_heartbeat_acknowledged, + seq, + stage, + token, + session_id, + shard_info, + ws_url, } - }}, ready)) + }}; + + shard.identify()?; + + Ok(shard) } /// Retrieves a copy of the current shard information. @@ -192,12 +192,15 @@ impl Shard { /// /// ```rust,no_run /// # use serenity::client::gateway::Shard; + /// # use std::sync::{Arc, Mutex}; /// # - /// # let (shard, _) = Shard::new("", "", Some([1, 2])).unwrap(); + /// # let mutex = Arc::new(Mutex::new("".to_owned())); /// # - /// assert_eq!(shard.shard_info(), Some([1, 2])); + /// # let shard = Shard::new(mutex.clone(), mutex, [1, 2]).unwrap(); + /// # + /// assert_eq!(shard.shard_info(), [1, 2]); /// ``` - pub fn shard_info(&self) -> Option<[u64; 2]> { + pub fn shard_info(&self) -> [u64; 2] { self.shard_info } @@ -221,8 +224,11 @@ impl Shard { /// /// ```rust,no_run /// # use serenity::client::gateway::Shard; + /// # use std::sync::{Arc, Mutex}; + /// # + /// # let mutex = Arc::new(Mutex::new("".to_owned())); /// # - /// # let (mut shard, _) = Shard::new("", "", Some([0, 1])).unwrap(); + /// # let mut shard = Shard::new(mutex.clone(), mutex, [0, 1]).unwrap(); /// # /// use serenity::model::Game; /// @@ -247,8 +253,11 @@ impl Shard { /// /// ```rust,no_run /// # use serenity::client::gateway::Shard; + /// # use std::sync::{Arc, Mutex}; /// # - /// # let (mut shard, _) = Shard::new("", "", Some([0, 1])).unwrap(); + /// # let mutex = Arc::new(Mutex::new("".to_owned())); + /// # + /// # let mut shard = Shard::new(mutex.clone(), mutex, [0, 1]).unwrap(); /// # /// use serenity::model::OnlineStatus; /// @@ -279,8 +288,11 @@ impl Shard { /// /// ```rust,no_run /// # use serenity::client::gateway::Shard; + /// # use std::sync::{Arc, Mutex}; + /// # + /// # let mutex = Arc::new(Mutex::new("".to_owned())); /// # - /// # let (mut shard, _) = Shard::new("", "", Some([0, 1])).unwrap(); + /// # let mut shard = Shard::new(mutex.clone(), mutex, [0, 1]).unwrap(); /// # /// use serenity::model::{Game, OnlineStatus}; /// @@ -309,32 +321,68 @@ impl Shard { /// - `Ok(Some((event, None)))`: an op0 dispatch was received, and the /// shard's voice state will be updated, _if_ the `voice` feature is /// enabled. + /// + /// # Errors + /// + /// Returns a `GatewayError::InvalidAuthentication` if invalid + /// authentication was sent in the IDENTIFY. + /// + /// Returns a `GatewayError::InvalidShardData` if invalid shard data was + /// sent in the IDENTIFY. + /// + /// Returns a `GatewayError::NoAuthentication` if no authentication was sent + /// in the IDENTIFY. + /// + /// Returns a `GatewayError::OverloadedShard` if the shard would have too + /// many guilds assigned to it. #[allow(cyclomatic_complexity)] pub(crate) fn handle_event(&mut self, event: Result<GatewayEvent>) -> Result<Option<Event>> { match event { Ok(GatewayEvent::Dispatch(seq, event)) => { - self.seq = seq; + match event { + Event::Ready(ref ready) => { + self.session_id = Some(ready.ready.session_id.clone()); + self.stage = ConnectionStage::Connected; + + set_client_timeout(&mut self.client)?; + }, + Event::Resumed(_) => { + info!("[Shard {:?}] Resumed", self.shard_info); + + self.stage = ConnectionStage::Connected; + }, + ref _other => { + #[cfg(feature="voice")] + { + self.voice_dispatch(_other); + } + }, + } - self.handle_dispatch(&event); + self.seq = seq; Ok(Some(event)) }, Ok(GatewayEvent::Heartbeat(s)) => { - info!("Received shard heartbeat"); + info!("[Shard {:?}] Received shard heartbeat", self.shard_info); // Received seq is off -- attempt to resume. if s > self.seq + 1 { - info!("Received off sequence (them: {}; us: {}); resuming", + info!("[Shard {:?}] Received off sequence (them: {}; us: {}); resuming", + self.shard_info, s, self.seq); - let _ = self.shutdown(); + if self.stage == ConnectionStage::Handshake { + self.stage = ConnectionStage::Identifying; - return if self.session_id.is_some() { - self.resume().map(Some) + self.identify()?; } else { - self.reconnect().map(Some) - }; + warn!("[Shard {:?}] Heartbeat during non-Handshake; auto-reconnecting", + self.shard_info); + + return self.autoreconnect().and(Ok(None)); + } } let map = json!({ @@ -353,31 +401,30 @@ impl Shard { }, Ok(GatewayEvent::Hello(interval)) => { if interval > 0 { - self.heartbeat_interval = interval; + self.heartbeat_interval = Some(interval); } - let _ = self.shutdown(); + if self.stage == ConnectionStage::Handshake { + self.stage = ConnectionStage::Identifying; - if self.session_id.is_some() { - self.resume().map(Some) + Ok(None) } else { - self.reconnect().map(Some) + self.autoreconnect().and(Ok(None)) } }, Ok(GatewayEvent::InvalidateSession) => { - info!("Received session invalidation; re-identifying"); + info!("[Shard {:?}] Received session invalidation; re-identifying", + self.shard_info); + self.seq = 0; self.session_id = None; - let identification = prep::identify(&self.token, self.shard_info); - let _ = self.client.send_json(&identification); + self.identify()?; Ok(None) }, Ok(GatewayEvent::Reconnect) => { - let _ = self.shutdown(); - - self.reconnect().map(Some) + self.reconnect().and(Ok(None)) }, Err(Error::Gateway(GatewayError::Closed(data))) => { let num = data.as_ref().map(|d| d.status_code); @@ -387,30 +434,53 @@ impl Shard { { let kind = if clean { "Cleanly" } else { "Uncleanly" }; - info!("{} closing with {:?}: {:?}", kind, num, reason); + info!("[Shard {:?}] {} closing with {:?}: {:?}", + self.shard_info, + kind, + num, + reason); } match num { Some(4001) => warn!("Sent invalid opcode"), Some(4002) => warn!("Sent invalid message"), - Some(4003) => warn!("Sent no authentication"), - Some(4004) => warn!("Sent invalid authentication"), + Some(4003) => { + warn!("Sent no authentication"); + + return Err(Error::Gateway(GatewayError::NoAuthentication)); + }, + Some(4004) => { + warn!("Sent invalid authentication"); + + return Err(Error::Gateway(GatewayError::InvalidAuthentication)); + }, Some(4005) => warn!("Already authenticated"), Some(4007) => { - warn!("Sent invalid seq: {}", self.seq); + warn!("[Shard {:?}] Sent invalid seq: {}", self.shard_info, self.seq); self.seq = 0; }, Some(4008) => warn!("Gateway ratelimited"), - Some(4010) => warn!("Sent invalid shard"), - Some(4011) => error!("Bot requires more shards"), + Some(4010) => { + warn!("Sent invalid shard data"); + + return Err(Error::Gateway(GatewayError::InvalidShardData)); + }, + Some(4011) => { + error!("Shard has too many guilds"); + + return Err(Error::Gateway(GatewayError::OverloadedShard)); + }, Some(4006) | Some(4009) => { - info!("Invalid session"); + info!("[Shard {:?}] Invalid session", self.shard_info); self.session_id = None; }, Some(other) if !clean => { - warn!("Unknown unclean close {}: {:?}", other, reason); + warn!("[Shard {:?}] Unknown unclean close {}: {:?}", + self.shard_info, + other, + reason); }, _ => {}, } @@ -420,30 +490,10 @@ impl Shard { }).unwrap_or(false); if resume { - info!("Attempting to resume"); - - if self.session_id.is_some() { - let _ = self.shutdown(); - - match self.resume() { - Ok(ev) => { - info!("Resumed"); - - return Ok(Some(ev)); - }, - Err(why) => { - warn!("Error resuming: {:?}", why); - info!("Falling back to reconnecting"); - }, - } - } + self.resume().or_else(|_| self.reconnect()).and(Ok(None)) + } else { + self.reconnect().and(Ok(None)) } - - info!("Reconnecting"); - - let _ = self.shutdown(); - - self.reconnect().map(Some) }, Err(Error::WebSocket(why)) => { if let WebSocketError::NoDataAvailable = why { @@ -452,37 +502,10 @@ impl Shard { } } - warn!("Websocket error: {:?}", why); - info!("Will attempt to reconnect or resume"); - - // Attempt to resume if the following was not received: - // - // - InvalidateSession. - // - // Otherwise, fallback to reconnecting. - if self.session_id.is_some() { - info!("Attempting to resume"); - - let _ = self.shutdown(); - - match self.resume() { - Ok(ev) => { - info!("Resumed"); - - return Ok(Some(ev)); - }, - Err(why) => { - warn!("Error resuming: {:?}", why); - info!("Falling back to reconnecting"); - }, - } - } - - info!("Reconnecting"); + warn!("[Shard {:?}] Websocket error: {:?}", self.shard_info, why); + info!("[Shard {:?}] Will attempt to auto-reconnect", self.shard_info); - let _ = self.shutdown(); - - self.reconnect().map(Some) + self.autoreconnect().and(Ok(None)) }, Err(error) => Err(error), } @@ -518,27 +541,31 @@ impl Shard { // Shamelessly stolen from brayzure's commit in eris: // <https://github.com/abalabahaha/eris/commit/0ce296ae9a542bcec0edf1c999ee2d9986bed5a6> pub fn latency(&self) -> Option<StdDuration> { - self.heartbeat_instants.1.map(|send| send - self.heartbeat_instants.0) + if let (Some(received), Some(sent)) = self.heartbeat_instants { + Some(sent - received) + } else { + None + } } /// Shuts down the receiver by attempting to cleanly close the /// connection. - pub fn shutdown_clean(client: &mut WsClient) -> Result<()> { + pub fn shutdown_clean(&mut self) -> Result<()> { { let message = OwnedMessage::Close(Some(CloseData { status_code: 1000, reason: String::new(), })); - client.send_message(&message)?; + self.client.send_message(&message)?; } - let mut stream = client.stream_ref().as_tcp(); + let mut stream = self.client.stream_ref().as_tcp(); stream.flush()?; stream.shutdown(Shutdown::Both)?; - debug!("Cleanly shutdown shard"); + debug!("[Shard {:?}] Cleanly shutdown shard", self.shard_info); Ok(()) } @@ -573,8 +600,11 @@ impl Shard { /// /// ```rust,no_run /// # use serenity::client::gateway::Shard; + /// # use std::sync::{Arc, Mutex}; /// # - /// # let (mut shard, _) = Shard::new("", "", Some([0, 1])).unwrap(); + /// # let mutex = Arc::new(Mutex::new("".to_owned())); + /// # + /// # let mut shard = Shard::new(mutex.clone(), mutex, [0, 1]).unwrap(); /// # /// use serenity::model::GuildId; /// @@ -588,8 +618,11 @@ impl Shard { /// /// ```rust,no_run /// # use serenity::client::gateway::Shard; + /// # use std::sync::{Arc, Mutex}; + /// # + /// # let mutex = Arc::new(Mutex::new("".to_owned())); /// # - /// # let (mut shard, _) = Shard::new("", "", Some([0, 1])).unwrap(); + /// # let mut shard = Shard::new(mutex.clone(), mutex, [0, 1]).unwrap(); /// # /// use serenity::model::GuildId; /// @@ -627,8 +660,11 @@ impl Shard { /// /// ```rust,no_run /// # use serenity::client::gateway::Shard; + /// # use std::sync::{Arc, Mutex}; + /// # + /// # let mutex = Arc::new(Mutex::new("will anyone read this".to_owned())); /// # - /// # let (shard, _) = Shard::new("will anyone read this", "", Some([0, 1])).unwrap(); + /// # let shard = Shard::new(mutex.clone(), mutex, [0, 1]).unwrap(); /// # /// let info = shard.shard_info(); /// let guilds = shard.guilds_handled(); @@ -642,33 +678,28 @@ impl Shard { pub fn guilds_handled(&self) -> u16 { let cache = CACHE.read().unwrap(); - if let Some((shard_id, shard_count)) = self.shard_info.map(|s| (s[0], s[1])) { - cache.guilds - .keys() - .filter(|guild_id| utils::shard_id(guild_id.0, shard_count) == shard_id) - .count() as u16 - } else { - cache.guilds.len() as u16 - } + let (shard_id, shard_count) = (self.shard_info[0], self.shard_info[1]); + + cache.guilds + .keys() + .filter(|guild_id| utils::shard_id(guild_id.0, shard_count) == shard_id) + .count() as u16 } - #[allow(unused_variables)] - fn handle_dispatch(&mut self, event: &Event) { - #[cfg(feature="voice")] - { - if let Event::VoiceStateUpdate(ref update) = *event { - if let Some(guild_id) = update.guild_id { - if let Some(handler) = self.manager.get(guild_id) { - handler.update_state(&update.voice_state); - } + #[cfg(feature="voice")] + fn voice_dispatch(&mut self, event: &Event) { + if let Event::VoiceStateUpdate(ref update) = *event { + if let Some(guild_id) = update.guild_id { + if let Some(handler) = self.manager.get(guild_id) { + handler.update_state(&update.voice_state); } } + } - if let Event::VoiceServerUpdate(ref update) = *event { - if let Some(guild_id) = update.guild_id { - if let Some(handler) = self.manager.get(guild_id) { - handler.update_server(&update.endpoint, &update.token); - } + if let Event::VoiceServerUpdate(ref update) = *event { + if let Some(guild_id) = update.guild_id { + if let Some(handler) = self.manager.get(guild_id) { + handler.update_server(&update.endpoint, &update.token); } } } @@ -678,7 +709,7 @@ impl Shard { pub(crate) fn cycle_voice_recv(&mut self) { if let Ok(v) = self.manager_rx.try_recv() { if let Err(why) = self.client.send_json(&v) { - warn!("Err sending voice msg: {:?}", why); + warn!("[Shard {:?}] Err sending voice msg: {:?}", self.shard_info, why); } } } @@ -689,11 +720,11 @@ impl Shard { "op": OpCode::Heartbeat.num(), }); - trace!("Sending heartbeat d: {}", self.seq); + trace!("[Shard {:?}] Sending heartbeat d: {}", self.shard_info, self.seq); match self.client.send_json(&map) { Ok(_) => { - self.heartbeat_instants.0 = Instant::now(); + self.heartbeat_instants.0 = Some(Instant::now()); self.last_heartbeat_acknowledged = false; Ok(()) @@ -702,10 +733,12 @@ impl Shard { match why { Error::WebSocket(WebSocketError::IoError(err)) => { if err.raw_os_error() != Some(32) { - debug!("Err w/ keepalive: {:?}", err); + debug!("[Shard {:?}] Err w/ heartbeating: {:?}", self.shard_info, err); } }, - other => warn!("Other err w/ keepalive: {:?}", other), + other => { + warn!("[Shard {:?}] Other err w/ keepalive: {:?}", self.shard_info, other); + }, } Err(Error::Gateway(GatewayError::HeartbeatFailed)) @@ -713,98 +746,143 @@ impl Shard { } } - pub(crate) fn heartbeat_interval(&self) -> i64 { - self.heartbeat_interval as i64 - } + pub(crate) fn check_heartbeat(&mut self) { + let heartbeat_interval = match self.heartbeat_interval { + Some(heartbeat_interval) => heartbeat_interval, + None => return, + }; - pub(crate) fn last_heartbeat_acknowledged(&self) -> bool { - self.last_heartbeat_acknowledged - } + let wait = StdDuration::from_secs(heartbeat_interval / 1000); - fn reconnect(&mut self) -> Result<Event> { - info!("Attempting to reconnect"); + // If a duration of time less than the heartbeat_interval has passed, + // then don't perform a keepalive or attempt to reconnect. + if let Some(last_sent) = self.heartbeat_instants.0 { + if last_sent.elapsed() <= wait { + return; + } + } - // Take a few attempts at reconnecting. - for i in 1u64..11u64 { - let gateway_url = http::get_gateway()?.url; + // If the last heartbeat didn't receive an acknowledgement, then + // auto-reconnect. + if !self.last_heartbeat_acknowledged { + debug!("[Shard {:?}] Last heartbeat not acknowledged; re-connecting", self.shard_info); - let shard = Shard::new(&gateway_url, - &self.token, - self.shard_info); + if let Err(why) = self.autoreconnect() { + warn!("[Shard {:?}] Err auto-reconnecting from heartbeat check: {:?}", + self.shard_info, + why); + } - if let Ok((shard, ready)) = shard { - let _ = self.shutdown(); + return; + } - mem::replace(self, shard); - self.session_id = Some(ready.ready.session_id.clone()); + // Otherwise, we're good to heartbeat. + if let Err(why) = self.heartbeat() { + warn!("[Shard {:?}] Err heartbeating: {:?}", self.shard_info, why); + } - return Ok(Event::Ready(ready)); - } + self.heartbeat_instants.0 = Some(Instant::now()); + } - let seconds = i.pow(2); + pub(crate) fn autoreconnect(&mut self) -> Result<()> { + if self.stage == ConnectionStage::Connecting { + return Ok(()); + } - debug!("Exponentially backing off for {} seconds", seconds); + if self.session_id.is_some() { + debug!("[Shard {:?}] Autoreconnector choosing to resume", self.shard_info); - // Exponentially back off. - thread::sleep(StdDuration::from_secs(seconds)); + self.resume() + } else { + debug!("[Shard {:?}] Autoreconnector choosing to reconnect", self.shard_info); + + self.reconnect() } + } + + /// Retrieves the `heartbeat_interval`. + #[inline] + pub(crate) fn heartbeat_interval(&self) -> Option<u64> { + self.heartbeat_interval + } + + /// Retrieves the value of when the last heartbeat ack was received. + #[inline] + pub(crate) fn last_heartbeat_ack(&self) -> Option<Instant> { + self.heartbeat_instants.1 + } + + fn reconnect(&mut self) -> Result<()> { + info!("[Shard {:?}] Attempting to reconnect", self.shard_info); + self.reset(); - // Reconnecting failed; just return an error instead. - Err(Error::Gateway(GatewayError::ReconnectFailure)) + self.initialize() } - pub(crate) fn resume(&mut self) -> Result<Event> { + // Attempts to send a RESUME message. + // + // # Examples + // + // Returns a `GatewayError::NoSessionId` is there is no `session_id`, + // indicating that the shard should instead [`reconnect`]. + // + // [`reconnect`]: #method.reconnect + fn resume(&mut self) -> Result<()> { + self.send_resume().or_else(|why| { + warn!("Err sending resume: {:?}", why); + + self.reconnect() + }) + } + + fn send_resume(&mut self) -> Result<()> { let session_id = match self.session_id.clone() { Some(session_id) => session_id, None => return Err(Error::Gateway(GatewayError::NoSessionId)), }; - self.client = connect(&self.ws_url)?; - self.client.send_json(&json!({ "op": OpCode::Resume.num(), "d": { "session_id": session_id, "seq": self.seq, - "token": self.token, + "token": &*self.token.lock().unwrap(), }, - }))?; - - // Note to self when this gets accepted in a decade: - // https://github.com/rust-lang/rfcs/issues/961 - let ev; - - loop { - match self.client.recv_json(GatewayEvent::decode)? { - GatewayEvent::Dispatch(seq, event) => { - match event { - Event::Ready(ref ready) => { - self.session_id = Some(ready.ready.session_id.clone()); - }, - Event::Resumed(_) => info!("Resumed"), - ref other => warn!("Unknown resume event: {:?}", other), - } + })) + } - self.seq = seq; - ev = event; + fn initialize(&mut self) -> Result<()> { + self.stage = ConnectionStage::Connecting; + self.client = connect(&self.ws_url.lock().unwrap())?; - break; - }, - GatewayEvent::Hello(i) => { - self.heartbeat_interval = i; - } - GatewayEvent::InvalidateSession => { - self.client.send_json(&prep::identify(&self.token, self.shard_info))?; - }, - other => { - debug!("Unexpected event: {:?}", other); + self.identify() + } - return Err(Error::Gateway(GatewayError::InvalidHandshake)); + fn identify(&mut self) -> Result<()> { + let identification = json!({ + "op": OpCode::Identify.num(), + "d": { + "compression": true, + "large_threshold": constants::LARGE_THRESHOLD, + "shard": self.shard_info, + "token": &*self.token.lock().unwrap(), + "v": constants::GATEWAY_VERSION, + "properties": { + "$browser": "serenity", + "$device": "serenity", + "$os": consts::OS, }, - } - } + }, + }); - Ok(ev) + self.client.send_json(&identification) + } + + fn reset(&mut self) { + self.heartbeat_instants = (Some(Instant::now()), None); + self.last_heartbeat_acknowledged = true; + self.stage = ConnectionStage::Disconnected; + self.seq = 0; } fn update_presence(&mut self) { @@ -823,7 +901,9 @@ impl Shard { }, }); - let _ = self.client.send_json(&msg); + if let Err(why) = self.client.send_json(&msg) { + warn!("[Shard {:?}] Err sending presence update: {:?}", self.shard_info, why); + } #[cfg(feature="cache")] { @@ -839,7 +919,7 @@ impl Shard { } fn connect(base_url: &str) -> Result<WsClient> { - let url = prep::build_gateway_url(base_url)?; + let url = build_gateway_url(base_url)?; let client = ClientBuilder::from_url(&url).connect_secure(None)?; Ok(client) @@ -853,3 +933,8 @@ fn set_client_timeout(client: &mut WsClient) -> Result<()> { Ok(()) } + +fn build_gateway_url(base: &str) -> Result<Url> { + Url::parse(&format!("{}?v={}", base, constants::GATEWAY_VERSION)) + .map_err(|_| Error::Gateway(GatewayError::BuildingUrl)) +} |