diff options
Diffstat (limited to 'src/gateway/shard.rs')
| -rw-r--r-- | src/gateway/shard.rs | 289 |
1 files changed, 145 insertions, 144 deletions
diff --git a/src/gateway/shard.rs b/src/gateway/shard.rs index 72652b9..1e298b9 100644 --- a/src/gateway/shard.rs +++ b/src/gateway/shard.rs @@ -1,17 +1,15 @@ use chrono::UTC; use std::io::Write; use std::net::Shutdown; -use std::sync::mpsc::{self, Sender as MpscSender}; -use std::sync::{Arc, Mutex}; -use std::thread::{self, Builder as ThreadBuilder}; +use std::thread; use std::time::{Duration as StdDuration, Instant}; use std::mem; -use super::{GatewayError, GatewayStatus, prep}; -use websocket::client::{Client as WsClient, Sender, Receiver}; -use websocket::message::Message as WsMessage; -use websocket::result::WebSocketError; -use websocket::stream::WebSocketStream; -use websocket::ws::sender::Sender as WsSender; +use super::{GatewayError, prep}; +use websocket::message::{CloseData, OwnedMessage}; +use websocket::stream::sync::AsTcpStream; +use websocket::sync::client::{Client, ClientBuilder}; +use websocket::sync::stream::{TcpStream, TlsStream}; +use websocket::WebSocketError; use ::constants::OpCode; use ::http; use ::internal::prelude::*; @@ -26,6 +24,8 @@ use ::ext::voice::Manager as VoiceManager; #[cfg(feature="cache")] use ::utils; +pub type WsClient = Client<TlsStream<TcpStream>>; + type CurrentPresence = (Option<Game>, OnlineStatus, bool); /// A Shard is a higher-level handler for a websocket connection to Discord's @@ -62,8 +62,8 @@ type CurrentPresence = (Option<Game>, OnlineStatus, bool); /// [`receive`]: #method.receive /// [docs]: https://discordapp.com/developers/docs/topics/gateway#sharding /// [module docs]: index.html#sharding -#[derive(Debug)] pub struct Shard { + pub client: WsClient, current_presence: CurrentPresence, /// A tuple of: /// @@ -73,13 +73,13 @@ pub struct Shard { /// This can be used to calculate [`latency`]. /// /// [`latency`]: fn.latency.html - heartbeat_instants: (Arc<Mutex<Instant>>, Option<Instant>), - keepalive_channel: MpscSender<GatewayStatus>, - /// This is used by the keepalive thread to determine whether the last + heartbeat_instants: (Instant, Option<Instant>), + heartbeat_interval: u64, + /// This is used by the heartbeater to determine whether the last /// heartbeat was sent without an acknowledgement, and whether to reconnect. // This _must_ be set to `true` in `Shard::handle_event`'s // `Ok(GatewayEvent::HeartbeatAck)` arm. - last_heartbeat_acknowledged: Arc<Mutex<bool>>, + last_heartbeat_acknowledged: bool, seq: u64, session_id: Option<String>, shard_info: Option<[u64; 2]>, @@ -118,13 +118,13 @@ impl Shard { pub fn new(base_url: &str, token: &str, shard_info: Option<[u64; 2]>) - -> Result<(Shard, ReadyEvent, Receiver<WebSocketStream>)> { - let (mut sender, mut receiver) = connect(base_url)?; + -> Result<(Shard, ReadyEvent)> { + let mut client = connect(base_url)?; let identification = prep::identify(token, shard_info); - sender.send_json(&identification)?; + client.send_json(&identification)?; - let heartbeat_interval = match receiver.recv_json(GatewayEvent::decode)? { + let heartbeat_interval = match client.recv_json(GatewayEvent::decode)? { GatewayEvent::Hello(interval) => interval, other => { debug!("Unexpected event during shard start: {:?}", other); @@ -132,47 +132,19 @@ impl Shard { return Err(Error::Gateway(GatewayError::ExpectedHello)); }, }; - - let (tx, rx) = mpsc::channel(); - let thread_name = match shard_info { - Some(info) => format!("serenity keepalive [shard {}/{}]", - info[0], - info[1] - 1), - None => "serenity keepalive [unsharded]".to_owned(), - }; - - let heartbeat_sent = Arc::new(Mutex::new(Instant::now())); - let heartbeat_clone = heartbeat_sent.clone(); - - // Set this to true: when the keepalive thread sends a heartbeat, it - // will check if the value is `false`. - // - // If it is, it will reconnect. This enters the bot into a reconnect - // loop. Set this to `true` to give Discord the first heartbeat to - // acknowledge first. - let last_ack = Arc::new(Mutex::new(true)); - let last_ack_clone = last_ack.clone(); - - ThreadBuilder::new() - .name(thread_name) - .spawn(move || { - prep::keepalive(heartbeat_interval, heartbeat_clone, last_ack_clone, sender, &rx) - })?; + let heartbeat_sent = Instant::now(); // Parse READY - let event = receiver.recv_json(GatewayEvent::decode)?; - let (ready, sequence) = prep::parse_ready(event, - &tx, - &mut receiver, - identification)?; - + let event = client.recv_json(GatewayEvent::decode)?; + let (ready, sequence) = prep::parse_ready(event, &mut client, &identification)?; Ok((feature_voice! {{ Shard { + client: client, current_presence: (None, OnlineStatus::Online, false), heartbeat_instants: (heartbeat_sent, None), - last_heartbeat_acknowledged: last_ack, - keepalive_channel: tx.clone(), + heartbeat_interval: heartbeat_interval, + last_heartbeat_acknowledged: true, seq: sequence, token: token.to_owned(), session_id: Some(ready.ready.session_id.clone()), @@ -182,17 +154,18 @@ impl Shard { } } else { Shard { + client: client, current_presence: (None, OnlineStatus::Online, false), heartbeat_instants: (heartbeat_sent, None), - last_heartbeat_acknowledged: last_ack, - keepalive_channel: tx.clone(), + heartbeat_interval: heartbeat_interval, + last_heartbeat_acknowledged: true, seq: sequence, token: token.to_owned(), session_id: Some(ready.ready.session_id.clone()), shard_info: shard_info, ws_url: base_url.to_owned(), } - }}, ready, receiver)) + }}, ready)) } /// Retrieves a copy of the current shard information. @@ -328,20 +301,14 @@ impl Shard { /// enabled. #[allow(cyclomatic_complexity)] #[doc(hidden)] - pub fn handle_event(&mut self, - event: Result<GatewayEvent>, - mut receiver: &mut Receiver<WebSocketStream>) - -> Result<Option<(Event, Option<Receiver<WebSocketStream>>)>> { + pub fn handle_event(&mut self, event: Result<GatewayEvent>) -> Result<Option<Event>> { match event { Ok(GatewayEvent::Dispatch(seq, event)) => { - let status = GatewayStatus::Sequence(seq); - let _ = self.keepalive_channel.send(status); - self.seq = seq; self.handle_dispatch(&event); - Ok(Some((event, None))) + Ok(Some(event)) }, Ok(GatewayEvent::Heartbeat(s)) => { info!("Received shard heartbeat"); @@ -352,10 +319,12 @@ impl Shard { s, self.seq); + let _ = self.shutdown(); + return if self.session_id.is_some() { - self.resume(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + self.resume().map(Some) } else { - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + self.reconnect().map(Some) }; } @@ -363,27 +332,27 @@ impl Shard { "d": Value::Null, "op": OpCode::Heartbeat.num(), }); - let status = GatewayStatus::SendMessage(map); - let _ = self.keepalive_channel.send(status); + self.client.send_json(&map)?; Ok(None) }, Ok(GatewayEvent::HeartbeatAck) => { self.heartbeat_instants.1 = Some(Instant::now()); - *self.last_heartbeat_acknowledged.lock().unwrap() = true; + self.last_heartbeat_acknowledged = true; Ok(None) }, Ok(GatewayEvent::Hello(interval)) => { if interval > 0 { - let status = GatewayStatus::Interval(interval); - let _ = self.keepalive_channel.send(status); + self.heartbeat_interval = interval; } + let _ = self.shutdown(); + if self.session_id.is_some() { - self.resume(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + self.resume().map(Some) } else { - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + self.reconnect().map(Some) } }, Ok(GatewayEvent::InvalidateSession) => { @@ -392,21 +361,24 @@ impl Shard { self.session_id = None; let identification = prep::identify(&self.token, self.shard_info); - let status = GatewayStatus::SendMessage(identification); - let _ = self.keepalive_channel.send(status); + let _ = self.client.send_json(&identification); Ok(None) }, Ok(GatewayEvent::Reconnect) => { - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + let _ = self.shutdown(); + + self.reconnect().map(Some) }, - Err(Error::Gateway(GatewayError::Closed(num, message))) => { + Err(Error::Gateway(GatewayError::Closed(data))) => { + let num = data.as_ref().map(|d| d.status_code); + let reason = data.map(|d| d.reason); let clean = num == Some(1000); { let kind = if clean { "Cleanly" } else { "Uncleanly" }; - info!("{} closing with {:?}: {}", kind, num, message); + info!("{} closing with {:?}: {:?}", kind, num, reason); } match num { @@ -429,23 +401,26 @@ impl Shard { self.session_id = None; }, Some(other) if !clean => { - warn!("Unknown unclean close {}: {:?}", other, message); + warn!("Unknown unclean close {}: {:?}", other, reason); }, _ => {}, } - let resume = num.map(|x| x != 1000 && x != 4004 && self.session_id.is_some()) - .unwrap_or(false); + let resume = num.map(|num| { + num != 1000 && num != 4004 && self.session_id.is_some() + }).unwrap_or(false); if resume { info!("Attempting to resume"); if self.session_id.is_some() { - match self.resume(receiver) { - Ok((ev, rec)) => { + let _ = self.shutdown(); + + match self.resume() { + Ok(ev) => { info!("Resumed"); - return Ok(Some((ev, Some(rec)))); + return Ok(Some(ev)); }, Err(why) => { warn!("Error resuming: {:?}", why); @@ -457,7 +432,9 @@ impl Shard { info!("Reconnecting"); - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + let _ = self.shutdown(); + + self.reconnect().map(Some) }, Err(Error::WebSocket(why)) => { if let WebSocketError::NoDataAvailable = why { @@ -477,11 +454,13 @@ impl Shard { if self.session_id.is_some() { info!("Attempting to resume"); - match self.resume(&mut receiver) { - Ok((ev, rec)) => { + let _ = self.shutdown(); + + match self.resume() { + Ok(ev) => { info!("Resumed"); - return Ok(Some((ev, Some(rec)))); + return Ok(Some(ev)); }, Err(why) => { warn!("Error resuming: {:?}", why); @@ -492,7 +471,9 @@ impl Shard { info!("Reconnecting"); - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + let _ = self.shutdown(); + + self.reconnect().map(Some) }, Err(error) => Err(error), } @@ -528,25 +509,26 @@ impl Shard { // Shamelessly stolen from brayzure's commit in eris: // <https://github.com/abalabahaha/eris/commit/0ce296ae9a542bcec0edf1c999ee2d9986bed5a6> pub fn latency(&self) -> Option<StdDuration> { - self.heartbeat_instants.1.map(|send| send - *self.heartbeat_instants.0.lock().unwrap()) + self.heartbeat_instants.1.map(|send| send - self.heartbeat_instants.0) } /// Shuts down the receiver by attempting to cleanly close the /// connection. #[doc(hidden)] - pub fn shutdown_clean(receiver: &mut Receiver<WebSocketStream>) - -> Result<()> { - let r = receiver.get_mut().get_mut(); - + pub fn shutdown_clean(client: &mut WsClient) -> Result<()> { { - let mut sender = Sender::new(r.by_ref(), true); - let message = WsMessage::close_because(1000, ""); + let message = OwnedMessage::Close(Some(CloseData { + status_code: 1000, + reason: String::new(), + })); - sender.send_message(&message)?; + client.send_message(&message)?; } - r.flush()?; - r.shutdown(Shutdown::Both)?; + let mut stream = client.stream_ref().as_tcp(); + + stream.flush()?; + stream.shutdown(Shutdown::Both)?; debug!("Cleanly shutdown shard"); @@ -555,11 +537,11 @@ impl Shard { /// Uncleanly shuts down the receiver by not sending a close code. #[doc(hidden)] - pub fn shutdown(receiver: &mut Receiver<WebSocketStream>) -> Result<()> { - let r = receiver.get_mut().get_mut(); + pub fn shutdown(&mut self) -> Result<()> { + let mut stream = self.client.stream_ref().as_tcp(); - r.flush()?; - r.shutdown(Shutdown::Both)?; + stream.flush()?; + stream.shutdown(Shutdown::Both)?; Ok(()) } @@ -612,7 +594,7 @@ impl Shard { /// [`Event::GuildMembersChunk`]: ../../model/event/enum.Event.html#variant.GuildMembersChunk /// [`Guild`]: ../../model/struct.Guild.html /// [`Member`]: ../../model/struct.Member.html - pub fn chunk_guilds(&self, guild_ids: &[GuildId], limit: Option<u16>, query: Option<&str>) { + pub fn chunk_guilds(&mut self, guild_ids: &[GuildId], limit: Option<u16>, query: Option<&str>) { let msg = json!({ "op": OpCode::GetGuildMembers.num(), "d": { @@ -622,7 +604,7 @@ impl Shard { }, }); - let _ = self.keepalive_channel.send(GatewayStatus::SendMessage(msg)); + let _ = self.client.send_json(&msg); } /// Calculates the number of guilds that the shard is responsible for. @@ -685,8 +667,48 @@ impl Shard { } } - fn reconnect(&mut self, mut receiver: &mut Receiver<WebSocketStream>) - -> Result<(Event, Receiver<WebSocketStream>)> { + #[doc(hidden)] + pub fn heartbeat(&mut self) -> Result<()> { + let map = json!({ + "d": self.seq, + "op": OpCode::Heartbeat.num(), + }); + + trace!("Sending heartbeat d: {}", self.seq); + + match self.client.send_json(&map) { + Ok(_) => { + self.heartbeat_instants.0 = Instant::now(); + self.last_heartbeat_acknowledged = false; + + Ok(()) + }, + Err(why) => { + match why { + Error::WebSocket(WebSocketError::IoError(err)) => { + if err.raw_os_error() != Some(32) { + debug!("Err w/ keepalive: {:?}", err); + } + }, + other => warn!("Other err w/ keepalive: {:?}", other), + } + + Err(Error::Gateway(GatewayError::HeartbeatFailed)) + }, + } + } + + #[doc(hidden)] + pub fn heartbeat_interval(&self) -> i64 { + self.heartbeat_interval as i64 + } + + #[doc(hidden)] + pub fn last_heartbeat_acknowledged(&self) -> bool { + self.last_heartbeat_acknowledged + } + + fn reconnect(&mut self) -> Result<Event> { info!("Attempting to reconnect"); // Take a few attempts at reconnecting. @@ -697,13 +719,13 @@ impl Shard { &self.token, self.shard_info); - if let Ok((shard, ready, receiver_new)) = shard { - let _ = Shard::shutdown(&mut receiver); + if let Ok((shard, ready)) = shard { + let _ = self.shutdown(); mem::replace(self, shard); self.session_id = Some(ready.ready.session_id.clone()); - return Ok((Event::Ready(ready), receiver_new)); + return Ok(Event::Ready(ready)); } let seconds = i.pow(2); @@ -719,17 +741,15 @@ impl Shard { } #[doc(hidden)] - pub fn resume(&mut self, receiver: &mut Receiver<WebSocketStream>) - -> Result<(Event, Receiver<WebSocketStream>)> { + pub fn resume(&mut self) -> Result<Event> { let session_id = match self.session_id.clone() { Some(session_id) => session_id, None => return Err(Error::Gateway(GatewayError::NoSessionId)), }; - let _ = receiver.shutdown_all(); - let (mut sender, mut receiver) = connect(&self.ws_url)?; + self.client = connect(&self.ws_url)?; - sender.send_json(&json!({ + self.client.send_json(&json!({ "op": OpCode::Resume.num(), "d": { "session_id": session_id, @@ -743,7 +763,7 @@ impl Shard { let ev; loop { - match receiver.recv_json(GatewayEvent::decode)? { + match self.client.recv_json(GatewayEvent::decode)? { GatewayEvent::Dispatch(seq, event) => { match event { Event::Ready(ref ready) => { @@ -759,10 +779,10 @@ impl Shard { break; }, GatewayEvent::Hello(i) => { - let _ = self.keepalive_channel.send(GatewayStatus::Interval(i)); + self.heartbeat_interval = i; } GatewayEvent::InvalidateSession => { - sender.send_json(&prep::identify(&self.token, self.shard_info))?; + self.client.send_json(&prep::identify(&self.token, self.shard_info))?; }, other => { debug!("Unexpected event: {:?}", other); @@ -772,12 +792,10 @@ impl Shard { } } - let _ = self.keepalive_channel.send(GatewayStatus::Sender(sender)); - - Ok((ev, receiver)) + Ok(ev) } - fn update_presence(&self) { + fn update_presence(&mut self) { let (ref game, status, afk) = self.current_presence; let now = UTC::now().timestamp() as u64; @@ -793,7 +811,7 @@ impl Shard { }, }); - let _ = self.keepalive_channel.send(GatewayStatus::SendMessage(msg)); + let _ = self.client.send_json(&msg); #[cfg(feature="cache")] { @@ -808,34 +826,17 @@ impl Shard { } } -fn connect(base_url: &str) -> Result<(Sender<WebSocketStream>, Receiver<WebSocketStream>)> { +fn connect(base_url: &str) -> Result<WsClient> { let url = prep::build_gateway_url(base_url)?; - let response = WsClient::connect(url)?.send()?; - response.validate()?; + let client = ClientBuilder::from_url(&url).connect_secure(None)?; - let (mut sender, mut receiver) = response.begin().split(); - - let timeout = StdDuration::from_secs(90); + let timeout = StdDuration::from_secs(1); { - let mut ws_stream = receiver.get_mut().get_mut(); - let stream = match *ws_stream { - WebSocketStream::Tcp(ref mut s) => s, - WebSocketStream::Ssl(ref mut s) => s.get_mut(), - }; - + let stream = client.stream_ref().as_tcp(); stream.set_read_timeout(Some(timeout))?; - } - - { - let mut ws_stream = sender.get_mut(); - let stream = match *ws_stream { - WebSocketStream::Tcp(ref mut s) => s, - WebSocketStream::Ssl(ref mut s) => s.get_mut(), - }; - stream.set_read_timeout(Some(timeout))?; } - Ok((sender, receiver)) + Ok(client) } |