aboutsummaryrefslogtreecommitdiff
path: root/src/ext/framework/mod.rs
diff options
context:
space:
mode:
authorIllia <[email protected]>2016-12-13 21:26:29 +0200
committerzeyla <[email protected]>2016-12-13 11:26:29 -0800
commitdaf92eda815b8f539f6d759ab48cf7a70513915f (patch)
tree36145f5095e7af6fb725635dd104e9d9d3f0ea62 /src/ext/framework/mod.rs
parentFix readme typo (diff)
downloadserenity-daf92eda815b8f539f6d759ab48cf7a70513915f.tar.xz
serenity-daf92eda815b8f539f6d759ab48cf7a70513915f.zip
Implement command groups and buckets
* Implement command groups * change to ref mut * Implement framework API. * Remove commands field * Make it all work * Make example use command groups * Requested changes * Implement adding buckets * Add ratelimit check function * Finish everything * Fix voice example * Actually fix it * Fix doc tests * Switch to result * Savage examples * Fix docs * Fixes * Accidental push * 👀 * Fix an example * fix some example * Small cleanup * Abstract ratelimit bucket logic
Diffstat (limited to 'src/ext/framework/mod.rs')
-rw-r--r--src/ext/framework/mod.rs289
1 files changed, 217 insertions, 72 deletions
diff --git a/src/ext/framework/mod.rs b/src/ext/framework/mod.rs
index fa7fdac..6cd222f 100644
--- a/src/ext/framework/mod.rs
+++ b/src/ext/framework/mod.rs
@@ -30,7 +30,7 @@
//! Configuring a Client with a framework, which has a prefix of `"~"` and a
//! ping and about command:
//!
-//! ```rust,no_run
+//! ```rust,ignore
//! use serenity::client::{Client, Context};
//! use serenity::model::Message;
//! use std::env;
@@ -42,34 +42,39 @@
//! .command("about", |c| c.exec_str("A simple test bot"))
//! .command("ping", |c| c.exec(ping)));
//!
-//! fn about(context: &Context, _message: &Message, _args: Vec<String>) {
+//! command!(about(context) {
//! let _ = context.say("A simple test bot");
-//! }
+//! });
//!
-//! fn ping(context: &Context, _message: &Message, _args: Vec<String>) {
+//! command!(ping(context) {
//! let _ = context.say("Pong!");
-//! }
+//! });
//! ```
//!
//! [`Client::with_framework`]: ../../client/struct.Client.html#method.with_framework
+pub mod help_commands;
+
mod command;
mod configuration;
mod create_command;
+mod create_group;
+mod buckets;
-pub use self::command::{Command, CommandType};
+pub use self::command::{Command, CommandType, CommandGroup};
pub use self::configuration::{AccountType, Configuration};
pub use self::create_command::CreateCommand;
+pub use self::create_group::CreateGroup;
+pub use self::buckets::{Bucket, MemberRatelimit, Ratelimit};
-use self::command::{Hook, InternalCommand};
+use self::command::{AfterHook, Hook};
use std::collections::HashMap;
+use std::default::Default;
use std::sync::Arc;
use std::thread;
-use ::client::Context;
+use ::client::{CACHE, Context};
use ::model::Message;
use ::utils;
-use ::client::CACHE;
-use ::model::Permissions;
/// A macro to generate "named parameters". This is useful to avoid manually
/// using the "arguments" parameter and manually parsing types.
@@ -104,28 +109,46 @@ use ::model::Permissions;
/// [`Framework`]: ext/framework/index.html
#[macro_export]
macro_rules! command {
+ ($fname:ident($c:ident) $b:block) => {
+ pub fn $fname($c: &Context, _: &Message, _: Vec<String>) -> Result<(), String> {
+ $b
+
+ Ok(())
+ }
+ };
+ ($fname:ident($c:ident, $m:ident) $b:block) => {
+ pub fn $fname($c: &Context, $m: &Message, _: Vec<String>) -> Result<(), String> {
+ $b
+
+ Ok(())
+ }
+ };
($fname:ident($c:ident, $m:ident, $a:ident) $b:block) => {
- pub fn $fname($c: &Context, $m: &Message, $a: Vec<String>) {
+ pub fn $fname($c: &Context, $m: &Message, $a: Vec<String>) -> Result<(), String> {
$b
+
+ Ok(())
}
};
($fname:ident($c:ident, $m:ident, $a:ident, $($name:ident: $t:ty),*) $b:block) => {
- pub fn $fname($c: &Context, $m: &Message, $a: Vec<String>) {
+ pub fn $fname($c: &Context, $m: &Message, $a: Vec<String>) -> Result<(), String> {
let mut i = $a.iter();
$(
let $name = match i.next() {
Some(v) => match v.parse::<$t>() {
Ok(v) => v,
- Err(_why) => return,
+ Err(_) => return Err(format!("Failed to parse {:?}", stringify!($t))),
},
- None => return,
+ None => return Err(format!("Failed to parse {:?}", stringify!($t))),
};
)*
drop(i);
$b
+
+ Ok(())
}
};
}
@@ -139,9 +162,10 @@ macro_rules! command {
#[derive(Default)]
pub struct Framework {
configuration: Configuration,
- commands: HashMap<String, InternalCommand>,
+ groups: HashMap<String, Arc<CommandGroup>>,
before: Option<Arc<Hook>>,
- after: Option<Arc<Hook>>,
+ buckets: HashMap<String, Bucket>,
+ after: Option<Arc<AfterHook>>,
/// Whether the framework has been "initialized".
///
/// The framework is initialized once one of the following occurs:
@@ -194,6 +218,44 @@ impl Framework {
self
}
+ /// Defines a bucket with `delay` between each command, and the `limit` of uses
+ /// per `time_span`.
+ pub fn bucket<S>(mut self, s: S, delay: i64, time_span: i64, limit: i32) -> Self
+ where S: Into<String> {
+ self.buckets.insert(s.into(), Bucket {
+ ratelimit: Ratelimit {
+ delay: delay,
+ limit: Some((time_span, limit))
+ },
+ limits: HashMap::new()
+ });
+
+ self
+ }
+
+ /// Defines a bucket just with `delay` between each command.
+ pub fn simple_bucket<S>(mut self, s: S, delay: i64) -> Self
+ where S: Into<String> {
+ self.buckets.insert(s.into(), Bucket {
+ ratelimit: Ratelimit {
+ delay: delay,
+ limit: None
+ },
+ limits: HashMap::new()
+ });
+
+ self
+ }
+
+ #[allow(map_entry)]
+ fn ratelimit_time(&mut self, bucket_name: &str, user_id: u64) -> i64 {
+ self.buckets
+ .get_mut(bucket_name)
+ .map(|bucket| bucket.take(user_id))
+ .unwrap_or(0)
+ }
+
+ #[allow(cyclomatic_complexity)]
#[doc(hidden)]
pub fn dispatch(&mut self, context: Context, message: Message) {
match self.configuration.account_type {
@@ -209,6 +271,7 @@ impl Framework {
},
AccountType::Automatic => {
let cache = CACHE.read().unwrap();
+
if cache.user.bot {
if message.author.bot {
return;
@@ -254,31 +317,64 @@ impl Framework {
None => continue,
});
- if let Some(command) = self.commands.get(&built) {
- if message.is_private() {
- if command.guild_only {
- return;
+ let groups = self.groups.clone();
+
+ for group in groups.values() {
+ let to_check = if let Some(ref prefix) = group.prefix {
+ if built.starts_with(prefix) && built.len() > prefix.len() + 1 {
+ built[(prefix.len() + 1)..].to_owned()
+ } else {
+ continue;
}
- } else if command.dm_only {
- return;
- }
+ } else {
+ built.clone()
+ };
+
+ if let Some(command) = group.commands.get(&to_check) {
+ if let Some(ref bucket_name) = command.bucket {
+ let rate_limit = self.ratelimit_time(bucket_name, message.author.id.0);
+
+ if rate_limit > 0 {
+ if let Some(ref message) = self.configuration.rate_limit_message {
+ let _ = context.say(
+ &message.replace("%time%", &rate_limit.to_string()));
+ }
- for check in &command.checks {
- if !(check)(&context, &message) {
- continue 'outer;
+ return;
+ }
}
- }
- let before = self.before.clone();
- let command = command.clone();
- let after = self.after.clone();
- let commands = self.commands.clone();
+ if message.is_private() {
+ if command.guild_only {
+ if let Some(ref message) = self.configuration.no_guild_message {
+ let _ = context.say(message);
+ }
+
+ return;
+ }
+ } else if command.dm_only {
+ if let Some(ref message) = self.configuration.no_dm_message {
+ let _ = context.say(message);
+ }
- thread::spawn(move || {
- if let Some(before) = before {
- (before)(&context, &message, &built);
+ return;
}
+ for check in &command.checks {
+ if !(check)(&context, &message) {
+ if let Some(ref message) = self.configuration.invalid_check_message {
+ let _ = context.say(message);
+ }
+
+ continue 'outer;
+ }
+ }
+
+ let before = self.before.clone();
+ let command = command.clone();
+ let after = self.after.clone();
+ let groups = self.groups.clone();
+
let args = if command.use_quotes {
utils::parse_quotes(&message.content[position + built.len()..])
} else {
@@ -290,12 +386,24 @@ impl Framework {
if let Some(x) = command.min_args {
if args.len() < x as usize {
+ if let Some(ref message) = self.configuration.not_enough_args_message {
+ let _ = context.say(
+ &message.replace("%min%", &x.to_string())
+ .replace("%given%", &args.len().to_string()));
+ }
+
return;
}
}
if let Some(x) = command.max_args {
if args.len() > x as usize {
+ if let Some(ref message) = self.configuration.too_many_args_message {
+ let _ = context.say(
+ &message.replace("%max%", &x.to_string())
+ .replace("%given%", &args.len().to_string()));
+ }
+
return;
}
}
@@ -316,28 +424,40 @@ impl Framework {
}
if !permissions_fulfilled {
+ if let Some(ref message) = self.configuration.invalid_permission_message {
+ let _ = context.say(message);
+ }
+
return;
}
}
- match command.exec {
- CommandType::StringResponse(ref x) => {
- let _ = &context.say(x);
- },
- CommandType::Basic(ref x) => {
- (x)(&context, &message, args);
- },
- CommandType::WithCommands(ref x) => {
- (x)(&context, &message, commands, args);
+ thread::spawn(move || {
+ if let Some(before) = before {
+ (before)(&context, &message, &built);
}
- }
- if let Some(after) = after {
- (after)(&context, &message, &built);
- }
- });
+ let result = match command.exec {
+ CommandType::StringResponse(ref x) => {
+ let _ = &context.say(x);
+
+ Ok(())
+ },
+ CommandType::Basic(ref x) => {
+ (x)(&context, &message, args)
+ },
+ CommandType::WithCommands(ref x) => {
+ (x)(&context, &message, groups, args)
+ }
+ };
- return;
+ if let Some(after) = after {
+ (after)(&context, &message, &built, result);
+ }
+ });
+
+ return;
+ }
}
}
}
@@ -359,21 +479,17 @@ impl Framework {
/// [`command`]: #method.command
/// [module-level documentation]: index.html
pub fn on<F, S>(mut self, command_name: S, f: F) -> Self
- where F: Fn(&Context, &Message, Vec<String>) + Send + Sync + 'static,
+ where F: Fn(&Context, &Message, Vec<String>) -> Result<(), String> + Send + Sync + 'static,
S: Into<String> {
- self.commands.insert(command_name.into(), Arc::new(Command {
- checks: Vec::default(),
- exec: CommandType::Basic(Box::new(f)),
- desc: None,
- usage: None,
- use_quotes: false,
- dm_only: false,
- guild_only: false,
- help_available: true,
- min_args: None,
- max_args: None,
- required_permissions: Permissions::empty()
- }));
+ if !self.groups.contains_key("Ungrouped") {
+ self.groups.insert("Ungrouped".to_string(), Arc::new(CommandGroup::default()));
+ }
+
+ if let Some(ref mut x) = self.groups.get_mut("Ungrouped") {
+ if let Some(ref mut y) = Arc::get_mut(x) {
+ y.commands.insert(command_name.into(), Arc::new(Command::new(f)));
+ }
+ }
self.initialized = true;
@@ -395,8 +511,28 @@ impl Framework {
where F: FnOnce(CreateCommand) -> CreateCommand,
S: Into<String> {
let cmd = f(CreateCommand(Command::default())).0;
- self.commands.insert(command_name.into(), Arc::new(cmd));
+ if !self.groups.contains_key("Ungrouped") {
+ self.groups.insert("Ungrouped".to_string(), Arc::new(CommandGroup::default()));
+ }
+
+ if let Some(ref mut x) = self.groups.get_mut("Ungrouped") {
+ if let Some(ref mut y) = Arc::get_mut(x) {
+ y.commands.insert(command_name.into(), Arc::new(cmd));
+ }
+ }
+
+ self.initialized = true;
+
+ self
+ }
+
+ pub fn group<F, S>(mut self, group_name: S, f: F) -> Self
+ where F: FnOnce(CreateGroup) -> CreateGroup,
+ S: Into<String> {
+ let group = f(CreateGroup(CommandGroup::default())).0;
+
+ self.groups.insert(group_name.into(), Arc::new(group));
self.initialized = true;
self
@@ -411,8 +547,9 @@ impl Framework {
}
/// Specify the function to be called after every command's execution.
+ /// Fourth argument exists if command returned an error which you can handle.
pub fn after<F>(mut self, f: F) -> Self
- where F: Fn(&Context, &Message, &String) + Send + Sync + 'static {
+ where F: Fn(&Context, &Message, &String, Result<(), String>) + Send + Sync + 'static {
self.after = Some(Arc::new(f));
self
@@ -426,7 +563,7 @@ impl Framework {
/// Ensure that the user who created a message, calling a "ping" command,
/// is the owner.
///
- /// ```rust,no_run
+ /// ```rust,ignore
/// use serenity::client::{Client, Context};
/// use serenity::model::Message;
/// use std::env;
@@ -438,9 +575,9 @@ impl Framework {
/// .on("ping", ping)
/// .set_check("ping", owner_check));
///
- /// fn ping(context: &Context, _message: &Message, _args: Vec<String>) {
- /// context.say("Pong!");
- /// }
+ /// command!(ping(context) {
+ /// let _ = context.say("Pong!");
+ /// });
///
/// fn owner_check(_context: &Context, message: &Message) -> bool {
/// // replace with your user ID
@@ -451,9 +588,17 @@ impl Framework {
pub fn set_check<F, S>(mut self, command: S, check: F) -> Self
where F: Fn(&Context, &Message) -> bool + Send + Sync + 'static,
S: Into<String> {
- if let Some(command) = self.commands.get_mut(&command.into()) {
- if let Some(c) = Arc::get_mut(command) {
- c.checks.push(Box::new(check));
+ if !self.groups.contains_key("Ungrouped") {
+ self.groups.insert("Ungrouped".to_string(), Arc::new(CommandGroup::default()));
+ }
+
+ if let Some(ref mut group) = self.groups.get_mut("Ungrouped") {
+ if let Some(group_mut) = Arc::get_mut(group) {
+ if let Some(ref mut command) = group_mut.commands.get_mut(&command.into()) {
+ if let Some(c) = Arc::get_mut(command) {
+ c.checks.push(Box::new(check));
+ }
+ }
}
}