diff options
| author | Lakelezz <[email protected]> | 2018-11-19 23:27:34 +0100 |
|---|---|---|
| committer | GitHub <[email protected]> | 2018-11-19 23:27:34 +0100 |
| commit | 1f71892c8714eff8b69908989f91acb32391acea (patch) | |
| tree | 1d00a1a16d68f322bb55b155e3eda90f94d2ccbf /src/http | |
| parent | Remove everything marked `deprecated` since `v0.5.x` or older (#441) (diff) | |
| download | serenity-1f71892c8714eff8b69908989f91acb32391acea.tar.xz serenity-1f71892c8714eff8b69908989f91acb32391acea.zip | |
Replace `hyper` with `reqwest` (#440)
Diffstat (limited to 'src/http')
| -rw-r--r-- | src/http/error.rs | 34 | ||||
| -rw-r--r-- | src/http/mod.rs | 31 | ||||
| -rw-r--r-- | src/http/ratelimiting.rs | 43 | ||||
| -rw-r--r-- | src/http/raw.rs | 129 | ||||
| -rw-r--r-- | src/http/request.rs | 32 | ||||
| -rw-r--r-- | src/http/routing.rs | 2 |
6 files changed, 150 insertions, 121 deletions
diff --git a/src/http/error.rs b/src/http/error.rs index 76a1d4f..b62ce15 100644 --- a/src/http/error.rs +++ b/src/http/error.rs @@ -1,4 +1,9 @@ -use hyper::client::Response; +use reqwest::{ + Error as ReqwestError, + header::InvalidHeaderValue, + Response, + UrlError +}; use std::{ error::Error as StdError, fmt::{ @@ -18,6 +23,30 @@ pub enum Error { /// When the decoding of a ratelimit header could not be properly decoded /// from UTF-8. RateLimitUtf8, + /// When parsing an URL failed due to invalid input. + Url(UrlError), + /// Header value contains invalid input. + InvalidHeader(InvalidHeaderValue), + /// Reqwest's Error contain information on why sending a request failed. + Request(ReqwestError), +} + +impl From<ReqwestError> for Error { + fn from(error: ReqwestError) -> Error { + Error::Request(error) + } +} + +impl From<UrlError> for Error { + fn from(error: UrlError) -> Error { + Error::Url(error) + } +} + +impl From<InvalidHeaderValue> for Error { + fn from(error: InvalidHeaderValue) -> Error { + Error::InvalidHeader(error) + } } impl Display for Error { @@ -30,6 +59,9 @@ impl StdError for Error { Error::UnsuccessfulRequest(_) => "A non-successful response status code was received", Error::RateLimitI64 => "Error decoding a header into an i64", Error::RateLimitUtf8 => "Error decoding a header from UTF-8", + Error::Url(_) => "Provided URL is incorrect.", + Error::InvalidHeader(_) => "Provided value is an invalid header value.", + Error::Request(_) => "Error while sending HTTP request.", } } } diff --git a/src/http/mod.rs b/src/http/mod.rs index e7e2bea..8093843 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -30,16 +30,14 @@ pub mod routing; mod error; -pub use hyper::status::{StatusClass, StatusCode}; +pub use reqwest::StatusCode; pub use self::error::Error as HttpError; pub use self::raw::*; -use hyper::{ - client::Client as HyperClient, - method::Method, - net::HttpsConnector, +use reqwest::{ + Client as ReqwestClient, + Method, }; -use hyper_native_tls::NativeTlsClient; use model::prelude::*; use parking_lot::Mutex; use self::{request::Request}; @@ -51,17 +49,12 @@ use std::{ }; lazy_static! { - static ref CLIENT: HyperClient = { - let tc = NativeTlsClient::new().expect("Unable to make http client"); - let connector = HttpsConnector::new(tc); - - HyperClient::with_connector(connector) - }; + static ref CLIENT: ReqwestClient = ReqwestClient::new(); } /// An method used for ratelimiting special routes. /// -/// This is needed because `hyper`'s `Method` enum does not derive Copy. +/// This is needed because `reqwest`'s `Method` enum does not derive Copy. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum LightMethod { /// Indicates that a route is for the `DELETE` method only. @@ -77,13 +70,13 @@ pub enum LightMethod { } impl LightMethod { - pub fn hyper_method(&self) -> Method { + pub fn reqwest_method(&self) -> Method { match *self { - LightMethod::Delete => Method::Delete, - LightMethod::Get => Method::Get, - LightMethod::Patch => Method::Patch, - LightMethod::Post => Method::Post, - LightMethod::Put => Method::Put, + LightMethod::Delete => Method::DELETE, + LightMethod::Get => Method::GET, + LightMethod::Patch => Method::PATCH, + LightMethod::Post => Method::POST, + LightMethod::Put => Method::PUT, } } } diff --git a/src/http/ratelimiting.rs b/src/http/ratelimiting.rs index 7713b27..c2c2bba 100644 --- a/src/http/ratelimiting.rs +++ b/src/http/ratelimiting.rs @@ -43,9 +43,11 @@ pub use super::routing::Route; use chrono::{DateTime, Utc}; -use hyper::client::Response; -use hyper::header::Headers; -use hyper::status::StatusCode; +use reqwest::{ + Response, + header::HeaderMap as Headers, + StatusCode, +}; use internal::prelude::*; use parking_lot::Mutex; use std::{ @@ -54,7 +56,7 @@ use std::{ time::Duration, str, thread, - i64 + i64, }; use super::{HttpError, Request}; @@ -152,7 +154,7 @@ pub(super) fn perform(req: Request) -> Result<Response> { // This should probably only be a one-time check, although we may want // to choose to check this often in the future. if unsafe { OFFSET }.is_none() { - calculate_offset(response.headers.get_raw("date")); + calculate_offset(&response.headers().get("date").and_then(|d| Some(d.as_bytes()))); } // Check if the request got ratelimited by checking for status 429, @@ -171,11 +173,11 @@ pub(super) fn perform(req: Request) -> Result<Response> { if route == Route::None { return Ok(response); } else { - let redo = if response.headers.get_raw("x-ratelimit-global").is_some() { + let redo = if response.headers().get("x-ratelimit-global").is_some() { let _ = GLOBAL.lock(); Ok( - if let Some(retry_after) = parse_header(&response.headers, "retry-after")? { + if let Some(retry_after) = parse_header(&response.headers(), "retry-after")? { debug!("Ratelimited on route {:?} for {:?}ms", route, retry_after); thread::sleep(Duration::from_millis(retry_after as u64)); @@ -256,21 +258,21 @@ impl RateLimit { } pub(crate) fn post_hook(&mut self, response: &Response, route: &Route) -> Result<bool> { - if let Some(limit) = parse_header(&response.headers, "x-ratelimit-limit")? { + if let Some(limit) = parse_header(&response.headers(), "x-ratelimit-limit")? { self.limit = limit; } - if let Some(remaining) = parse_header(&response.headers, "x-ratelimit-remaining")? { + if let Some(remaining) = parse_header(&response.headers(), "x-ratelimit-remaining")? { self.remaining = remaining; } - if let Some(reset) = parse_header(&response.headers, "x-ratelimit-reset")? { + if let Some(reset) = parse_header(&response.headers(), "x-ratelimit-reset")? { self.reset = reset; } - Ok(if response.status != StatusCode::TooManyRequests { + Ok(if response.status() != StatusCode::TOO_MANY_REQUESTS { false - } else if let Some(retry_after) = parse_header(&response.headers, "retry-after")? { + } else if let Some(retry_after) = parse_header(&response.headers(), "retry-after")? { debug!("Ratelimited on route {:?} for {:?}ms", route, retry_after); thread::sleep(Duration::from_millis(retry_after as u64)); @@ -302,20 +304,18 @@ pub fn offset() -> Option<i64> { unsafe { OFFSET } } -fn calculate_offset(header: Option<&[Vec<u8>]>) { +fn calculate_offset(header: &Option<&[u8]>) { // Get the current time as soon as possible. let now = Utc::now().timestamp(); - // First get the `Date` header's value and parse it as UTF8. - let header = header - .and_then(|h| h.get(0)) - .and_then(|x| str::from_utf8(x).ok()); + let header = header.and_then(|x| str::from_utf8(x).ok()); - if let Some(date) = header { + if let Some(header) = header { // Replace the `GMT` timezone with an offset, and then parse it // into a chrono DateTime. If it parses correctly, calculate the // diff and then set it as the offset. - let s = date.replace("GMT", "+0000"); + let s = header.replace("GMT", "+0000"); + let parsed = DateTime::parse_from_str(&s, "%a, %d %b %Y %T %z"); if let Ok(parsed) = parsed { @@ -330,11 +330,12 @@ fn calculate_offset(header: Option<&[Vec<u8>]>) { } } } + } fn parse_header(headers: &Headers, header: &str) -> Result<Option<i64>> { - headers.get_raw(header).map_or(Ok(None), |header| { - str::from_utf8(&header[0]) + headers.get(header).map_or(Ok(None), |header| { + str::from_utf8(&header.as_bytes()) .map_err(|_| Error::Http(HttpError::RateLimitUtf8)) .and_then(|v| { v.parse::<i64>() diff --git a/src/http/raw.rs b/src/http/raw.rs index b4a9cc9..742ca5c 100644 --- a/src/http/raw.rs +++ b/src/http/raw.rs @@ -1,22 +1,14 @@ use constants; -use hyper::{ - client::{ - Request as HyperRequest, - Response as HyperResponse - }, - header::{ContentType, Headers}, - method::Method, - mime::{Mime, SubLevel, TopLevel}, - net::HttpsConnector, - header, - Error as HyperError, - Result as HyperResult, - Url +use reqwest::{ + Client as ReqwestClient, + header::{AUTHORIZATION, USER_AGENT, CONTENT_TYPE, HeaderValue, HeaderMap as Headers}, + multipart::Part, + Response as ReqwestResponse, + StatusCode, + Url, }; -use hyper_native_tls::NativeTlsClient; use internal::prelude::*; use model::prelude::*; -use multipart::client::Multipart; use super::{ TOKEN, ratelimiting, @@ -25,8 +17,6 @@ use super::{ AttachmentType, GuildPagination, HttpError, - StatusClass, - StatusCode, }; use serde::de::DeserializeOwned; use serde_json; @@ -691,7 +681,7 @@ pub fn edit_profile(map: &JsonMap) -> Result<CurrentUser> { route: RouteInfo::EditProfile, })?; - let mut value = serde_json::from_reader::<HyperResponse, Value>(response)?; + let mut value = serde_json::from_reader::<ReqwestResponse, Value>(response)?; if let Some(map) = value.as_object_mut() { if !TOKEN.lock().starts_with("Bot ") { @@ -880,9 +870,7 @@ pub fn execute_webhook(webhook_id: u64, let body = serde_json::to_vec(map)?; let mut headers = Headers::new(); - headers.set(ContentType( - Mime(TopLevel::Application, SubLevel::Json, vec![]), - )); + headers.insert(CONTENT_TYPE, HeaderValue::from_static(&"application/json")); let response = request(Request { body: Some(&body), @@ -890,11 +878,11 @@ pub fn execute_webhook(webhook_id: u64, route: RouteInfo::ExecuteWebhook { token, wait, webhook_id }, })?; - if response.status == StatusCode::NoContent { + if response.status() == StatusCode::NO_CONTENT { return Ok(None); } - serde_json::from_reader::<HyperResponse, Message>(response) + serde_json::from_reader::<ReqwestResponse, Message>(response) .map(Some) .map_err(From::from) } @@ -1086,7 +1074,7 @@ pub fn get_guild_vanity_url(guild_id: u64) -> Result<String> { route: RouteInfo::GetGuildVanityUrl { guild_id }, })?; - serde_json::from_reader::<HyperResponse, GuildVanityUrl>(response) + serde_json::from_reader::<ReqwestResponse, GuildVanityUrl>(response) .map(|x| x.code) .map_err(From::from) } @@ -1103,7 +1091,7 @@ pub fn get_guild_members(guild_id: u64, route: RouteInfo::GetGuildMembers { after, guild_id, limit }, })?; - let mut v = serde_json::from_reader::<HyperResponse, Value>(response)?; + let mut v = serde_json::from_reader::<ReqwestResponse, Value>(response)?; if let Some(values) = v.as_array_mut() { let num = Value::Number(Number::from(guild_id)); @@ -1241,7 +1229,7 @@ pub fn get_member(guild_id: u64, user_id: u64) -> Result<Member> { route: RouteInfo::GetMember { guild_id, user_id }, })?; - let mut v = serde_json::from_reader::<HyperResponse, Value>(response)?; + let mut v = serde_json::from_reader::<ReqwestResponse, Value>(response)?; if let Some(map) = v.as_object_mut() { map.insert("guild_id".to_string(), Value::Number(Number::from(guild_id))); @@ -1470,31 +1458,31 @@ pub fn send_files<'a, T, It: IntoIterator<Item=T>>(channel_id: u64, files: It, m Err(_) => return Err(Error::Url(uri)), }; - let tc = NativeTlsClient::new()?; - let connector = HttpsConnector::new(tc); - let mut request = HyperRequest::with_connector(Method::Post, url, &connector)?; - request - .headers_mut() - .set(header::Authorization(TOKEN.lock().clone())); - request - .headers_mut() - .set(header::UserAgent(constants::USER_AGENT.to_string())); - - let mut request = Multipart::from_request(request)?; + let client = ReqwestClient::new() + .post(url) + .header(AUTHORIZATION, HeaderValue::from_str(&TOKEN.lock())?) + .header(USER_AGENT, HeaderValue::from_static(&constants::USER_AGENT)); + + let mut multipart = reqwest::multipart::Form::new(); let mut file_num = "0".to_string(); for file in files { + match file.into() { - AttachmentType::Bytes((mut bytes, filename)) => { - request - .write_stream(&file_num, &mut bytes, Some(filename), None)?; + AttachmentType::Bytes((bytes, filename)) => { + multipart = multipart + .part(file_num.to_string(), Part::bytes(bytes.to_vec()) + .file_name(filename.to_string())); }, - AttachmentType::File((mut f, filename)) => { - request - .write_stream(&file_num, &mut f, Some(filename), None)?; + AttachmentType::File((file, filename)) => { + multipart = multipart + .part(file_num.to_string(), + Part::reader(file.try_clone()?) + .file_name(filename.to_string())); }, - AttachmentType::Path(p) => { - request.write_file(&file_num, &p)?; + AttachmentType::Path(path) => { + multipart = multipart + .file(file_num.to_string(), path)?; }, } @@ -1506,19 +1494,19 @@ pub fn send_files<'a, T, It: IntoIterator<Item=T>>(channel_id: u64, files: It, m for (k, v) in map { match v { - Value::Bool(false) => request.write_text(&k, "false")?, - Value::Bool(true) => request.write_text(&k, "true")?, - Value::Number(inner) => request.write_text(&k, inner.to_string())?, - Value::String(inner) => request.write_text(&k, inner)?, - Value::Object(inner) => request.write_text(&k, serde_json::to_string(&inner)?)?, + Value::Bool(false) => multipart = multipart.text(k.clone(), "false"), + Value::Bool(true) => multipart = multipart.text(k.clone(), "true"), + Value::Number(inner) => multipart = multipart.text(k.clone(), inner.to_string()), + Value::String(inner) => multipart = multipart.text(k.clone(), inner), + Value::Object(inner) =>multipart = multipart.text(k.clone(), serde_json::to_string(&inner)?), _ => continue, }; } - let response = request.send()?; + let response = client.multipart(multipart).send()?; - if response.status.class() != StatusClass::Success { - return Err(Error::Http(HttpError::UnsuccessfulRequest(response))); + if !response.status().is_success() { + return Err(HttpError::UnsuccessfulRequest(response).into()); } serde_json::from_reader(response).map_err(From::from) @@ -1665,7 +1653,7 @@ pub fn fire<T: DeserializeOwned>(req: Request) -> Result<T> { /// Performs a request, ratelimiting it if necessary. /// -/// Returns the raw hyper Response. Use [`fire`] to deserialize the response +/// Returns the raw reqwest Response. Use [`fire`] to deserialize the response /// into some type. /// /// # Examples @@ -1696,7 +1684,7 @@ pub fn fire<T: DeserializeOwned>(req: Request) -> Result<T> { /// /// let response = http::request(request.build())?; /// -/// println!("Response successful?: {}", response.status.is_success()); +/// println!("Response successful?: {}", response.status().is_success()); /// # /// # Ok(()) /// # } @@ -1707,29 +1695,40 @@ pub fn fire<T: DeserializeOwned>(req: Request) -> Result<T> { /// ``` /// /// [`fire`]: fn.fire.html -pub fn request(req: Request) -> Result<HyperResponse> { +pub fn request(req: Request) -> Result<ReqwestResponse> { let response = ratelimiting::perform(req)?; - if response.status.class() == StatusClass::Success { + if response.status().is_success() { Ok(response) } else { Err(Error::Http(HttpError::UnsuccessfulRequest(response))) } } -pub(super) fn retry(request: &Request) -> HyperResult<HyperResponse> { +pub(super) fn retry(request: &Request) -> Result<ReqwestResponse> { // Retry the request twice in a loop until it succeeds. // // If it doesn't and the loop breaks, try one last time. for _ in 0..3 { - match request.build().send() { - Err(HyperError::Io(ref io)) - if io.kind() == IoErrorKind::ConnectionAborted => continue, - other => return other, + + match request.build()?.send() { + Ok(response) => return Ok(response), + Err(reqwest_error) => { + if let Some(io_error) = reqwest_error.get_ref().and_then(|e| e.downcast_ref::<std::io::Error>()) { + + if let IoErrorKind::ConnectionAborted = io_error.kind() { + continue; + } + } + + return Err(reqwest_error.into()); + }, } } - request.build().send() + request.build() + .map_err(Into::into) + .and_then(|b| Ok(b.send()?)) } /// Performs a request and then verifies that the response status code is equal @@ -1740,11 +1739,11 @@ pub(super) fn retry(request: &Request) -> HyperResult<HyperResponse> { pub(super) fn wind(expected: u16, req: Request) -> Result<()> { let resp = request(req)?; - if resp.status.to_u16() == expected { + if resp.status().as_u16() == expected { return Ok(()); } - debug!("Expected {}, got {}", expected, resp.status); + debug!("Expected {}, got {}", expected, resp.status()); trace!("Unsuccessful response: {:?}", resp); Err(Error::Http(HttpError::UnsuccessfulRequest(resp))) diff --git a/src/http/request.rs b/src/http/request.rs index 92dd073..cbb8f79 100644 --- a/src/http/request.rs +++ b/src/http/request.rs @@ -1,11 +1,13 @@ use constants; -use hyper::{ - client::{Body, RequestBuilder as HyperRequestBuilder}, - header::{Authorization, ContentType, Headers, UserAgent}, +use reqwest::{ + RequestBuilder as ReqwestRequestBuilder, + header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT, HeaderMap as Headers, HeaderValue}, + Url, }; use super::{ CLIENT, TOKEN, + HttpError, routing::RouteInfo, }; @@ -61,33 +63,35 @@ impl<'a> Request<'a> { Self { body, headers, route } } - pub fn build(&'a self) -> HyperRequestBuilder<'a> { + pub fn build(&'a self) -> Result<ReqwestRequestBuilder, HttpError> { let Request { body, headers: ref request_headers, route: ref route_info, } = *self; + let (method, _, path) = route_info.deconstruct(); let mut builder = CLIENT.request( - method.hyper_method(), - &path.into_owned(), + method.reqwest_method(), + Url::parse(&path)?, ); if let Some(ref bytes) = body { - builder = builder.body(Body::BufBody(bytes, bytes.len())); + builder = builder.body(Vec::from(*bytes)); } - let mut headers = Headers::new(); - headers.set(UserAgent(constants::USER_AGENT.to_string())); - headers.set(Authorization(TOKEN.lock().clone())); - headers.set(ContentType::json()); + let mut headers = Headers::with_capacity(3); + headers.insert(USER_AGENT, HeaderValue::from_static(&constants::USER_AGENT)); + headers.insert(AUTHORIZATION, + HeaderValue::from_str(&TOKEN.lock()).map_err(|e| HttpError::InvalidHeader(e))?); + headers.insert(CONTENT_TYPE, HeaderValue::from_static(&"application/json")); - if let Some(request_headers) = request_headers.clone() { - headers.extend(request_headers.iter()); + if let Some(ref request_headers) = request_headers { + headers.extend(request_headers.clone()); } - builder.headers(headers) + Ok(builder.headers(headers)) } pub fn body_ref(&self) -> &Option<&'a [u8]> { diff --git a/src/http/routing.rs b/src/http/routing.rs index d1d4bc6..e9ebaba 100644 --- a/src/http/routing.rs +++ b/src/http/routing.rs @@ -46,7 +46,7 @@ pub enum Route { // Refer to the docs on [Rate Limits] in the yellow warning section. // // Additionally, this needs to be a `LightMethod` from the parent module - // and _not_ a `hyper` `Method` due to `hyper`'s not deriving `Copy`. + // and _not_ a `reqwest` `Method` due to `reqwest`'s not deriving `Copy`. // // [Rate Limits]: https://discordapp.com/developers/docs/topics/rate-limits ChannelsIdMessagesId(LightMethod, u64), |