diff options
Diffstat (limited to 'src/client/mod.rs')
| -rw-r--r-- | src/client/mod.rs | 184 |
1 files changed, 109 insertions, 75 deletions
diff --git a/src/client/mod.rs b/src/client/mod.rs index 6711287..a227920 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -37,13 +37,11 @@ pub use http as rest; #[cfg(feature = "cache")] pub use CACHE; -use self::bridge::gateway::{ShardId, ShardManager, ShardRunnerInfo}; +use self::bridge::gateway::{ShardManager, ShardManagerMonitor}; use self::dispatch::dispatch; -use std::sync::{self, Arc}; +use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering, ATOMIC_BOOL_INIT}; use parking_lot::Mutex; -use std::collections::HashMap; -use std::mem; use threadpool::ThreadPool; use typemap::ShareMap; use http; @@ -103,8 +101,7 @@ impl CloseHandle { /// [`on_message`]: #method.on_message /// [`Event::MessageCreate`]: ../model/event/enum.Event.html#variant.MessageCreate /// [sharding docs]: gateway/index.html#sharding -#[derive(Clone)] -pub struct Client<H: EventHandler + Send + Sync + 'static> { +pub struct Client { /// A ShareMap which requires types to be Send + Sync. This is a map that /// can be safely shared across contexts. /// @@ -191,8 +188,7 @@ pub struct Client<H: EventHandler + Send + Sync + 'static> { /// /// [`Event::Ready`]: ../model/event/enum.Event.html#variant.Ready /// [`on_ready`]: #method.on_ready - event_handler: Arc<H>, - #[cfg(feature = "framework")] framework: Arc<sync::Mutex<Option<Box<Framework + Send>>>>, + #[cfg(feature = "framework")] framework: Arc<Mutex<Option<Box<Framework + Send>>>>, /// A HashMap of all shards instantiated by the Client. /// /// The key is the shard ID and the value is the shard itself. @@ -222,14 +218,14 @@ pub struct Client<H: EventHandler + Send + Sync + 'static> { /// /// impl EventHandler for Handler { } /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler)?; /// - /// let shard_runners = client.shard_runners.clone(); + /// let shard_manager = client.shard_manager.clone(); /// /// thread::spawn(move || { /// loop { /// println!("Shard count instantiated: {}", - /// shard_runners.lock().len()); + /// shard_manager.lock().shards_instantiated().len()); /// /// thread::sleep(Duration::from_millis(5000)); /// } @@ -244,16 +240,26 @@ pub struct Client<H: EventHandler + Send + Sync + 'static> { /// /// [`Client::start_shard`]: #method.start_shard /// [`Client::start_shards`]: #method.start_shards - pub shard_runners: Arc<Mutex<HashMap<ShardId, ShardRunnerInfo>>>, + pub shard_manager: Arc<Mutex<ShardManager>>, + shard_manager_worker: ShardManagerMonitor, /// The threadpool shared by all shards. /// /// Defaults to 5 threads, which should suffice small bots. Consider /// increasing this number as your bot grows. pub threadpool: ThreadPool, - token: Arc<sync::Mutex<String>>, + /// The token in use by the client. + pub token: Arc<Mutex<String>>, + /// URI that the client's shards will use to connect to the gateway. + /// + /// This is likely not important for production usage and is, at best, used + /// for debugging. + /// + /// This is wrapped in an `Arc<Mutex<T>>` so all shards will have an updated + /// value available. + pub ws_uri: Arc<Mutex<String>>, } -impl<H: EventHandler + Send + Sync + 'static> Client<H> { +impl Client { /// Creates a Client for a bot user. /// /// Discord has a requirement of prefixing bot tokens with `"Bot "`, which @@ -275,7 +281,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// use std::env; /// /// let token = env::var("DISCORD_TOKEN")?; - /// let client = Client::new(&token, Handler); + /// let client = Client::new(&token, Handler)?; /// # Ok(()) /// # } /// # @@ -283,7 +289,8 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// # try_main().unwrap(); /// # } /// ``` - pub fn new(token: &str, handler: H) -> Self { + pub fn new<H>(token: &str, handler: H) -> Result<Self> + where H: EventHandler + Send + Sync + 'static { let token = if token.starts_with("Bot ") { token.to_string() } else { @@ -291,29 +298,59 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { }; http::set_token(&token); - let locked = Arc::new(sync::Mutex::new(token)); + let locked = Arc::new(Mutex::new(token)); let name = "serenity client".to_owned(); let threadpool = ThreadPool::with_name(name, 5); + let url = Arc::new(Mutex::new(http::get_gateway()?.url)); + let data = Arc::new(Mutex::new(ShareMap::custom())); + let event_handler = Arc::new(handler); + + Ok(feature_framework! {{ + let framework = Arc::new(Mutex::new(None)); + + let (shard_manager, shard_manager_worker) = ShardManager::new( + 0, + 0, + 0, + Arc::clone(&url), + Arc::clone(&locked), + Arc::clone(&data), + Arc::clone(&event_handler), + Arc::clone(&framework), + threadpool.clone(), + ); - feature_framework! {{ Client { - data: Arc::new(Mutex::new(ShareMap::custom())), - event_handler: Arc::new(handler), - framework: Arc::new(sync::Mutex::new(None)), - shard_runners: Arc::new(Mutex::new(HashMap::new())), - threadpool, token: locked, + ws_uri: url, + framework, + data, + shard_manager, + shard_manager_worker, + threadpool, } } else { + let (shard_manager, shard_manager_worker) = ShardManager::new( + 0, + 0, + 0, + Arc::clone(&url), + locked.clone(), + data.clone(), + Arc::clone(&event_handler), + threadpool.clone(), + ); + Client { - data: Arc::new(Mutex::new(ShareMap::custom())), - event_handler: Arc::new(handler), - shard_runners: Arc::new(Mutex::new(HashMap::new())), - threadpool, token: locked, + ws_uri: url, + data, + shard_manager, + shard_manager_worker, + threadpool, } - }} + }}) } /// Sets a framework to be used with the client. All message events will be @@ -340,7 +377,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// use serenity::Client; /// use std::env; /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler)?; /// client.with_framework(StandardFramework::new() /// .configure(|c| c.prefix("~")) /// .command("ping", |c| c.exec_str("Pong!"))); @@ -392,12 +429,11 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// /// impl EventHandler for Handler {} /// - /// /// # fn try_main() -> Result<(), Box<Error>> { /// use serenity::Client; /// use std::env; /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let mut client = Client::new(&token, Handler).unwrap(); /// client.with_framework(MyFramework { commands: { /// let mut map = HashMap::new(); /// map.insert("ping".to_string(), Box::new(|msg, _| msg.channel_id.say("pong!"))); @@ -417,7 +453,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// [framework docs]: ../framework/index.html #[cfg(feature = "framework")] pub fn with_framework<F: Framework + Send + 'static>(&mut self, f: F) { - self.framework = Arc::new(sync::Mutex::new(Some(Box::new(f)))); + *self.framework.lock() = Some(Box::new(f)); } /// Establish the connection and start listening for events. @@ -447,7 +483,8 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// use serenity::client::Client; /// use std::env; /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let token = env::var("DISCORD_TOKEN")?; + /// let mut client = Client::new(&token, Handler).unwrap(); /// /// if let Err(why) = client.start() { /// println!("Err with client: {:?}", why); @@ -462,7 +499,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// /// [gateway docs]: gateway/index.html#sharding pub fn start(&mut self) -> Result<()> { - self.start_connection([0, 0, 1], http::get_gateway()?.url) + self.start_connection([0, 0, 1]) } /// Establish the connection(s) and start listening for events. @@ -492,7 +529,8 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// use serenity::client::Client; /// use std::env; /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let token = env::var("DISCORD_TOKEN")?; + /// let mut client = Client::new(&token, Handler).unwrap(); /// /// if let Err(why) = client.start_autosharded() { /// println!("Err with client: {:?}", why); @@ -513,15 +551,13 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// [`ClientError::Shutdown`]: enum.ClientError.html#variant.Shutdown /// [gateway docs]: gateway/index.html#sharding pub fn start_autosharded(&mut self) -> Result<()> { - let mut res = http::get_bot_gateway()?; - - let x = res.shards as u64 - 1; - let y = res.shards as u64; - let url = mem::replace(&mut res.url, String::default()); + let (x, y) = { + let res = http::get_bot_gateway()?; - drop(res); + (res.shards as u64 - 1, res.shards as u64) + }; - self.start_connection([0, x, y], url) + self.start_connection([0, x, y]) } /// Establish a sharded connection and start listening for events. @@ -551,7 +587,8 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// use serenity::client::Client; /// use std::env; /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let token = env::var("DISCORD_TOKEN")?; + /// let mut client = Client::new(&token, Handler).unwrap(); /// /// if let Err(why) = client.start_shard(3, 5) { /// println!("Err with client: {:?}", why); @@ -578,7 +615,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// use serenity::client::Client; /// use std::env; /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler)?; /// /// if let Err(why) = client.start_shard(0, 1) { /// println!("Err with client: {:?}", why); @@ -601,7 +638,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// [`start_autosharded`]: #method.start_autosharded /// [gateway docs]: gateway/index.html#sharding pub fn start_shard(&mut self, shard: u64, shards: u64) -> Result<()> { - self.start_connection([shard, shard, shards], http::get_gateway()?.url) + self.start_connection([shard, shard, shards]) } /// Establish sharded connections and start listening for events. @@ -631,7 +668,8 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// use serenity::client::Client; /// use std::env; /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let token = env::var("DISCORD_TOKEN")?; + /// let mut client = Client::new(&token, Handler).unwrap(); /// /// if let Err(why) = client.start_shards(8) { /// println!("Err with client: {:?}", why); @@ -654,10 +692,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// [`start_shard_range`]: #method.start_shard_range /// [Gateway docs]: gateway/index.html#sharding pub fn start_shards(&mut self, total_shards: u64) -> Result<()> { - self.start_connection( - [0, total_shards - 1, total_shards], - http::get_gateway()?.url, - ) + self.start_connection([0, total_shards - 1, total_shards]) } /// Establish a range of sharded connections and start listening for events. @@ -702,7 +737,8 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// use serenity::client::Client; /// use std::env; /// - /// let mut client = Client::new(&env::var("DISCORD_TOKEN")?, Handler); + /// let token = env::var("DISCORD_TOKEN")?; + /// let mut client = Client::new(&token, Handler).unwrap(); /// /// if let Err(why) = client.start_shard_range([4, 7], 10) { /// println!("Err with client: {:?}", why); @@ -726,7 +762,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { /// [`start_shards`]: #method.start_shards /// [Gateway docs]: gateway/index.html#sharding pub fn start_shard_range(&mut self, range: [u64; 2], total_shards: u64) -> Result<()> { - self.start_connection([range[0], range[1], total_shards], http::get_gateway()?.url) + self.start_connection([range[0], range[1], total_shards]) } /// Returns a thread-safe handle for closing shards. @@ -745,7 +781,7 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { // an error. // // [`ClientError::Shutdown`]: enum.ClientError.html#variant.Shutdown - fn start_connection(&mut self, shard_data: [u64; 3], url: String) -> Result<()> { + fn start_connection(&mut self, shard_data: [u64; 3]) -> Result<()> { HANDLE_STILL.store(true, Ordering::Relaxed); // Update the framework's current user if the feature is enabled. @@ -755,44 +791,42 @@ impl<H: EventHandler + Send + Sync + 'static> Client<H> { { let user = http::get_current_user()?; - if let Some(ref mut framework) = *self.framework.lock().unwrap() { + if let Some(ref mut framework) = *self.framework.lock() { framework.update_current_user(user.id, user.bot); } } - let gateway_url = Arc::new(sync::Mutex::new(url)); + { + let mut manager = self.shard_manager.lock(); + + let init = shard_data[1] - shard_data[0] + 1; - let mut manager = ShardManager::new( - shard_data[0], - shard_data[1] - shard_data[0] + 1, - shard_data[2], - Arc::clone(&gateway_url), - Arc::clone(&self.token), - Arc::clone(&self.data), - Arc::clone(&self.event_handler), - #[cfg(feature = "framework")] - Arc::clone(&self.framework), - self.threadpool.clone(), - ); + manager.set_shards(shard_data[0], init, shard_data[2]); - self.shard_runners = Arc::clone(&manager.runners); + debug!( + "Initializing shard info: {} - {}/{}", + shard_data[0], + init, + shard_data[2], + ); - if let Err(why) = manager.initialize() { - error!("Failed to boot a shard: {:?}", why); - info!("Shutting down all shards"); + if let Err(why) = manager.initialize() { + error!("Failed to boot a shard: {:?}", why); + info!("Shutting down all shards"); - manager.shutdown_all(); + manager.shutdown_all(); - return Err(Error::Client(ClientError::ShardBootFailure)); + return Err(Error::Client(ClientError::ShardBootFailure)); + } } - manager.run(); + self.shard_manager_worker.run(); Err(Error::Client(ClientError::Shutdown)) } } -impl<H: EventHandler + Send + Sync + 'static> Drop for Client<H> { +impl Drop for Client { fn drop(&mut self) { self.close_handle().close(); } } |