diff options
| author | Zeyla Hellyer <[email protected]> | 2017-09-24 15:48:02 -0700 |
|---|---|---|
| committer | Zeyla Hellyer <[email protected]> | 2017-09-24 15:53:23 -0700 |
| commit | 6c43fed3702be3fdc1eafed26a2f6335acd71843 (patch) | |
| tree | e3dd142b36f221f33fb8e35c511bbf4e9e9471b6 /src/client/bridge | |
| parent | Use $crate for CommandError (diff) | |
| download | serenity-6c43fed3702be3fdc1eafed26a2f6335acd71843.tar.xz serenity-6c43fed3702be3fdc1eafed26a2f6335acd71843.zip | |
Add a shard manager
The shard manager will queue up shards for booting.
Diffstat (limited to 'src/client/bridge')
| -rw-r--r-- | src/client/bridge/gateway/mod.rs | 46 | ||||
| -rw-r--r-- | src/client/bridge/gateway/shard_manager.rs | 171 | ||||
| -rw-r--r-- | src/client/bridge/gateway/shard_queuer.rs | 118 | ||||
| -rw-r--r-- | src/client/bridge/gateway/shard_runner.rs | 171 | ||||
| -rw-r--r-- | src/client/bridge/mod.rs | 1 |
5 files changed, 507 insertions, 0 deletions
diff --git a/src/client/bridge/gateway/mod.rs b/src/client/bridge/gateway/mod.rs new file mode 100644 index 0000000..0bfd4e6 --- /dev/null +++ b/src/client/bridge/gateway/mod.rs @@ -0,0 +1,46 @@ +mod shard_manager; +mod shard_queuer; +mod shard_runner; + +pub use self::shard_manager::ShardManager; +pub use self::shard_queuer::ShardQueuer; +pub use self::shard_runner::ShardRunner; + +use gateway::Shard; +use parking_lot::Mutex; +use std::fmt::{Display, Formatter, Result as FmtResult}; +use std::sync::mpsc::Sender; +use std::sync::Arc; + +type Parked<T> = Arc<Mutex<T>>; +type LockedShard = Parked<Shard>; + +#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub enum ShardManagerMessage { + Restart(ShardId), + Shutdown(ShardId), + ShutdownAll, +} + +pub enum ShardQueuerMessage { + /// Message to start a shard, where the 0-index element is the ID of the + /// Shard to start and the 1-index element is the total shards in use. + Start(ShardId, ShardId), + /// Message to shutdown the shard queuer. + Shutdown, +} + +// A light tuplestruct wrapper around a u64 to verify type correctness when +// working with the IDs of shards. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub struct ShardId(pub u64); + +impl Display for ShardId { + fn fmt(&self, f: &mut Formatter) -> FmtResult { + write!(f, "{}", self.0) + } +} + +pub struct ShardRunnerInfo { + runner_tx: Sender<ShardManagerMessage>, +} diff --git a/src/client/bridge/gateway/shard_manager.rs b/src/client/bridge/gateway/shard_manager.rs new file mode 100644 index 0000000..0ac9d19 --- /dev/null +++ b/src/client/bridge/gateway/shard_manager.rs @@ -0,0 +1,171 @@ +use internal::prelude::*; +use parking_lot::Mutex as ParkingLotMutex; +use std::collections::HashMap; +use std::sync::mpsc::{self, Receiver, Sender}; +use std::sync::{Arc, Mutex}; +use std::thread; +use super::super::super::EventHandler; +use super::{ + ShardId, + ShardManagerMessage, + ShardQueuer, + ShardQueuerMessage, + ShardRunnerInfo, +}; +use typemap::ShareMap; + +#[cfg(feature = "framework")] +use framework::Framework; + +pub struct ShardManager { + #[cfg(feature = "framework")] + runners: Arc<ParkingLotMutex<HashMap<ShardId, ShardRunnerInfo>>>, + /// The index of the first shard to initialize, 0-indexed. + shard_index: u64, + /// The number of shards to initialize. + shard_init: u64, + /// The total shards in use, 1-indexed. + shard_total: u64, + shard_queuer: Sender<ShardQueuerMessage>, + thread_rx: Receiver<ShardManagerMessage>, +} + +impl ShardManager { + #[cfg(feature = "framework")] + pub fn new<H>( + shard_index: u64, + shard_init: u64, + shard_total: u64, + ws_url: Arc<Mutex<String>>, + token: Arc<Mutex<String>>, + data: Arc<ParkingLotMutex<ShareMap>>, + event_handler: Arc<H>, + framework: Arc<Mutex<Option<Box<Framework + Send>>>>, + ) -> Self where H: EventHandler + Send + Sync + 'static { + let (thread_tx, thread_rx) = mpsc::channel(); + let (shard_queue_tx, shard_queue_rx) = mpsc::channel(); + + let runners = Arc::new(ParkingLotMutex::new(HashMap::new())); + + let mut shard_queuer = feature_framework! {{ + ShardQueuer { + data: data.clone(), + event_handler: event_handler.clone(), + framework: framework.clone(), + last_start: None, + manager_tx: thread_tx.clone(), + runners: runners.clone(), + rx: shard_queue_rx, + token: token.clone(), + ws_url: ws_url.clone(), + } + } else { + ShardQueuer { + data: data.clone(), + event_handler: event_handler.clone(), + last_start: None, + manager_tx: thread_tx.clone(), + runners: runners.clone(), + rx: shard_queue_rx, + rx: shard_queue_rx, + token: token.clone(), + ws_url: ws_url.clone(), + } + }}; + + thread::spawn(move || { + shard_queuer.run(); + }); + + Self { + shard_queuer: shard_queue_tx, + thread_rx: thread_rx, + runners, + shard_index, + shard_init, + shard_total, + } + } + + pub fn initialize(&mut self) -> Result<()> { + let shard_to = self.shard_index + self.shard_init; + + debug!("{}, {}", self.shard_index, self.shard_init); + + for shard_id in self.shard_index..shard_to { + let shard_total = self.shard_total; + + self.boot([ShardId(shard_id), ShardId(shard_total)]); + } + + Ok(()) + } + + pub fn run(&mut self) { + loop { + let value = match self.thread_rx.recv() { + Ok(value) => value, + Err(_) => break, + }; + + match value { + ShardManagerMessage::Restart(shard_id) => self.restart(shard_id), + ShardManagerMessage::Shutdown(shard_id) => self.shutdown(shard_id), + ShardManagerMessage::ShutdownAll => { + self.shutdown_all(); + + break; + }, + } + } + } + + pub fn shutdown_all(&mut self) { + info!("Shutting down all shards"); + let keys = { + self.runners.lock().keys().cloned().collect::<Vec<ShardId>>() + }; + + for shard_id in keys { + self.shutdown(shard_id); + } + } + + fn boot(&mut self, shard_info: [ShardId; 2]) { + info!("Telling shard queuer to start shard {}", shard_info[0]); + + let msg = ShardQueuerMessage::Start(shard_info[0], shard_info[1]); + let _ = self.shard_queuer.send(msg); + } + + fn restart(&mut self, shard_id: ShardId) { + info!("Restarting shard {}", shard_id); + self.shutdown(shard_id); + + let shard_total = self.shard_total; + + self.boot([shard_id, ShardId(shard_total)]); + } + + fn shutdown(&mut self, shard_id: ShardId) { + info!("Shutting down shard {}", shard_id); + + if let Some(runner) = self.runners.lock().get(&shard_id) { + let msg = ShardManagerMessage::Shutdown(shard_id); + + if let Err(why) = runner.runner_tx.send(msg) { + warn!("Failed to cleanly shutdown shard {}: {:?}", shard_id, why); + } + } + + self.runners.lock().remove(&shard_id); + } +} + +impl Drop for ShardManager { + fn drop(&mut self) { + if let Err(why) = self.shard_queuer.send(ShardQueuerMessage::Shutdown) { + warn!("Failed to send shutdown to shard queuer: {:?}", why); + } + } +} diff --git a/src/client/bridge/gateway/shard_queuer.rs b/src/client/bridge/gateway/shard_queuer.rs new file mode 100644 index 0000000..8d3dbe1 --- /dev/null +++ b/src/client/bridge/gateway/shard_queuer.rs @@ -0,0 +1,118 @@ +use framework::Framework; +use gateway::Shard; +use internal::prelude::*; +use parking_lot::Mutex as ParkingLotMutex; +use std::collections::HashMap; +use std::sync::mpsc::{Receiver, Sender}; +use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::{Duration, Instant}; +use super::super::super::EventHandler; +use super::{ + ShardId, + ShardManagerMessage, + ShardQueuerMessage, + ShardRunner, + ShardRunnerInfo, +}; +use typemap::ShareMap; + +/// The shard queuer is a simple loop that runs indefinitely to manage the +/// startup of shards. +/// +/// A shard queuer instance _should_ be run in its own thread, due to the +/// blocking nature of the loop itself as well as a 5 second thread sleep +/// between shard starts. +pub struct ShardQueuer<H: EventHandler + Send + Sync + 'static> { + pub data: Arc<ParkingLotMutex<ShareMap>>, + pub event_handler: Arc<H>, + #[cfg(feature = "framework")] + pub framework: Arc<Mutex<Option<Box<Framework + Send>>>>, + pub last_start: Option<Instant>, + pub manager_tx: Sender<ShardManagerMessage>, + pub runners: Arc<ParkingLotMutex<HashMap<ShardId, ShardRunnerInfo>>>, + pub rx: Receiver<ShardQueuerMessage>, + pub token: Arc<Mutex<String>>, + pub ws_url: Arc<Mutex<String>>, +} + +impl<H: EventHandler + Send + Sync + 'static> ShardQueuer<H> { + pub fn run(&mut self) { + loop { + let msg = match self.rx.recv() { + Ok(msg) => msg, + Err(_) => { + break; + } + }; + + match msg { + ShardQueuerMessage::Shutdown => break, + ShardQueuerMessage::Start(shard_id, shard_total) => { + self.check_last_start(); + + if let Err(why) = self.start(shard_id, shard_total) { + warn!("Err starting shard {}: {:?}", shard_id, why); + } + + self.last_start = Some(Instant::now()); + }, + } + } + } + + fn check_last_start(&mut self) { + let instant = match self.last_start { + Some(instant) => instant, + None => return, + }; + + // We must wait 5 seconds between IDENTIFYs to avoid session + // invalidations. + let duration = Duration::from_secs(5); + let elapsed = instant.elapsed(); + + if elapsed >= duration { + return; + } + + let to_sleep = duration - elapsed; + + thread::sleep(to_sleep); + } + + fn start(&mut self, shard_id: ShardId, shard_total: ShardId) -> Result<()> { + let shard_info = [shard_id.0, shard_total.0]; + let shard = Shard::new(self.ws_url.clone(), self.token.clone(), shard_info)?; + let locked = Arc::new(ParkingLotMutex::new(shard)); + + let mut runner = feature_framework! {{ + ShardRunner::new( + locked.clone(), + self.manager_tx.clone(), + self.framework.clone(), + self.data.clone(), + self.event_handler.clone(), + ) + } else { + ShardRunner::new( + locked.clone(), + self.manager_tx.clone(), + self.data.clone(), + self.event_handler.clone(), + ) + }}; + + let runner_info = ShardRunnerInfo { + runner_tx: runner.runner_tx(), + }; + + thread::spawn(move || { + let _ = runner.run(); + }); + + self.runners.lock().insert(shard_id, runner_info); + + Ok(()) + } +} diff --git a/src/client/bridge/gateway/shard_runner.rs b/src/client/bridge/gateway/shard_runner.rs new file mode 100644 index 0000000..8bdbb35 --- /dev/null +++ b/src/client/bridge/gateway/shard_runner.rs @@ -0,0 +1,171 @@ +use internal::prelude::*; +use internal::ws_impl::ReceiverExt; +use model::event::{Event, GatewayEvent}; +use parking_lot::Mutex as ParkingLotMutex; +use std::sync::mpsc::{self, Receiver, Sender}; +use std::sync::{Arc, Mutex}; +use super::super::super::{EventHandler, dispatch}; +use super::{LockedShard, ShardId, ShardManagerMessage}; +use typemap::ShareMap; +use websocket::WebSocketError; + +#[cfg(feature = "framework")] +use framework::Framework; + +enum EventRetrieval { + Some() +} + +pub struct ShardRunner<H: EventHandler + 'static> { + data: Arc<ParkingLotMutex<ShareMap>>, + event_handler: Arc<H>, + #[cfg(feature = "framework")] + framework: Arc<Mutex<Option<Box<Framework + Send>>>>, + manager_tx: Sender<ShardManagerMessage>, + runner_rx: Receiver<ShardManagerMessage>, + runner_tx: Sender<ShardManagerMessage>, + shard: LockedShard, +} + +impl<H: EventHandler + 'static> ShardRunner<H> { + pub fn new(shard: LockedShard, + manager_tx: Sender<ShardManagerMessage>, + framework: Arc<Mutex<Option<Box<Framework + Send>>>>, + data: Arc<ParkingLotMutex<ShareMap>>, + event_handler: Arc<H>) -> Self { + let (tx, rx) = mpsc::channel(); + + Self { + runner_rx: rx, + runner_tx: tx, + data, + event_handler, + framework, + manager_tx, + shard, + } + } + + pub fn run(&mut self) -> Result<()> { + loop { + { + let mut shard = self.shard.lock(); + let incoming = self.runner_rx.try_recv(); + + // Check for an incoming message over the runner channel. + // + // If the message is to shutdown, first verify the ID so we know + // for certain this runner is to shutdown. + if let Ok(ShardManagerMessage::Shutdown(id)) = incoming { + if id.0 == shard.shard_info()[0] { + let _ = shard.shutdown_clean(); + + return Ok(()); + } + } + + if let Err(why) = shard.check_heartbeat() { + error!("Failed to heartbeat and reconnect: {:?}", why); + + let msg = ShardManagerMessage::Restart(ShardId(shard.shard_info()[0])); + let _ = self.manager_tx.send(msg); + + return Ok(()); + } + + #[cfg(feature = "voice")] + { + shard.cycle_voice_recv(); + } + } + + let events = self.recv_events(); + + for event in events { + feature_framework! {{ + dispatch(event, + &self.shard, + &self.framework, + &self.data, + &self.event_handler); + } else { + dispatch(event, + &info.shard, + &info.data, + &info.event_handler, + &handle); + }} + } + } + } + + pub(super) fn runner_tx(&self) -> Sender<ShardManagerMessage> { + self.runner_tx.clone() + } + + fn recv_events(&mut self) -> Vec<Event> { + let mut shard = self.shard.lock(); + + let mut events = vec![]; + + loop { + let gw_event = match shard.client.recv_json(GatewayEvent::decode) { + Err(Error::WebSocket(WebSocketError::IoError(_))) => { + // Check that an amount of time at least double the + // heartbeat_interval has passed. + // + // If not, continue on trying to receive messages. + // + // If it has, attempt to auto-reconnect. + let last = shard.last_heartbeat_ack(); + let interval = shard.heartbeat_interval(); + + if let (Some(last_heartbeat_ack), Some(interval)) = (last, interval) { + let seconds_passed = last_heartbeat_ack.elapsed().as_secs(); + let interval_in_secs = interval / 1000; + + if seconds_passed <= interval_in_secs * 2 { + break; + } + } else { + break; + } + + debug!("Attempting to auto-reconnect"); + + if let Err(why) = shard.autoreconnect() { + error!("Failed to auto-reconnect: {:?}", why); + } + + break; + }, + Err(Error::WebSocket(WebSocketError::NoDataAvailable)) => break, + other => other, + }; + + let event = match gw_event { + Ok(Some(event)) => Ok(event), + Ok(None) => break, + Err(why) => Err(why), + }; + + let event = match shard.handle_event(event) { + Ok(Some(event)) => event, + Ok(None) => continue, + Err(why) => { + error!("Shard handler received err: {:?}", why); + + continue; + }, + }; + + events.push(event); + + if events.len() > 5 { + break; + } + }; + + events + } +} diff --git a/src/client/bridge/mod.rs b/src/client/bridge/mod.rs new file mode 100644 index 0000000..4f27526 --- /dev/null +++ b/src/client/bridge/mod.rs @@ -0,0 +1 @@ +pub mod gateway; |