diff options
| author | Zeyla Hellyer <[email protected]> | 2017-06-07 15:01:47 -0700 |
|---|---|---|
| committer | Zeyla Hellyer <[email protected]> | 2017-06-07 15:01:47 -0700 |
| commit | 8f8a05996c5b47ec9401aabb517d96ed2af5c36b (patch) | |
| tree | ab48c3b558c396f4f6d12c98a466074f97f17acf /src | |
| parent | Ws read/write timeout after 90s (diff) | |
| download | serenity-8f8a05996c5b47ec9401aabb517d96ed2af5c36b.tar.xz serenity-8f8a05996c5b47ec9401aabb517d96ed2af5c36b.zip | |
Upgrade rust-websocket, rust-openssl, and hyper
Upgrade `rust-websocket` to v0.20, maintaining use of its sync client.
This indirectly switches from `rust-openssl` v0.7 - which required
openssl-1.0 on all platforms - to `native-tls`, which allows for use of
schannel on Windows, Secure Transport on OSX, and openssl-1.1 on other
platforms.
Additionally, since hyper is no longer even a dependency of
rust-websocket, we can safely and easily upgrade to `hyper` v0.10 and
`multipart` v0.12.
This commit is fairly experimental as it has not been tested on a
long-running bot.
Diffstat (limited to 'src')
| -rw-r--r-- | src/client/context.rs | 32 | ||||
| -rw-r--r-- | src/client/mod.rs | 81 | ||||
| -rw-r--r-- | src/error.rs | 14 | ||||
| -rw-r--r-- | src/gateway/error.rs | 8 | ||||
| -rw-r--r-- | src/gateway/mod.rs | 2 | ||||
| -rw-r--r-- | src/gateway/prep.rs | 34 | ||||
| -rw-r--r-- | src/gateway/shard.rs | 289 | ||||
| -rw-r--r-- | src/gateway/status.rs | 11 | ||||
| -rw-r--r-- | src/internal/macros.rs | 24 | ||||
| -rw-r--r-- | src/internal/ws_impl.rs | 62 | ||||
| -rw-r--r-- | src/lib.rs | 4 |
11 files changed, 302 insertions, 259 deletions
diff --git a/src/client/context.rs b/src/client/context.rs index 5d36647..ec072ca 100644 --- a/src/client/context.rs +++ b/src/client/context.rs @@ -134,7 +134,8 @@ impl Context { /// /// [`Online`]: ../model/enum.OnlineStatus.html#variant.Online pub fn online(&self) { - self.shard.lock().unwrap().set_status(OnlineStatus::Online); + let mut shard = self.shard.lock().unwrap(); + shard.set_status(OnlineStatus::Online); } /// Sets the current user as being [`Idle`]. This maintains the current @@ -157,7 +158,8 @@ impl Context { /// /// [`Idle`]: ../model/enum.OnlineStatus.html#variant.Idle pub fn idle(&self) { - self.shard.lock().unwrap().set_status(OnlineStatus::Idle); + let mut shard = self.shard.lock().unwrap(); + shard.set_status(OnlineStatus::Idle); } /// Sets the current user as being [`DoNotDisturb`]. This maintains the @@ -180,7 +182,8 @@ impl Context { /// /// [`DoNotDisturb`]: ../model/enum.OnlineStatus.html#variant.DoNotDisturb pub fn dnd(&self) { - self.shard.lock().unwrap().set_status(OnlineStatus::DoNotDisturb); + let mut shard = self.shard.lock().unwrap(); + shard.set_status(OnlineStatus::DoNotDisturb); } /// Sets the current user as being [`Invisible`]. This maintains the current @@ -203,7 +206,8 @@ impl Context { /// [`Event::Ready`]: ../model/event/enum.Event.html#variant.Ready /// [`Invisible`]: ../model/enum.OnlineStatus.html#variant.Invisible pub fn invisible(&self) { - self.shard.lock().unwrap().set_status(OnlineStatus::Invisible); + let mut shard = self.shard.lock().unwrap(); + shard.set_status(OnlineStatus::Invisible); } /// "Resets" the current user's presence, by setting the game to `None` and @@ -228,9 +232,8 @@ impl Context { /// [`Online`]: ../model/enum.OnlineStatus.html#variant.Online /// [`set_presence`]: #method.set_presence pub fn reset_presence(&self) { - self.shard.lock() - .unwrap() - .set_presence(None, OnlineStatus::Online, false) + let mut shard = self.shard.lock().unwrap(); + shard.set_presence(None, OnlineStatus::Online, false) } /// Sets the current game, defaulting to an online status of [`Online`]. @@ -260,9 +263,8 @@ impl Context { /// /// [`Online`]: ../model/enum.OnlineStatus.html#variant.Online pub fn set_game(&self, game: Game) { - self.shard.lock() - .unwrap() - .set_presence(Some(game), OnlineStatus::Online, false); + let mut shard = self.shard.lock().unwrap(); + shard.set_presence(Some(game), OnlineStatus::Online, false); } /// Sets the current game, passing in only its name. This will automatically @@ -302,9 +304,8 @@ impl Context { url: None, }; - self.shard.lock() - .unwrap() - .set_presence(Some(game), OnlineStatus::Online, false); + let mut shard = self.shard.lock().unwrap(); + shard.set_presence(Some(game), OnlineStatus::Online, false); } /// Sets the current user's presence, providing all fields to be passed. @@ -351,8 +352,7 @@ impl Context { game: Option<Game>, status: OnlineStatus, afk: bool) { - self.shard.lock() - .unwrap() - .set_presence(game, status, afk) + let mut shard = self.shard.lock().unwrap(); + shard.set_presence(game, status, afk) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index f728792..91790fa 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -35,6 +35,7 @@ pub use ::http as rest; #[cfg(feature="cache")] pub use ::CACHE; +use chrono::UTC; use self::dispatch::dispatch; use self::event_store::EventStore; use std::collections::HashMap; @@ -43,9 +44,7 @@ use std::time::Duration; use std::{mem, thread}; use super::gateway::Shard; use typemap::ShareMap; -use websocket::client::Receiver; use websocket::result::WebSocketError; -use websocket::stream::WebSocketStream; use ::http; use ::internal::prelude::*; use ::internal::ws_impl::ReceiverExt; @@ -982,7 +981,7 @@ impl Client { }); match boot { - Ok((shard, ready, receiver)) => { + Ok((shard, ready)) => { #[cfg(feature="cache")] { CACHE.write() @@ -1011,7 +1010,6 @@ impl Client { event_store: self.event_store.clone(), framework: self.framework.clone(), gateway_url: gateway_url.clone(), - receiver: receiver, shard: shard, shard_info: shard_info, token: self.token.clone(), @@ -1021,7 +1019,6 @@ impl Client { data: self.data.clone(), event_store: self.event_store.clone(), gateway_url: gateway_url.clone(), - receiver: receiver, shard: shard, shard_info: shard_info, token: self.token.clone(), @@ -1254,7 +1251,6 @@ struct MonitorInfo { event_store: Arc<RwLock<EventStore>>, framework: Arc<Mutex<Framework>>, gateway_url: Arc<Mutex<String>>, - receiver: Receiver<WebSocketStream>, shard: Arc<Mutex<Shard>>, shard_info: Option<[u64; 2]>, token: String, @@ -1265,13 +1261,12 @@ struct MonitorInfo { data: Arc<Mutex<ShareMap>>, event_store: Arc<RwLock<EventStore>>, gateway_url: Arc<Mutex<String>>, - receiver: Receiver<WebSocketStream>, shard: Arc<Mutex<Shard>>, shard_info: Option<[u64; 2]>, token: String, } -fn boot_shard(info: &BootInfo) -> Result<(Shard, ReadyEvent, Receiver<WebSocketStream>)> { +fn boot_shard(info: &BootInfo) -> Result<(Shard, ReadyEvent)> { // Make ten attempts to boot the shard, exponentially backing off; if it // still doesn't boot after that, accept it as a failure. // @@ -1298,7 +1293,7 @@ fn boot_shard(info: &BootInfo) -> Result<(Shard, ReadyEvent, Receiver<WebSocketS info.shard_info); match attempt { - Ok((shard, ready, receiver)) => { + Ok((shard, ready)) => { #[cfg(feature="cache")] { CACHE.write() @@ -1308,7 +1303,7 @@ fn boot_shard(info: &BootInfo) -> Result<(Shard, ReadyEvent, Receiver<WebSocketS info!("Successfully booted shard: {:?}", info.shard_info); - return Ok((shard, ready, receiver)); + return Ok((shard, ready)); }, Err(why) => warn!("Failed to boot shard: {:?}", why), } @@ -1332,14 +1327,13 @@ fn monitor_shard(mut info: MonitorInfo) { }); match boot { - Ok((new_shard, ready, new_receiver)) => { + Ok((new_shard, ready)) => { #[cfg(feature="cache")] { CACHE.write().unwrap().update_with_ready(&ready); } *info.shard.lock().unwrap() = new_shard; - info.receiver = new_receiver; boot_successful = true; @@ -1375,16 +1369,54 @@ fn monitor_shard(mut info: MonitorInfo) { } fn handle_shard(info: &mut MonitorInfo) { + // This is currently all ducktape. Redo this. + let mut last_ack_time = UTC::now().timestamp(); + let mut last_heartbeat_sent = UTC::now().timestamp(); + loop { - let event = match info.receiver.recv_json(GatewayEvent::decode) { - Err(Error::WebSocket(WebSocketError::NoDataAvailable)) => { - debug!("Attempting to shutdown receiver/sender"); + let mut shard = info.shard.lock().unwrap(); + let in_secs = shard.heartbeat_interval() / 1000; - match info.shard.lock().unwrap().resume(&mut info.receiver) { - Ok((_, receiver)) => { + if UTC::now().timestamp() - last_heartbeat_sent > in_secs { + // If the last heartbeat didn't receive an acknowledgement, then + // shutdown and auto-reconnect. + if !shard.last_heartbeat_acknowledged() { + debug!("Last heartbeat not acknowledged; re-connecting"); + + match shard.resume() { + Ok(_) => { debug!("Successfully resumed shard"); - info.receiver = receiver; + continue; + }, + Err(why) => { + warn!("Err resuming shard: {:?}", why); + + return; + }, + } + } + + let _ = shard.heartbeat(); + last_heartbeat_sent = UTC::now().timestamp(); + } + + let event = match shard.client.recv_json(GatewayEvent::decode) { + Ok(GatewayEvent::HeartbeatAck) => { + last_ack_time = UTC::now().timestamp(); + + Ok(GatewayEvent::HeartbeatAck) + }, + Err(Error::WebSocket(WebSocketError::IoError(_))) => { + if shard.last_heartbeat_acknowledged() || UTC::now().timestamp() - 90 < last_ack_time { + continue; + } + + debug!("Attempting to shutdown receiver/sender"); + + match shard.resume() { + Ok(_) => { + debug!("Successfully resumed shard"); continue; }, @@ -1395,21 +1427,14 @@ fn handle_shard(info: &mut MonitorInfo) { }, } }, + Err(Error::WebSocket(WebSocketError::NoDataAvailable)) => continue, other => other, }; trace!("Received event on shard handler: {:?}", event); - // This will only lock when _updating_ the shard, resuming, etc. Most - // of the time, this won't be locked (i.e. when receiving an event over - // the receiver, separate from the shard itself). - let event = match info.shard.lock().unwrap().handle_event(event, &mut info.receiver) { - Ok(Some((event, Some(new_receiver)))) => { - info.receiver = new_receiver; - - event - }, - Ok(Some((event, None))) => event, + let event = match shard.handle_event(event) { + Ok(Some(event)) => event, Ok(None) => continue, Err(why) => { error!("Shard handler received err: {:?}", why); diff --git a/src/error.rs b/src/error.rs index 7795db5..6aa65e7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,6 +8,8 @@ use ::model::ModelError; #[cfg(feature="hyper")] use hyper::Error as HyperError; +#[cfg(feature="native-tls")] +use native_tls::Error as TlsError; #[cfg(feature="voice")] use opus::Error as OpusError; #[cfg(feature="websocket")] @@ -83,6 +85,9 @@ pub enum Error { /// An error from the `hyper` crate. #[cfg(feature="hyper")] Hyper(HyperError), + /// An error from the `native-tls` crate. + #[cfg(feature="native-tls")] + Tls(TlsError), /// An error from the `rust-websocket` crate. #[cfg(feature="gateway")] WebSocket(WebSocketError), @@ -141,6 +146,13 @@ impl From<OpusError> for Error { } } +#[cfg(feature="native-tls")] +impl From<TlsError> for Error { + fn from(e: TlsError) -> Error { + Error::Tls(e) + } +} + #[cfg(feature="gateway")] impl From<WebSocketError> for Error { fn from(e: WebSocketError) -> Error { @@ -184,6 +196,8 @@ impl StdError for Error { Error::Hyper(ref inner) => inner.description(), #[cfg(feature="voice")] Error::Opus(ref inner) => inner.description(), + #[cfg(feature="native-tls")] + Error::Tls(ref inner) => inner.description(), #[cfg(feature="voice")] Error::Voice(_) => "Voice error", #[cfg(feature="gateway")] diff --git a/src/gateway/error.rs b/src/gateway/error.rs index 57fa1cf..374c4a3 100644 --- a/src/gateway/error.rs +++ b/src/gateway/error.rs @@ -1,5 +1,6 @@ use std::error::Error as StdError; use std::fmt::{self, Display}; +use websocket::message::CloseData; /// An error that occurred while attempting to deal with the gateway. /// @@ -10,9 +11,11 @@ pub enum Error { /// There was an error building a URL. BuildingUrl, /// The connection closed, potentially uncleanly. - Closed(Option<u16>, String), + Closed(Option<CloseData>), /// Expected a Hello during a handshake ExpectedHello, + /// When there was an error sending a heartbeat. + HeartbeatFailed, /// Expected a Ready or an InvalidateSession InvalidHandshake, /// An indicator that an unknown opcode was received from the gateway. @@ -33,8 +36,9 @@ impl StdError for Error { fn description(&self) -> &str { match *self { Error::BuildingUrl => "Error building url", - Error::Closed(_, _) => "Connection closed", + Error::Closed(_) => "Connection closed", Error::ExpectedHello => "Expected a Hello", + Error::HeartbeatFailed => "Failed sending a heartbeat", Error::InvalidHandshake => "Expected a valid Handshake", Error::InvalidOpCode => "Invalid OpCode", Error::NoSessionId => "No Session Id present when required", diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 38cdaa1..f45522a 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -52,8 +52,6 @@ mod error; mod prep; mod shard; -mod status; pub use self::error::Error as GatewayError; pub use self::shard::Shard; -pub use self::status::Status as GatewayStatus; diff --git a/src/gateway/prep.rs b/src/gateway/prep.rs index 8e04df4..91b0a48 100644 --- a/src/gateway/prep.rs +++ b/src/gateway/prep.rs @@ -1,18 +1,9 @@ -use chrono::{Duration, UTC}; use serde_json::Value; -use std::sync::mpsc::{ - Receiver as MpscReceiver, - Sender as MpscSender, - TryRecvError, -}; -use std::sync::{Arc, Mutex}; -use std::time::{Duration as StdDuration, Instant}; -use std::{env, thread}; -use super::{GatewayError, GatewayStatus}; -use websocket::client::request::Url as RequestUrl; -use websocket::client::{Receiver, Sender}; -use websocket::result::WebSocketError as WsError; -use websocket::stream::WebSocketStream; +use std::env; +use super::GatewayError; +use websocket::client::Url; +use websocket::sync::stream::{TcpStream, TlsStream}; +use websocket::sync::Client as WsClient; use ::constants::{self, LARGE_THRESHOLD, OpCode}; use ::error::{Error, Result}; use ::internal::ws_impl::{ReceiverExt, SenderExt}; @@ -20,9 +11,8 @@ use ::model::event::{Event, GatewayEvent, ReadyEvent}; #[inline] pub fn parse_ready(event: GatewayEvent, - tx: &MpscSender<GatewayStatus>, - receiver: &mut Receiver<WebSocketStream>, - identification: Value) + client: &mut WsClient<TlsStream<TcpStream>>, + identification: &Value) -> Result<(ReadyEvent, u64)> { match event { GatewayEvent::Dispatch(seq, Event::Ready(event)) => { @@ -31,9 +21,9 @@ pub fn parse_ready(event: GatewayEvent, GatewayEvent::InvalidateSession => { debug!("Session invalidation"); - let _ = tx.send(GatewayStatus::SendMessage(identification)); + let _ = client.send_json(identification); - match receiver.recv_json(GatewayEvent::decode)? { + match client.recv_json(GatewayEvent::decode)? { GatewayEvent::Dispatch(seq, Event::Ready(event)) => { Ok((event, seq)) }, @@ -70,11 +60,12 @@ pub fn identify(token: &str, shard_info: Option<[u64; 2]>) -> Value { }) } -pub fn build_gateway_url(base: &str) -> Result<RequestUrl> { - RequestUrl::parse(&format!("{}?v={}", base, constants::GATEWAY_VERSION)) +pub fn build_gateway_url(base: &str) -> Result<Url> { + Url::parse(&format!("{}?v={}", base, constants::GATEWAY_VERSION)) .map_err(|_| Error::Gateway(GatewayError::BuildingUrl)) } +/* pub fn keepalive(interval: u64, heartbeat_sent: Arc<Mutex<Instant>>, last_ack: Arc<Mutex<bool>>, @@ -169,3 +160,4 @@ pub fn keepalive(interval: u64, }, } } +*/ diff --git a/src/gateway/shard.rs b/src/gateway/shard.rs index 72652b9..1e298b9 100644 --- a/src/gateway/shard.rs +++ b/src/gateway/shard.rs @@ -1,17 +1,15 @@ use chrono::UTC; use std::io::Write; use std::net::Shutdown; -use std::sync::mpsc::{self, Sender as MpscSender}; -use std::sync::{Arc, Mutex}; -use std::thread::{self, Builder as ThreadBuilder}; +use std::thread; use std::time::{Duration as StdDuration, Instant}; use std::mem; -use super::{GatewayError, GatewayStatus, prep}; -use websocket::client::{Client as WsClient, Sender, Receiver}; -use websocket::message::Message as WsMessage; -use websocket::result::WebSocketError; -use websocket::stream::WebSocketStream; -use websocket::ws::sender::Sender as WsSender; +use super::{GatewayError, prep}; +use websocket::message::{CloseData, OwnedMessage}; +use websocket::stream::sync::AsTcpStream; +use websocket::sync::client::{Client, ClientBuilder}; +use websocket::sync::stream::{TcpStream, TlsStream}; +use websocket::WebSocketError; use ::constants::OpCode; use ::http; use ::internal::prelude::*; @@ -26,6 +24,8 @@ use ::ext::voice::Manager as VoiceManager; #[cfg(feature="cache")] use ::utils; +pub type WsClient = Client<TlsStream<TcpStream>>; + type CurrentPresence = (Option<Game>, OnlineStatus, bool); /// A Shard is a higher-level handler for a websocket connection to Discord's @@ -62,8 +62,8 @@ type CurrentPresence = (Option<Game>, OnlineStatus, bool); /// [`receive`]: #method.receive /// [docs]: https://discordapp.com/developers/docs/topics/gateway#sharding /// [module docs]: index.html#sharding -#[derive(Debug)] pub struct Shard { + pub client: WsClient, current_presence: CurrentPresence, /// A tuple of: /// @@ -73,13 +73,13 @@ pub struct Shard { /// This can be used to calculate [`latency`]. /// /// [`latency`]: fn.latency.html - heartbeat_instants: (Arc<Mutex<Instant>>, Option<Instant>), - keepalive_channel: MpscSender<GatewayStatus>, - /// This is used by the keepalive thread to determine whether the last + heartbeat_instants: (Instant, Option<Instant>), + heartbeat_interval: u64, + /// This is used by the heartbeater to determine whether the last /// heartbeat was sent without an acknowledgement, and whether to reconnect. // This _must_ be set to `true` in `Shard::handle_event`'s // `Ok(GatewayEvent::HeartbeatAck)` arm. - last_heartbeat_acknowledged: Arc<Mutex<bool>>, + last_heartbeat_acknowledged: bool, seq: u64, session_id: Option<String>, shard_info: Option<[u64; 2]>, @@ -118,13 +118,13 @@ impl Shard { pub fn new(base_url: &str, token: &str, shard_info: Option<[u64; 2]>) - -> Result<(Shard, ReadyEvent, Receiver<WebSocketStream>)> { - let (mut sender, mut receiver) = connect(base_url)?; + -> Result<(Shard, ReadyEvent)> { + let mut client = connect(base_url)?; let identification = prep::identify(token, shard_info); - sender.send_json(&identification)?; + client.send_json(&identification)?; - let heartbeat_interval = match receiver.recv_json(GatewayEvent::decode)? { + let heartbeat_interval = match client.recv_json(GatewayEvent::decode)? { GatewayEvent::Hello(interval) => interval, other => { debug!("Unexpected event during shard start: {:?}", other); @@ -132,47 +132,19 @@ impl Shard { return Err(Error::Gateway(GatewayError::ExpectedHello)); }, }; - - let (tx, rx) = mpsc::channel(); - let thread_name = match shard_info { - Some(info) => format!("serenity keepalive [shard {}/{}]", - info[0], - info[1] - 1), - None => "serenity keepalive [unsharded]".to_owned(), - }; - - let heartbeat_sent = Arc::new(Mutex::new(Instant::now())); - let heartbeat_clone = heartbeat_sent.clone(); - - // Set this to true: when the keepalive thread sends a heartbeat, it - // will check if the value is `false`. - // - // If it is, it will reconnect. This enters the bot into a reconnect - // loop. Set this to `true` to give Discord the first heartbeat to - // acknowledge first. - let last_ack = Arc::new(Mutex::new(true)); - let last_ack_clone = last_ack.clone(); - - ThreadBuilder::new() - .name(thread_name) - .spawn(move || { - prep::keepalive(heartbeat_interval, heartbeat_clone, last_ack_clone, sender, &rx) - })?; + let heartbeat_sent = Instant::now(); // Parse READY - let event = receiver.recv_json(GatewayEvent::decode)?; - let (ready, sequence) = prep::parse_ready(event, - &tx, - &mut receiver, - identification)?; - + let event = client.recv_json(GatewayEvent::decode)?; + let (ready, sequence) = prep::parse_ready(event, &mut client, &identification)?; Ok((feature_voice! {{ Shard { + client: client, current_presence: (None, OnlineStatus::Online, false), heartbeat_instants: (heartbeat_sent, None), - last_heartbeat_acknowledged: last_ack, - keepalive_channel: tx.clone(), + heartbeat_interval: heartbeat_interval, + last_heartbeat_acknowledged: true, seq: sequence, token: token.to_owned(), session_id: Some(ready.ready.session_id.clone()), @@ -182,17 +154,18 @@ impl Shard { } } else { Shard { + client: client, current_presence: (None, OnlineStatus::Online, false), heartbeat_instants: (heartbeat_sent, None), - last_heartbeat_acknowledged: last_ack, - keepalive_channel: tx.clone(), + heartbeat_interval: heartbeat_interval, + last_heartbeat_acknowledged: true, seq: sequence, token: token.to_owned(), session_id: Some(ready.ready.session_id.clone()), shard_info: shard_info, ws_url: base_url.to_owned(), } - }}, ready, receiver)) + }}, ready)) } /// Retrieves a copy of the current shard information. @@ -328,20 +301,14 @@ impl Shard { /// enabled. #[allow(cyclomatic_complexity)] #[doc(hidden)] - pub fn handle_event(&mut self, - event: Result<GatewayEvent>, - mut receiver: &mut Receiver<WebSocketStream>) - -> Result<Option<(Event, Option<Receiver<WebSocketStream>>)>> { + pub fn handle_event(&mut self, event: Result<GatewayEvent>) -> Result<Option<Event>> { match event { Ok(GatewayEvent::Dispatch(seq, event)) => { - let status = GatewayStatus::Sequence(seq); - let _ = self.keepalive_channel.send(status); - self.seq = seq; self.handle_dispatch(&event); - Ok(Some((event, None))) + Ok(Some(event)) }, Ok(GatewayEvent::Heartbeat(s)) => { info!("Received shard heartbeat"); @@ -352,10 +319,12 @@ impl Shard { s, self.seq); + let _ = self.shutdown(); + return if self.session_id.is_some() { - self.resume(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + self.resume().map(Some) } else { - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + self.reconnect().map(Some) }; } @@ -363,27 +332,27 @@ impl Shard { "d": Value::Null, "op": OpCode::Heartbeat.num(), }); - let status = GatewayStatus::SendMessage(map); - let _ = self.keepalive_channel.send(status); + self.client.send_json(&map)?; Ok(None) }, Ok(GatewayEvent::HeartbeatAck) => { self.heartbeat_instants.1 = Some(Instant::now()); - *self.last_heartbeat_acknowledged.lock().unwrap() = true; + self.last_heartbeat_acknowledged = true; Ok(None) }, Ok(GatewayEvent::Hello(interval)) => { if interval > 0 { - let status = GatewayStatus::Interval(interval); - let _ = self.keepalive_channel.send(status); + self.heartbeat_interval = interval; } + let _ = self.shutdown(); + if self.session_id.is_some() { - self.resume(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + self.resume().map(Some) } else { - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + self.reconnect().map(Some) } }, Ok(GatewayEvent::InvalidateSession) => { @@ -392,21 +361,24 @@ impl Shard { self.session_id = None; let identification = prep::identify(&self.token, self.shard_info); - let status = GatewayStatus::SendMessage(identification); - let _ = self.keepalive_channel.send(status); + let _ = self.client.send_json(&identification); Ok(None) }, Ok(GatewayEvent::Reconnect) => { - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + let _ = self.shutdown(); + + self.reconnect().map(Some) }, - Err(Error::Gateway(GatewayError::Closed(num, message))) => { + Err(Error::Gateway(GatewayError::Closed(data))) => { + let num = data.as_ref().map(|d| d.status_code); + let reason = data.map(|d| d.reason); let clean = num == Some(1000); { let kind = if clean { "Cleanly" } else { "Uncleanly" }; - info!("{} closing with {:?}: {}", kind, num, message); + info!("{} closing with {:?}: {:?}", kind, num, reason); } match num { @@ -429,23 +401,26 @@ impl Shard { self.session_id = None; }, Some(other) if !clean => { - warn!("Unknown unclean close {}: {:?}", other, message); + warn!("Unknown unclean close {}: {:?}", other, reason); }, _ => {}, } - let resume = num.map(|x| x != 1000 && x != 4004 && self.session_id.is_some()) - .unwrap_or(false); + let resume = num.map(|num| { + num != 1000 && num != 4004 && self.session_id.is_some() + }).unwrap_or(false); if resume { info!("Attempting to resume"); if self.session_id.is_some() { - match self.resume(receiver) { - Ok((ev, rec)) => { + let _ = self.shutdown(); + + match self.resume() { + Ok(ev) => { info!("Resumed"); - return Ok(Some((ev, Some(rec)))); + return Ok(Some(ev)); }, Err(why) => { warn!("Error resuming: {:?}", why); @@ -457,7 +432,9 @@ impl Shard { info!("Reconnecting"); - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + let _ = self.shutdown(); + + self.reconnect().map(Some) }, Err(Error::WebSocket(why)) => { if let WebSocketError::NoDataAvailable = why { @@ -477,11 +454,13 @@ impl Shard { if self.session_id.is_some() { info!("Attempting to resume"); - match self.resume(&mut receiver) { - Ok((ev, rec)) => { + let _ = self.shutdown(); + + match self.resume() { + Ok(ev) => { info!("Resumed"); - return Ok(Some((ev, Some(rec)))); + return Ok(Some(ev)); }, Err(why) => { warn!("Error resuming: {:?}", why); @@ -492,7 +471,9 @@ impl Shard { info!("Reconnecting"); - self.reconnect(receiver).map(|(ev, rec)| Some((ev, Some(rec)))) + let _ = self.shutdown(); + + self.reconnect().map(Some) }, Err(error) => Err(error), } @@ -528,25 +509,26 @@ impl Shard { // Shamelessly stolen from brayzure's commit in eris: // <https://github.com/abalabahaha/eris/commit/0ce296ae9a542bcec0edf1c999ee2d9986bed5a6> pub fn latency(&self) -> Option<StdDuration> { - self.heartbeat_instants.1.map(|send| send - *self.heartbeat_instants.0.lock().unwrap()) + self.heartbeat_instants.1.map(|send| send - self.heartbeat_instants.0) } /// Shuts down the receiver by attempting to cleanly close the /// connection. #[doc(hidden)] - pub fn shutdown_clean(receiver: &mut Receiver<WebSocketStream>) - -> Result<()> { - let r = receiver.get_mut().get_mut(); - + pub fn shutdown_clean(client: &mut WsClient) -> Result<()> { { - let mut sender = Sender::new(r.by_ref(), true); - let message = WsMessage::close_because(1000, ""); + let message = OwnedMessage::Close(Some(CloseData { + status_code: 1000, + reason: String::new(), + })); - sender.send_message(&message)?; + client.send_message(&message)?; } - r.flush()?; - r.shutdown(Shutdown::Both)?; + let mut stream = client.stream_ref().as_tcp(); + + stream.flush()?; + stream.shutdown(Shutdown::Both)?; debug!("Cleanly shutdown shard"); @@ -555,11 +537,11 @@ impl Shard { /// Uncleanly shuts down the receiver by not sending a close code. #[doc(hidden)] - pub fn shutdown(receiver: &mut Receiver<WebSocketStream>) -> Result<()> { - let r = receiver.get_mut().get_mut(); + pub fn shutdown(&mut self) -> Result<()> { + let mut stream = self.client.stream_ref().as_tcp(); - r.flush()?; - r.shutdown(Shutdown::Both)?; + stream.flush()?; + stream.shutdown(Shutdown::Both)?; Ok(()) } @@ -612,7 +594,7 @@ impl Shard { /// [`Event::GuildMembersChunk`]: ../../model/event/enum.Event.html#variant.GuildMembersChunk /// [`Guild`]: ../../model/struct.Guild.html /// [`Member`]: ../../model/struct.Member.html - pub fn chunk_guilds(&self, guild_ids: &[GuildId], limit: Option<u16>, query: Option<&str>) { + pub fn chunk_guilds(&mut self, guild_ids: &[GuildId], limit: Option<u16>, query: Option<&str>) { let msg = json!({ "op": OpCode::GetGuildMembers.num(), "d": { @@ -622,7 +604,7 @@ impl Shard { }, }); - let _ = self.keepalive_channel.send(GatewayStatus::SendMessage(msg)); + let _ = self.client.send_json(&msg); } /// Calculates the number of guilds that the shard is responsible for. @@ -685,8 +667,48 @@ impl Shard { } } - fn reconnect(&mut self, mut receiver: &mut Receiver<WebSocketStream>) - -> Result<(Event, Receiver<WebSocketStream>)> { + #[doc(hidden)] + pub fn heartbeat(&mut self) -> Result<()> { + let map = json!({ + "d": self.seq, + "op": OpCode::Heartbeat.num(), + }); + + trace!("Sending heartbeat d: {}", self.seq); + + match self.client.send_json(&map) { + Ok(_) => { + self.heartbeat_instants.0 = Instant::now(); + self.last_heartbeat_acknowledged = false; + + Ok(()) + }, + Err(why) => { + match why { + Error::WebSocket(WebSocketError::IoError(err)) => { + if err.raw_os_error() != Some(32) { + debug!("Err w/ keepalive: {:?}", err); + } + }, + other => warn!("Other err w/ keepalive: {:?}", other), + } + + Err(Error::Gateway(GatewayError::HeartbeatFailed)) + }, + } + } + + #[doc(hidden)] + pub fn heartbeat_interval(&self) -> i64 { + self.heartbeat_interval as i64 + } + + #[doc(hidden)] + pub fn last_heartbeat_acknowledged(&self) -> bool { + self.last_heartbeat_acknowledged + } + + fn reconnect(&mut self) -> Result<Event> { info!("Attempting to reconnect"); // Take a few attempts at reconnecting. @@ -697,13 +719,13 @@ impl Shard { &self.token, self.shard_info); - if let Ok((shard, ready, receiver_new)) = shard { - let _ = Shard::shutdown(&mut receiver); + if let Ok((shard, ready)) = shard { + let _ = self.shutdown(); mem::replace(self, shard); self.session_id = Some(ready.ready.session_id.clone()); - return Ok((Event::Ready(ready), receiver_new)); + return Ok(Event::Ready(ready)); } let seconds = i.pow(2); @@ -719,17 +741,15 @@ impl Shard { } #[doc(hidden)] - pub fn resume(&mut self, receiver: &mut Receiver<WebSocketStream>) - -> Result<(Event, Receiver<WebSocketStream>)> { + pub fn resume(&mut self) -> Result<Event> { let session_id = match self.session_id.clone() { Some(session_id) => session_id, None => return Err(Error::Gateway(GatewayError::NoSessionId)), }; - let _ = receiver.shutdown_all(); - let (mut sender, mut receiver) = connect(&self.ws_url)?; + self.client = connect(&self.ws_url)?; - sender.send_json(&json!({ + self.client.send_json(&json!({ "op": OpCode::Resume.num(), "d": { "session_id": session_id, @@ -743,7 +763,7 @@ impl Shard { let ev; loop { - match receiver.recv_json(GatewayEvent::decode)? { + match self.client.recv_json(GatewayEvent::decode)? { GatewayEvent::Dispatch(seq, event) => { match event { Event::Ready(ref ready) => { @@ -759,10 +779,10 @@ impl Shard { break; }, GatewayEvent::Hello(i) => { - let _ = self.keepalive_channel.send(GatewayStatus::Interval(i)); + self.heartbeat_interval = i; } GatewayEvent::InvalidateSession => { - sender.send_json(&prep::identify(&self.token, self.shard_info))?; + self.client.send_json(&prep::identify(&self.token, self.shard_info))?; }, other => { debug!("Unexpected event: {:?}", other); @@ -772,12 +792,10 @@ impl Shard { } } - let _ = self.keepalive_channel.send(GatewayStatus::Sender(sender)); - - Ok((ev, receiver)) + Ok(ev) } - fn update_presence(&self) { + fn update_presence(&mut self) { let (ref game, status, afk) = self.current_presence; let now = UTC::now().timestamp() as u64; @@ -793,7 +811,7 @@ impl Shard { }, }); - let _ = self.keepalive_channel.send(GatewayStatus::SendMessage(msg)); + let _ = self.client.send_json(&msg); #[cfg(feature="cache")] { @@ -808,34 +826,17 @@ impl Shard { } } -fn connect(base_url: &str) -> Result<(Sender<WebSocketStream>, Receiver<WebSocketStream>)> { +fn connect(base_url: &str) -> Result<WsClient> { let url = prep::build_gateway_url(base_url)?; - let response = WsClient::connect(url)?.send()?; - response.validate()?; + let client = ClientBuilder::from_url(&url).connect_secure(None)?; - let (mut sender, mut receiver) = response.begin().split(); - - let timeout = StdDuration::from_secs(90); + let timeout = StdDuration::from_secs(1); { - let mut ws_stream = receiver.get_mut().get_mut(); - let stream = match *ws_stream { - WebSocketStream::Tcp(ref mut s) => s, - WebSocketStream::Ssl(ref mut s) => s.get_mut(), - }; - + let stream = client.stream_ref().as_tcp(); stream.set_read_timeout(Some(timeout))?; - } - - { - let mut ws_stream = sender.get_mut(); - let stream = match *ws_stream { - WebSocketStream::Tcp(ref mut s) => s, - WebSocketStream::Ssl(ref mut s) => s.get_mut(), - }; - stream.set_read_timeout(Some(timeout))?; } - Ok((sender, receiver)) + Ok(client) } diff --git a/src/gateway/status.rs b/src/gateway/status.rs deleted file mode 100644 index f6e5ec2..0000000 --- a/src/gateway/status.rs +++ /dev/null @@ -1,11 +0,0 @@ -use serde_json::Value; -use websocket::client::Sender; -use websocket::stream::WebSocketStream; - -#[doc(hidden)] -pub enum Status { - Interval(u64), - Sender(Sender<WebSocketStream>), - SendMessage(Value), - Sequence(u64), -} diff --git a/src/internal/macros.rs b/src/internal/macros.rs index 1e09b2e..c2475d1 100644 --- a/src/internal/macros.rs +++ b/src/internal/macros.rs @@ -2,29 +2,45 @@ macro_rules! request { ($route:expr, $method:ident($body:expr), $url:expr, $($rest:tt)*) => {{ - let client = HyperClient::new(); + let client = request_client!(); + request($route, || client .$method(&format!(api!($url), $($rest)*)) .body(&$body))? }}; ($route:expr, $method:ident($body:expr), $url:expr) => {{ - let client = HyperClient::new(); + let client = request_client!(); + request($route, || client .$method(api!($url)) .body(&$body))? }}; ($route:expr, $method:ident, $url:expr, $($rest:tt)*) => {{ - let client = HyperClient::new(); + let client = request_client!(); + request($route, || client .$method(&format!(api!($url), $($rest)*)))? }}; ($route:expr, $method:ident, $url:expr) => {{ - let client = HyperClient::new(); + let client = request_client!(); + request($route, || client .$method(api!($url)))? }}; } +macro_rules! request_client { + () => {{ + use hyper::net::HttpsConnector; + use hyper_native_tls::NativeTlsClient; + + let tc = NativeTlsClient::new()?; + let connector = HttpsConnector::new(tc); + + HyperClient::with_connector(connector) + }} +} + macro_rules! cdn { ($e:expr) => { concat!("https://cdn.discordapp.com", $e) diff --git a/src/internal/ws_impl.rs b/src/internal/ws_impl.rs index 5475a3e..0db40ee 100644 --- a/src/internal/ws_impl.rs +++ b/src/internal/ws_impl.rs @@ -1,10 +1,8 @@ use flate2::read::ZlibDecoder; use serde_json; -use websocket::client::{Receiver, Sender}; -use websocket::message::{Message as WsMessage, Type as WsType}; -use websocket::stream::WebSocketStream; -use websocket::ws::receiver::Receiver as WsReceiver; -use websocket::ws::sender::Sender as WsSender; +use websocket::message::OwnedMessage; +use websocket::sync::stream::{TcpStream, TlsStream}; +use websocket::sync::Client as WsClient; use ::gateway::GatewayError; use ::internal::prelude::*; @@ -17,43 +15,45 @@ pub trait SenderExt { fn send_json(&mut self, value: &Value) -> Result<()>; } -impl ReceiverExt for Receiver<WebSocketStream> { +impl ReceiverExt for WsClient<TlsStream<TcpStream>> { fn recv_json<F, T>(&mut self, decode: F) -> Result<T> where F: FnOnce(Value) -> Result<T> { - let message: WsMessage = self.recv_message()?; + match self.recv_message()? { + OwnedMessage::Binary(bytes) => { + let value = serde_json::from_reader(ZlibDecoder::new(&bytes[..]))?; - if message.opcode == WsType::Close { - let r = String::from_utf8_lossy(&message.payload).into_owned(); - - Err(Error::Gateway(GatewayError::Closed(message.cd_status_code, r))) - } else if message.opcode == WsType::Binary || message.opcode == WsType::Text { - let json: Value = if message.opcode == WsType::Binary { - serde_json::from_reader(ZlibDecoder::new(&message.payload[..]))? - } else { - serde_json::from_reader(&message.payload[..])? - }; - - match decode(json) { - Ok(v) => Ok(v), - Err(why) => { - let s = String::from_utf8_lossy(&message.payload); + decode(value).map_err(|why| { + let s = String::from_utf8_lossy(&bytes); warn!("(╯°□°)╯︵ ┻━┻ Error decoding: {}", s); - Err(why) - } - } - } else { - let r = String::from_utf8_lossy(&message.payload).into_owned(); - - Err(Error::Gateway(GatewayError::Closed(None, r))) + why + }) + }, + OwnedMessage::Close(data) => { + Err(Error::Gateway(GatewayError::Closed(data))) + }, + OwnedMessage::Text(payload) => { + let value = serde_json::from_str(&payload)?; + + decode(value).map_err(|why| { + warn!("(╯°□°)╯︵ ┻━┻ Error decoding: {}", payload); + + why + }) + }, + OwnedMessage::Ping(x) | OwnedMessage::Pong(x) => { + warn!("Unexpectly got ping/pong: {:?}", x); + + Err(Error::Gateway(GatewayError::Closed(None))) + }, } } } -impl SenderExt for Sender<WebSocketStream> { +impl SenderExt for WsClient<TlsStream<TcpStream>> { fn send_json(&mut self, value: &Value) -> Result<()> { serde_json::to_string(value) - .map(WsMessage::text) + .map(OwnedMessage::Text) .map_err(Error::from) .and_then(|m| self.send_message(&m).map_err(Error::from)) } @@ -109,8 +109,12 @@ extern crate serde; extern crate byteorder; #[cfg(feature="hyper")] extern crate hyper; +#[cfg(feature="hyper-native-tls")] +extern crate hyper_native_tls; #[cfg(feature="http")] extern crate multipart; +#[cfg(feature="native-tls")] +extern crate native_tls; #[cfg(feature="voice")] extern crate opus; #[cfg(feature="voice")] |