diff options
| author | Steven Fackler <[email protected]> | 2016-08-02 20:52:07 -0700 |
|---|---|---|
| committer | Steven Fackler <[email protected]> | 2016-08-02 20:52:07 -0700 |
| commit | c5b2ede2829869915852b30bf9600bc3cb1fbdc9 (patch) | |
| tree | db723196c8a5d63fa569721de26be59ec559f8d9 /openssl/src/ssl/mod.rs | |
| parent | Merge pull request #433 from tmiasko/binop-different-lifetimes (diff) | |
| parent | Restructure PEM input/output methods (diff) | |
| download | rust-openssl-c5b2ede2829869915852b30bf9600bc3cb1fbdc9.tar.xz rust-openssl-c5b2ede2829869915852b30bf9600bc3cb1fbdc9.zip | |
Merge remote-tracking branch 'origin/breaks'
Diffstat (limited to 'openssl/src/ssl/mod.rs')
| -rw-r--r-- | openssl/src/ssl/mod.rs | 796 |
1 files changed, 217 insertions, 579 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 0252b114..3d1ec6e5 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -5,9 +5,9 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; +use std::error as stderror; use std::mem; use std::str; -use std::net; use std::path::Path; use std::ptr; use std::sync::{Once, ONCE_INIT, Mutex, Arc}; @@ -18,17 +18,13 @@ use libc::{c_uchar, c_uint}; #[cfg(any(feature = "npn", feature = "alpn"))] use std::slice; use std::marker::PhantomData; -#[cfg(unix)] -use std::os::unix::io::{AsRawFd, RawFd}; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, RawSocket}; use ffi; use ffi_extras; use dh::DH; -use ssl::error::{NonblockingSslError, SslError, OpenSslError, OpensslError}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; +use error::ErrorStack; pub mod error; mod bio; @@ -45,27 +41,11 @@ extern "C" { fn rust_SSL_CTX_clone(cxt: *mut ffi::SSL_CTX); } -static mut VERIFY_IDX: c_int = -1; -static mut SNI_IDX: c_int = -1; - /// Manually initialize SSL. /// It is optional to call this function and safe to do so more than once. pub fn init() { static mut INIT: Once = ONCE_INIT; - - unsafe { - INIT.call_once(|| { - ffi::init(); - - let verify_idx = ffi::SSL_CTX_get_ex_new_index(0, ptr::null(), None, None, None); - assert!(verify_idx >= 0); - VERIFY_IDX = verify_idx; - - let sni_idx = ffi::SSL_CTX_get_ex_new_index(0, ptr::null(), None, None, None); - assert!(sni_idx >= 0); - SNI_IDX = sni_idx; - }); - } + unsafe { INIT.call_once(|| ffi::init()); } } bitflags! { @@ -186,30 +166,6 @@ impl SslMethod { _ => None, } } - - #[cfg(feature = "dtlsv1")] - pub fn is_dtlsv1(&self) -> bool { - *self == SslMethod::Dtlsv1 - } - - #[cfg(feature = "dtlsv1_2")] - pub fn is_dtlsv1_2(&self) -> bool { - *self == SslMethod::Dtlsv1_2 - } - - pub fn is_dtls(&self) -> bool { - self.is_dtlsv1() || self.is_dtlsv1_2() - } - - #[cfg(not(feature = "dtlsv1"))] - pub fn is_dtlsv1(&self) -> bool { - false - } - - #[cfg(not(feature = "dtlsv1_2"))] - pub fn is_dtlsv1_2(&self) -> bool { - false - } } /// Determines the type of certificate verification used @@ -292,47 +248,19 @@ fn get_new_ssl_idx<T>() -> c_int { } } -extern "C" fn raw_verify(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int { - unsafe { - let idx = ffi::SSL_get_ex_data_X509_STORE_CTX_idx(); - let ssl = ffi::X509_STORE_CTX_get_ex_data(x509_ctx, idx); - let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); - let verify = ffi::SSL_CTX_get_ex_data(ssl_ctx, VERIFY_IDX); - let verify: Option<VerifyCallback> = mem::transmute(verify); - - let ctx = X509StoreContext::new(x509_ctx); - - match verify { - None => preverify_ok, - Some(verify) => verify(preverify_ok != 0, &ctx) as c_int, - } - } -} - -extern "C" fn raw_verify_with_data<T>(preverify_ok: c_int, - x509_ctx: *mut ffi::X509_STORE_CTX) - -> c_int - where T: Any + 'static +extern "C" fn raw_verify<F>(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int + where F: Fn(bool, &X509StoreContext) -> bool + Any + 'static + Sync + Send { unsafe { let idx = ffi::SSL_get_ex_data_X509_STORE_CTX_idx(); let ssl = ffi::X509_STORE_CTX_get_ex_data(x509_ctx, idx); let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); - - let verify = ffi::SSL_CTX_get_ex_data(ssl_ctx, VERIFY_IDX); - let verify: Option<VerifyCallbackData<T>> = mem::transmute(verify); - - let data = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_verify_data_idx::<T>()); - let data: &T = mem::transmute(data); + let verify = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_verify_data_idx::<F>()); + let verify: &F = mem::transmute(verify); let ctx = X509StoreContext::new(x509_ctx); - let res = match verify { - None => preverify_ok, - Some(verify) => verify(preverify_ok != 0, &ctx, data) as c_int, - }; - - res + verify(preverify_ok != 0, &ctx) as c_int } } @@ -351,50 +279,31 @@ extern "C" fn ssl_raw_verify<F>(preverify_ok: c_int, x509_ctx: *mut ffi::X509_ST } } -extern "C" fn raw_sni(ssl: *mut ffi::SSL, ad: &mut c_int, _arg: *mut c_void) -> c_int { - unsafe { - let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); - let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, SNI_IDX); - let callback: Option<ServerNameCallback> = mem::transmute(callback); - rust_SSL_clone(ssl); - let mut s = Ssl { ssl: ssl }; - - let res = match callback { - None => ffi::SSL_TLSEXT_ERR_ALERT_FATAL, - Some(callback) => callback(&mut s, ad), - }; - - res - } -} - -extern "C" fn raw_sni_with_data<T>(ssl: *mut ffi::SSL, ad: &mut c_int, arg: *mut c_void) -> c_int - where T: Any + 'static +extern "C" fn raw_sni<F>(ssl: *mut ffi::SSL, al: *mut c_int, _arg: *mut c_void) -> c_int + where F: Fn(&mut Ssl) -> Result<(), SniError> + Any + 'static + Sync + Send { unsafe { let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); - - let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, SNI_IDX); - let callback: Option<ServerNameCallbackData<T>> = mem::transmute(callback); + let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_verify_data_idx::<F>()); + let callback: &F = mem::transmute(callback); rust_SSL_clone(ssl); - let mut s = Ssl { ssl: ssl }; - - let data: &T = mem::transmute(arg); - - let res = match callback { - None => ffi::SSL_TLSEXT_ERR_ALERT_FATAL, - Some(callback) => callback(&mut s, ad, &*data), - }; + let mut ssl = Ssl { ssl: ssl }; - // Since data might be required on the next verification - // it is time to forget about it and avoid dropping - // data will be freed once OpenSSL considers it is time - // to free all context data - res + match callback(&mut ssl) { + Ok(()) => ffi::SSL_TLSEXT_ERR_OK, + Err(SniError::Fatal(e)) => { + *al = e; + ffi::SSL_TLSEXT_ERR_ALERT_FATAL + } + Err(SniError::Warning(e)) => { + *al = e; + ffi::SSL_TLSEXT_ERR_ALERT_WARNING + } + Err(SniError::NoAck) => ffi::SSL_TLSEXT_ERR_NOACK, + } } } - #[cfg(any(feature = "npn", feature = "alpn"))] unsafe fn select_proto_using(ssl: *mut ffi::SSL, out: *mut *mut c_uchar, @@ -499,24 +408,18 @@ fn ssl_encode_byte_strings(strings: &[&[u8]]) -> Vec<u8> { enc } -/// The signature of functions that can be used to manually verify certificates -pub type VerifyCallback = fn(preverify_ok: bool, x509_ctx: &X509StoreContext) -> bool; - -/// The signature of functions that can be used to manually verify certificates -/// when user-data should be carried for all verification process -pub type VerifyCallbackData<T> = fn(preverify_ok: bool, x509_ctx: &X509StoreContext, data: &T) - -> bool; - -/// The signature of functions that can be used to choose the context depending on the server name -pub type ServerNameCallback = fn(ssl: &mut Ssl, ad: &mut i32) -> i32; - -pub type ServerNameCallbackData<T> = fn(ssl: &mut Ssl, ad: &mut i32, data: &T) -> i32; +/// An error returned from an SNI callback. +pub enum SniError { + Fatal(c_int), + Warning(c_int), + NoAck, +} // FIXME: macro may be instead of inlining? #[inline] -fn wrap_ssl_result(res: c_int) -> Result<(), SslError> { +fn wrap_ssl_result(res: c_int) -> Result<(), ErrorStack> { if res == 0 { - Err(SslError::get()) + Err(ErrorStack::get()) } else { Ok(()) } @@ -559,80 +462,56 @@ impl SslContext { } /// Creates a new SSL context. - pub fn new(method: SslMethod) -> Result<SslContext, SslError> { + pub fn new(method: SslMethod) -> Result<SslContext, ErrorStack> { init(); let ctx = try_ssl_null!(unsafe { ffi::SSL_CTX_new(method.to_raw()) }); - let ctx = SslContext { ctx: ctx }; + let mut ctx = SslContext { ctx: ctx }; + match method { + #[cfg(feature = "dtlsv1")] + SslMethod::Dtlsv1 => ctx.set_read_ahead(1), + #[cfg(feature = "dtlsv1_2")] + SslMethod::Dtlsv1_2 => ctx.set_read_ahead(1), + _ => {} + } // this is a bit dubious (?) try!(ctx.set_mode(ffi::SSL_MODE_AUTO_RETRY | ffi::SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER)); - if method.is_dtls() { - ctx.set_read_ahead(1); - } - Ok(ctx) } /// Configures the certificate verification method for new connections. - pub fn set_verify(&mut self, mode: SslVerifyMode, verify: Option<VerifyCallback>) { + pub fn set_verify(&mut self, mode: SslVerifyMode) { unsafe { - ffi::SSL_CTX_set_ex_data(self.ctx, VERIFY_IDX, mem::transmute(verify)); - let f: extern "C" fn(c_int, *mut ffi::X509_STORE_CTX) -> c_int = raw_verify; - - ffi::SSL_CTX_set_verify(self.ctx, mode.bits as c_int, Some(f)); + ffi::SSL_CTX_set_verify(self.ctx, mode.bits as c_int, None); } } - /// Configures the certificate verification method for new connections also - /// carrying supplied data. - // Note: no option because there is no point to set data without providing - // a function handling it - pub fn set_verify_with_data<T>(&mut self, - mode: SslVerifyMode, - verify: VerifyCallbackData<T>, - data: T) - where T: Any + 'static + /// Configures the certificate verification method for new connections and + /// registers a verification callback. + pub fn set_verify_callback<F>(&mut self, mode: SslVerifyMode, verify: F) + where F: Fn(bool, &X509StoreContext) -> bool + Any + 'static + Sync + Send { - let data = Box::new(data); unsafe { - ffi::SSL_CTX_set_ex_data(self.ctx, VERIFY_IDX, mem::transmute(Some(verify))); - ffi::SSL_CTX_set_ex_data(self.ctx, get_verify_data_idx::<T>(), mem::transmute(data)); - let f: extern "C" fn(c_int, *mut ffi::X509_STORE_CTX) -> c_int = - raw_verify_with_data::<T>; - - ffi::SSL_CTX_set_verify(self.ctx, mode.bits as c_int, Some(f)); + let verify = Box::new(verify); + ffi::SSL_CTX_set_ex_data(self.ctx, get_verify_data_idx::<F>(), mem::transmute(verify)); + ffi::SSL_CTX_set_verify(self.ctx, mode.bits as c_int, Some(raw_verify::<F>)); } } /// Configures the server name indication (SNI) callback for new connections /// - /// obtain the server name with `get_servername` then set the corresponding context + /// Obtain the server name with `servername` then set the corresponding context /// with `set_ssl_context` - pub fn set_servername_callback(&mut self, callback: Option<ServerNameCallback>) { - unsafe { - ffi::SSL_CTX_set_ex_data(self.ctx, SNI_IDX, mem::transmute(callback)); - let f: extern "C" fn(_, _, _) -> _ = raw_sni; - let f: extern "C" fn() = mem::transmute(f); - ffi_extras::SSL_CTX_set_tlsext_servername_callback(self.ctx, Some(f)); - } - } - - /// Configures the server name indication (SNI) callback for new connections - /// carrying supplied data - pub fn set_servername_callback_with_data<T>(&mut self, - callback: ServerNameCallbackData<T>, - data: T) - where T: Any + 'static + pub fn set_servername_callback<F>(&mut self, callback: F) + where F: Fn(&mut Ssl) -> Result<(), SniError> + Any + 'static + Sync + Send { - let data = Box::new(data); unsafe { - ffi::SSL_CTX_set_ex_data(self.ctx, SNI_IDX, mem::transmute(Some(callback))); - - ffi_extras::SSL_CTX_set_tlsext_servername_arg(self.ctx, mem::transmute(data)); - let f: extern "C" fn(_, _, _) -> _ = raw_sni_with_data::<T>; + let callback = Box::new(callback); + ffi::SSL_CTX_set_ex_data(self.ctx, get_verify_data_idx::<F>(), mem::transmute(callback)); + let f: extern "C" fn(_, _, _) -> _ = raw_sni::<F>; let f: extern "C" fn() = mem::transmute(f); ffi_extras::SSL_CTX_set_tlsext_servername_callback(self.ctx, Some(f)); } @@ -645,18 +524,18 @@ impl SslContext { } } - pub fn set_read_ahead(&self, m: u32) { + pub fn set_read_ahead(&mut self, m: u32) { unsafe { ffi_extras::SSL_CTX_set_read_ahead(self.ctx, m as c_long); } } - fn set_mode(&self, mode: c_long) -> Result<(), SslError> { + fn set_mode(&mut self, mode: c_long) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { ffi_extras::SSL_CTX_set_mode(self.ctx, mode) as c_int }) } - pub fn set_tmp_dh(&self, dh: DH) -> Result<(), SslError> { - wrap_ssl_result(unsafe { ffi_extras::SSL_CTX_set_tmp_dh(self.ctx, dh.raw()) as c_int }) + pub fn set_tmp_dh(&mut self, dh: DH) -> Result<(), ErrorStack> { + wrap_ssl_result(unsafe { ffi_extras::SSL_CTX_set_tmp_dh(self.ctx, dh.raw()) as i32 }) } /// Use the default locations of trusted certificates for verification. @@ -664,13 +543,13 @@ impl SslContext { /// These locations are read from the `SSL_CERT_FILE` and `SSL_CERT_DIR` /// environment variables if present, or defaults specified at OpenSSL /// build time otherwise. - pub fn set_default_verify_paths(&mut self) -> Result<(), SslError> { + pub fn set_default_verify_paths(&mut self) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { ffi::SSL_CTX_set_default_verify_paths(self.ctx) }) } #[allow(non_snake_case)] /// Specifies the file that contains trusted CA certificates. - pub fn set_CA_file<P: AsRef<Path>>(&mut self, file: P) -> Result<(), SslError> { + pub fn set_CA_file<P: AsRef<Path>>(&mut self, file: P) -> Result<(), ErrorStack> { let file = CString::new(file.as_ref().as_os_str().to_str().expect("invalid utf8")).unwrap(); wrap_ssl_result(unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, file.as_ptr() as *const _, ptr::null()) @@ -685,7 +564,7 @@ impl SslContext { /// /// This value should be set when using client certificates, or each request will fail /// handshake and need to be restarted. - pub fn set_session_id_context(&mut self, sid_ctx: &[u8]) -> Result<(), SslError> { + pub fn set_session_id_context(&mut self, sid_ctx: &[u8]) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { ffi::SSL_CTX_set_session_id_context(self.ctx, sid_ctx.as_ptr(), sid_ctx.len() as u32) }) @@ -695,7 +574,7 @@ impl SslContext { pub fn set_certificate_file<P: AsRef<Path>>(&mut self, file: P, file_type: X509FileType) - -> Result<(), SslError> { + -> Result<(), ErrorStack> { let file = CString::new(file.as_ref().as_os_str().to_str().expect("invalid utf8")).unwrap(); wrap_ssl_result(unsafe { ffi::SSL_CTX_use_certificate_file(self.ctx, @@ -708,7 +587,7 @@ impl SslContext { pub fn set_certificate_chain_file<P: AsRef<Path>>(&mut self, file: P, file_type: X509FileType) - -> Result<(), SslError> { + -> Result<(), ErrorStack> { let file = CString::new(file.as_ref().as_os_str().to_str().expect("invalid utf8")).unwrap(); wrap_ssl_result(unsafe { ffi::SSL_CTX_use_certificate_chain_file(self.ctx, @@ -718,13 +597,13 @@ impl SslContext { } /// Specifies the certificate - pub fn set_certificate(&mut self, cert: &X509) -> Result<(), SslError> { + pub fn set_certificate(&mut self, cert: &X509) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { ffi::SSL_CTX_use_certificate(self.ctx, cert.get_handle()) }) } /// Adds a certificate to the certificate chain presented together with the /// certificate specified using set_certificate() - pub fn add_extra_chain_cert(&mut self, cert: &X509) -> Result<(), SslError> { + pub fn add_extra_chain_cert(&mut self, cert: &X509) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { ffi_extras::SSL_CTX_add_extra_chain_cert(self.ctx, cert.get_handle()) as c_int }) @@ -734,7 +613,7 @@ impl SslContext { pub fn set_private_key_file<P: AsRef<Path>>(&mut self, file: P, file_type: X509FileType) - -> Result<(), SslError> { + -> Result<(), ErrorStack> { let file = CString::new(file.as_ref().as_os_str().to_str().expect("invalid utf8")).unwrap(); wrap_ssl_result(unsafe { ffi::SSL_CTX_use_PrivateKey_file(self.ctx, @@ -744,16 +623,16 @@ impl SslContext { } /// Specifies the private key - pub fn set_private_key(&mut self, key: &PKey) -> Result<(), SslError> { + pub fn set_private_key(&mut self, key: &PKey) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { ffi::SSL_CTX_use_PrivateKey(self.ctx, key.get_handle()) }) } /// Check consistency of private key and certificate - pub fn check_private_key(&mut self) -> Result<(), SslError> { + pub fn check_private_key(&mut self) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { ffi::SSL_CTX_check_private_key(self.ctx) }) } - pub fn set_cipher_list(&mut self, cipher_list: &str) -> Result<(), SslError> { + pub fn set_cipher_list(&mut self, cipher_list: &str) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { let cipher_list = CString::new(cipher_list).unwrap(); ffi::SSL_CTX_set_cipher_list(self.ctx, cipher_list.as_ptr() as *const _) @@ -765,7 +644,7 @@ impl SslContext { /// /// This method requires OpenSSL >= 1.0.2 or LibreSSL and the `ecdh_auto` feature. #[cfg(feature = "ecdh_auto")] - pub fn set_ecdh_auto(&mut self, onoff: bool) -> Result<(), SslError> { + pub fn set_ecdh_auto(&mut self, onoff: bool) -> Result<(), ErrorStack> { wrap_ssl_result(unsafe { ffi_extras::SSL_CTX_set_ecdh_auto(self.ctx, onoff as c_int) }) } @@ -775,7 +654,7 @@ impl SslContext { SslContextOptions::from_bits(ret).unwrap() } - pub fn get_options(&mut self) -> SslContextOptions { + pub fn options(&self) -> SslContextOptions { let ret = unsafe { ffi_extras::SSL_CTX_get_options(self.ctx) }; SslContextOptions::from_bits(ret).unwrap() } @@ -935,17 +814,8 @@ impl Drop for Ssl { } } -impl Clone for Ssl { - /// # Deprecated - fn clone(&self) -> Ssl { - unsafe { rust_SSL_clone(self.ssl) }; - Ssl { ssl: self.ssl } - - } -} - impl Ssl { - pub fn new(ctx: &SslContext) -> Result<Ssl, SslError> { + pub fn new(ctx: &SslContext) -> Result<Ssl, ErrorStack> { let ssl = try_ssl_null!(unsafe { ffi::SSL_new(ctx.ctx) }); let ssl = Ssl { ssl: ssl }; Ok(ssl) @@ -963,6 +833,10 @@ impl Ssl { unsafe { ffi::SSL_accept(self.ssl) } } + fn handshake(&self) -> c_int { + unsafe { ffi::SSL_do_handshake(self.ssl) } + } + fn read(&self, buf: &mut [u8]) -> c_int { let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) } @@ -973,12 +847,8 @@ impl Ssl { unsafe { ffi::SSL_write(self.ssl, buf.as_ptr() as *const c_void, len) } } - fn get_error(&self, ret: c_int) -> LibSslError { - let err = unsafe { ffi::SSL_get_error(self.ssl, ret) }; - match LibSslError::from_i32(err as i32) { - Some(err) => err, - None => unreachable!(), - } + fn get_error(&self, ret: c_int) -> c_int { + unsafe { ffi::SSL_get_error(self.ssl, ret) } } /// Sets the verification mode to be used during the handshake process. @@ -1007,7 +877,7 @@ impl Ssl { } } - pub fn get_current_cipher<'a>(&'a self) -> Option<SslCipher<'a>> { + pub fn current_cipher<'a>(&'a self) -> Option<SslCipher<'a>> { unsafe { let ptr = ffi::SSL_get_current_cipher(self.ssl); @@ -1041,7 +911,7 @@ impl Ssl { } /// Sets the host name to be used with SNI (Server Name Indication). - pub fn set_hostname(&self, hostname: &str) -> Result<(), SslError> { + pub fn set_hostname(&self, hostname: &str) -> Result<(), ErrorStack> { let cstr = CString::new(hostname).unwrap(); let ret = unsafe { ffi_extras::SSL_set_tlsext_host_name(self.ssl, cstr.as_ptr() as *const _) @@ -1049,7 +919,7 @@ impl Ssl { // For this case, 0 indicates failure. if ret == 0 { - Err(SslError::get()) + Err(ErrorStack::get()) } else { Ok(()) } @@ -1147,7 +1017,7 @@ impl Ssl { Some(s) } - pub fn get_ssl_method(&self) -> Option<SslMethod> { + pub fn ssl_method(&self) -> Option<SslMethod> { unsafe { let method = ffi::SSL_get_ssl_method(self.ssl); SslMethod::from_raw(method) @@ -1155,7 +1025,7 @@ impl Ssl { } /// Returns the server's name for the current connection - pub fn get_servername(&self) -> Option<String> { + pub fn servername(&self) -> Option<String> { let name = unsafe { ffi::SSL_get_servername(self.ssl, ffi::TLSEXT_NAMETYPE_host_name) }; if name == ptr::null() { return None; @@ -1164,63 +1034,23 @@ impl Ssl { unsafe { String::from_utf8(CStr::from_ptr(name as *const _).to_bytes().to_vec()).ok() } } - /// change the context corresponding to the current connection - /// - /// Returns a clone of the SslContext @ctx (ie: the new context). The old context is freed. - pub fn set_ssl_context(&self, ctx: &SslContext) -> SslContext { - // If duplication of @ctx's cert fails, this returns NULL. This _appears_ to only occur on - // allocation failures (meaning panicing is probably appropriate), but it might be nice to - // propogate the error. - assert!(unsafe { ffi::SSL_set_SSL_CTX(self.ssl, ctx.ctx) } != ptr::null_mut()); - - // FIXME: we return this reference here for compatibility, but it isn't actually required. - // This should be removed when a api-incompatabile version is to be released. - // - // ffi:SSL_set_SSL_CTX() returns copy of the ctx pointer passed to it, so it's easier for - // us to do the clone directly. - ctx.clone() - } - - /// obtain the context corresponding to the current connection - pub fn get_ssl_context(&self) -> SslContext { + /// Changes the context corresponding to the current connection. + pub fn set_ssl_context(&self, ctx: &SslContext) -> Result<(), ErrorStack> { unsafe { - let ssl_ctx = ffi::SSL_get_SSL_CTX(self.ssl); - SslContext::new_ref(ssl_ctx) + try_ssl_null!(ffi::SSL_set_SSL_CTX(self.ssl, ctx.ctx)); } + Ok(()) } -} -macro_rules! make_LibSslError { - ($($variant:ident = $value:ident),+) => { - #[derive(Debug)] - #[repr(i32)] - enum LibSslError { - $($variant = ffi::$value),+ - } - - impl LibSslError { - fn from_i32(val: i32) -> Option<LibSslError> { - match val { - $(ffi::$value => Some(LibSslError::$variant),)+ - _ => None - } - } + /// Returns the context corresponding to the current connection + pub fn ssl_context(&self) -> SslContext { + unsafe { + let ssl_ctx = ffi::SSL_get_SSL_CTX(self.ssl); + SslContext::new_ref(ssl_ctx) } } } -make_LibSslError! { - ErrorNone = SSL_ERROR_NONE, - ErrorSsl = SSL_ERROR_SSL, - ErrorWantRead = SSL_ERROR_WANT_READ, - ErrorWantWrite = SSL_ERROR_WANT_WRITE, - ErrorWantX509Lookup = SSL_ERROR_WANT_X509_LOOKUP, - ErrorSyscall = SSL_ERROR_SYSCALL, - ErrorZeroReturn = SSL_ERROR_ZERO_RETURN, - ErrorWantConnect = SSL_ERROR_WANT_CONNECT, - ErrorWantAccept = SSL_ERROR_WANT_ACCEPT -} - /// A stream wrapper which handles SSL encryption for an underlying stream. pub struct SslStream<S> { ssl: Ssl, @@ -1228,19 +1058,7 @@ pub struct SslStream<S> { _p: PhantomData<S>, } -/// # Deprecated -/// -/// This method does not behave as expected and will be removed in a future -/// release. -impl<S: Clone + Read + Write> Clone for SslStream<S> { - fn clone(&self) -> SslStream<S> { - SslStream { - ssl: self.ssl.clone(), - _method: self._method.clone(), - _p: PhantomData, - } - } -} +unsafe impl<S: Send> Send for SslStream<S> {} impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug @@ -1253,20 +1071,6 @@ impl<S> fmt::Debug for SslStream<S> } } -#[cfg(unix)] -impl<S: AsRawFd> AsRawFd for SslStream<S> { - fn as_raw_fd(&self) -> RawFd { - self.get_ref().as_raw_fd() - } -} - -#[cfg(windows)] -impl<S: AsRawSocket> AsRawSocket for SslStream<S> { - fn as_raw_socket(&self) -> RawSocket { - self.get_ref().as_raw_socket() - } -} - impl<S: Read + Write> SslStream<S> { fn new_base(ssl: Ssl, stream: S) -> Self { unsafe { @@ -1282,49 +1086,53 @@ impl<S: Read + Write> SslStream<S> { } /// Creates an SSL/TLS client operating over the provided stream. - pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, SslError> { - let ssl = try!(ssl.into_ssl()); + pub fn connect<T: IntoSsl>(ssl: T, stream: S) + -> Result<Self, HandshakeError<S>>{ + let ssl = try!(ssl.into_ssl().map_err(|e| { + HandshakeError::Failure(Error::Ssl(e)) + })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.connect(); if ret > 0 { Ok(stream) } else { - match stream.make_old_error(ret) { - Some(err) => Err(err), - None => Ok(stream), + match stream.make_error(ret) { + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + Err(HandshakeError::Interrupted(MidHandshakeSslStream { + stream: stream, + error: e, + })) + } + err => Err(HandshakeError::Failure(err)), } } } /// Creates an SSL/TLS server operating over the provided stream. - pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, SslError> { - let ssl = try!(ssl.into_ssl()); + pub fn accept<T: IntoSsl>(ssl: T, stream: S) + -> Result<Self, HandshakeError<S>> { + let ssl = try!(ssl.into_ssl().map_err(|e| { + HandshakeError::Failure(Error::Ssl(e)) + })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.accept(); if ret > 0 { Ok(stream) } else { - match stream.make_old_error(ret) { - Some(err) => Err(err), - None => Ok(stream), + match stream.make_error(ret) { + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + Err(HandshakeError::Interrupted(MidHandshakeSslStream { + stream: stream, + error: e, + })) + } + err => Err(HandshakeError::Failure(err)), } } } - /// ### Deprecated - /// - /// Use `connect`. - pub fn connect_generic<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { - Self::connect(ssl, stream) - } - - /// ### Deprecated - /// - /// Use `accept`. - pub fn accept_generic<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { - Self::accept(ssl, stream) - } - /// Like `read`, but returns an `ssl::Error` rather than an `io::Error`. /// /// This is particularly useful with a nonblocking socket, where the error @@ -1352,63 +1160,112 @@ impl<S: Read + Write> SslStream<S> { } } -impl<S> SslStream<S> { - fn make_error(&mut self, ret: c_int) -> Error { - self.check_panic(); +/// An error or intermediate state after a TLS handshake attempt. +#[derive(Debug)] +pub enum HandshakeError<S> { + /// The handshake failed. + Failure(Error), + /// The handshake was interrupted midway through. + Interrupted(MidHandshakeSslStream<S>), +} - match self.ssl.get_error(ret) { - LibSslError::ErrorSsl => Error::Ssl(OpenSslError::get_stack()), - LibSslError::ErrorSyscall => { - let errs = OpenSslError::get_stack(); - if errs.is_empty() { - if ret == 0 { - Error::Stream(io::Error::new(io::ErrorKind::ConnectionAborted, - "unexpected EOF observed")) - } else { - Error::Stream(self.get_bio_error()) - } - } else { - Error::Ssl(errs) +impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> { + fn description(&self) -> &str { + match *self { + HandshakeError::Failure(ref e) => e.description(), + HandshakeError::Interrupted(ref e) => e.error.description(), + } + } + + fn cause(&self) -> Option<&stderror::Error> { + match *self { + HandshakeError::Failure(ref e) => Some(e), + HandshakeError::Interrupted(ref e) => Some(&e.error), + } + } +} + +impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(f.write_str(stderror::Error::description(self))); + if let Some(e) = stderror::Error::cause(self) { + try!(write!(f, ": {}", e)); + } + Ok(()) + } +} + +/// An SSL stream midway through the handshake process. +#[derive(Debug)] +pub struct MidHandshakeSslStream<S> { + stream: SslStream<S>, + error: Error, +} + +impl<S> MidHandshakeSslStream<S> { + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &S { + self.stream.get_ref() + } + + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut S { + self.stream.get_mut() + } + + /// Returns a shared reference to the `SslContext` of the stream. + pub fn ssl(&self) -> &Ssl { + self.stream.ssl() + } + + /// Returns the underlying error which interrupted this handshake. + pub fn error(&self) -> &Error { + &self.error + } + + /// Restarts the handshake process. + pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> { + let ret = self.stream.ssl.handshake(); + if ret > 0 { + Ok(self.stream) + } else { + match self.stream.make_error(ret) { + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + self.error = e; + Err(HandshakeError::Interrupted(self)) } - } - LibSslError::ErrorZeroReturn => Error::ZeroReturn, - LibSslError::ErrorWantWrite => Error::WantWrite(self.get_bio_error()), - LibSslError::ErrorWantRead => Error::WantRead(self.get_bio_error()), - err => { - Error::Stream(io::Error::new(io::ErrorKind::Other, - format!("unexpected error {:?}", err))) + err => Err(HandshakeError::Failure(err)), } } } +} - fn make_old_error(&mut self, ret: c_int) -> Option<SslError> { +impl<S> SslStream<S> { + fn make_error(&mut self, ret: c_int) -> Error { self.check_panic(); match self.ssl.get_error(ret) { - LibSslError::ErrorSsl => Some(SslError::get()), - LibSslError::ErrorSyscall => { - let err = SslError::get(); - let count = match err { - SslError::OpenSslErrors(ref v) => v.len(), - _ => unreachable!(), - }; - if count == 0 { + ffi::SSL_ERROR_SSL => Error::Ssl(ErrorStack::get()), + ffi::SSL_ERROR_SYSCALL => { + let errs = ErrorStack::get(); + if errs.errors().is_empty() { if ret == 0 { - Some(SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, - "unexpected EOF observed"))) + Error::Stream(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) } else { - Some(SslError::StreamError(self.get_bio_error())) + Error::Stream(self.get_bio_error()) } } else { - Some(err) + Error::Ssl(errs) } } - LibSslError::ErrorZeroReturn => Some(SslError::SslSessionClosed), - LibSslError::ErrorWantWrite | - LibSslError::ErrorWantRead => None, + ffi::SSL_ERROR_ZERO_RETURN => Error::ZeroReturn, + ffi::SSL_ERROR_WANT_WRITE => Error::WantWrite(self.get_bio_error()), + ffi::SSL_ERROR_WANT_READ => Error::WantRead(self.get_bio_error()), err => { - Some(SslError::StreamError(io::Error::new(io::ErrorKind::Other, - format!("unexpected error {:?}", err)))) + Error::Stream(io::Error::new(io::ErrorKind::InvalidData, + format!("unexpected error {}", err))) } } } @@ -1461,20 +1318,6 @@ impl<S> SslStream<S> { } } -impl SslStream<::std::net::TcpStream> { - /// # Deprecated - /// - /// This method does not behave as expected and will be removed in a future - /// release. - pub fn try_clone(&self) -> io::Result<SslStream<::std::net::TcpStream>> { - Ok(SslStream { - ssl: self.ssl.clone(), - _method: self._method.clone(), - _p: PhantomData, - }) - } -} - impl<S: Read + Write> Read for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { match self.ssl_read(buf) { @@ -1506,222 +1349,17 @@ impl<S: Read + Write> Write for SslStream<S> { } pub trait IntoSsl { - fn into_ssl(self) -> Result<Ssl, SslError>; + fn into_ssl(self) -> Result<Ssl, ErrorStack>; } impl IntoSsl for Ssl { - fn into_ssl(self) -> Result<Ssl, SslError> { + fn into_ssl(self) -> Result<Ssl, ErrorStack> { Ok(self) } } impl<'a> IntoSsl for &'a SslContext { - fn into_ssl(self) -> Result<Ssl, SslError> { + fn into_ssl(self) -> Result<Ssl, ErrorStack> { Ssl::new(self) } } - -/// A utility type to help in cases where the use of SSL is decided at runtime. -#[derive(Debug)] -pub enum MaybeSslStream<S> - where S: Read + Write -{ - /// A connection using SSL - Ssl(SslStream<S>), - /// A connection not using SSL - Normal(S), -} - -impl<S> Read for MaybeSslStream<S> - where S: Read + Write -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - match *self { - MaybeSslStream::Ssl(ref mut s) => s.read(buf), - MaybeSslStream::Normal(ref mut s) => s.read(buf), - } - } -} - -impl<S> Write for MaybeSslStream<S> - where S: Read + Write -{ - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - match *self { - MaybeSslStream::Ssl(ref mut s) => s.write(buf), - MaybeSslStream::Normal(ref mut s) => s.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match *self { - MaybeSslStream::Ssl(ref mut s) => s.flush(), - MaybeSslStream::Normal(ref mut s) => s.flush(), - } - } -} - -impl<S> MaybeSslStream<S> - where S: Read + Write -{ - /// Returns a reference to the underlying stream. - pub fn get_ref(&self) -> &S { - match *self { - MaybeSslStream::Ssl(ref s) => s.get_ref(), - MaybeSslStream::Normal(ref s) => s, - } - } - - /// Returns a mutable reference to the underlying stream. - /// - /// ## Warning - /// - /// It is inadvisable to read from or write to the underlying stream. - pub fn get_mut(&mut self) -> &mut S { - match *self { - MaybeSslStream::Ssl(ref mut s) => s.get_mut(), - MaybeSslStream::Normal(ref mut s) => s, - } - } -} - -impl MaybeSslStream<net::TcpStream> { - /// Like `TcpStream::try_clone`. - pub fn try_clone(&self) -> io::Result<MaybeSslStream<net::TcpStream>> { - match *self { - MaybeSslStream::Ssl(ref s) => s.try_clone().map(MaybeSslStream::Ssl), - MaybeSslStream::Normal(ref s) => s.try_clone().map(MaybeSslStream::Normal), - } - } -} - -/// # Deprecated -/// -/// Use `SslStream` with `ssl_read` and `ssl_write`. -pub struct NonblockingSslStream<S>(SslStream<S>); - -impl<S: Clone + Read + Write> Clone for NonblockingSslStream<S> { - fn clone(&self) -> Self { - NonblockingSslStream(self.0.clone()) - } -} - -#[cfg(unix)] -impl<S: AsRawFd> AsRawFd for NonblockingSslStream<S> { - fn as_raw_fd(&self) -> RawFd { - self.0.as_raw_fd() - } -} - -#[cfg(windows)] -impl<S: AsRawSocket> AsRawSocket for NonblockingSslStream<S> { - fn as_raw_socket(&self) -> RawSocket { - self.0.as_raw_socket() - } -} - -impl NonblockingSslStream<net::TcpStream> { - pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> { - self.0.try_clone().map(NonblockingSslStream) - } -} - -impl<S> NonblockingSslStream<S> { - /// Returns a reference to the underlying stream. - pub fn get_ref(&self) -> &S { - self.0.get_ref() - } - - /// Returns a mutable reference to the underlying stream. - /// - /// ## Warning - /// - /// It is inadvisable to read from or write to the underlying stream as it - /// will most likely corrupt the SSL session. - pub fn get_mut(&mut self) -> &mut S { - self.0.get_mut() - } - - /// Returns a reference to the Ssl. - pub fn ssl(&self) -> &Ssl { - self.0.ssl() - } -} - -impl<S: Read + Write> NonblockingSslStream<S> { - /// Create a new nonblocking client ssl connection on wrapped `stream`. - /// - /// Note that this method will most likely not actually complete the SSL - /// handshake because doing so requires several round trips; the handshake will - /// be completed in subsequent read/write calls managed by your event loop. - pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { - SslStream::connect(ssl, stream).map(NonblockingSslStream) - } - - /// Create a new nonblocking server ssl connection on wrapped `stream`. - /// - /// Note that this method will most likely not actually complete the SSL - /// handshake because doing so requires several round trips; the handshake will - /// be completed in subsequent read/write calls managed by your event loop. - pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { - SslStream::accept(ssl, stream).map(NonblockingSslStream) - } - - fn convert_err(&self, err: Error) -> NonblockingSslError { - match err { - Error::ZeroReturn => SslError::SslSessionClosed.into(), - Error::WantRead(_) => NonblockingSslError::WantRead, - Error::WantWrite(_) => NonblockingSslError::WantWrite, - Error::WantX509Lookup => unreachable!(), - Error::Stream(e) => SslError::StreamError(e).into(), - Error::Ssl(e) => { - SslError::OpenSslErrors(e.iter() - .map(|e| OpensslError::from_error_code(e.error_code())) - .collect()) - .into() - } - } - } - - /// Read bytes from the SSL stream into `buf`. - /// - /// Given the SSL state machine, this method may return either `WantWrite` - /// or `WantRead` to indicate that your event loop should respectively wait - /// for write or read readiness on the underlying stream. Upon readiness, - /// repeat your `read()` call with the same arguments each time until you - /// receive an `Ok(count)`. - /// - /// An `SslError` return value, is terminal; do not re-attempt your read. - /// - /// As expected of a nonblocking API, this method will never block your - /// thread on I/O. - /// - /// On a return value of `Ok(count)`, count is the number of decrypted - /// plaintext bytes copied into the `buf` slice. - pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> { - match self.0.ssl_read(buf) { - Ok(n) => Ok(n), - Err(Error::ZeroReturn) => Ok(0), - Err(e) => Err(self.convert_err(e)), - } - } - - /// Write bytes from `buf` to the SSL stream. - /// - /// Given the SSL state machine, this method may return either `WantWrite` - /// or `WantRead` to indicate that your event loop should respectively wait - /// for write or read readiness on the underlying stream. Upon readiness, - /// repeat your `write()` call with the same arguments each time until you - /// receive an `Ok(count)`. - /// - /// An `SslError` return value, is terminal; do not re-attempt your write. - /// - /// As expected of a nonblocking API, this method will never block your - /// thread on I/O. - /// - /// Given a return value of `Ok(count)`, count is the number of plaintext bytes - /// from the `buf` slice that were encrypted and written onto the stream. - pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> { - self.0.ssl_write(buf).map_err(|e| self.convert_err(e)) - } -} |