diff options
Diffstat (limited to 'openssl/src/ssl/mod.rs')
| -rw-r--r-- | openssl/src/ssl/mod.rs | 904 |
1 files changed, 375 insertions, 529 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index cca369a2..3b22c755 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -10,28 +10,42 @@ use std::str; use std::net; use std::path::Path; use std::ptr; -use std::sync::{Once, ONCE_INIT, Arc, Mutex}; -use std::ops::{Deref, DerefMut}; +use std::sync::{Once, ONCE_INIT, Mutex}; use std::cmp; use std::any::Any; #[cfg(any(feature = "npn", feature = "alpn"))] use libc::{c_uchar, c_uint}; #[cfg(any(feature = "npn", feature = "alpn"))] use std::slice; +use std::marker::PhantomData; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; -use bio::{MemBio}; use ffi; use ffi_extras; use dh::DH; -use ssl::error::{NonblockingSslError, SslError, SslSessionClosed, StreamError, OpenSslErrors}; +use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError, + OpensslError}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; pub mod error; +mod bio; #[cfg(test)] mod tests; +#[doc(inline)] +pub use ssl::error::Error; + +extern "C" { + fn rust_SSL_clone(ssl: *mut ffi::SSL); + fn rust_SSL_CTX_clone(cxt: *mut ffi::SSL_CTX); +} + static mut VERIFY_IDX: c_int = -1; +static mut SNI_IDX: c_int = -1; /// Manually initialize SSL. /// It is optional to call this function and safe to do so more than once. @@ -46,6 +60,11 @@ pub fn init() { None, None); assert!(verify_idx >= 0); VERIFY_IDX = verify_idx; + + let sni_idx = ffi::SSL_CTX_get_ex_new_index(0, ptr::null(), None, + None, None); + assert!(sni_idx >= 0); + SNI_IDX = sni_idx; }); } } @@ -279,24 +298,63 @@ extern fn raw_verify_with_data<T>(preverify_ok: c_int, let verify: Option<VerifyCallbackData<T>> = mem::transmute(verify); let data = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_verify_data_idx::<T>()); - let data: Box<T> = mem::transmute(data); + let data: &T = mem::transmute(data); let ctx = X509StoreContext::new(x509_ctx); let res = match verify { None => preverify_ok, - Some(verify) => verify(preverify_ok != 0, &ctx, &*data) as c_int + Some(verify) => verify(preverify_ok != 0, &ctx, data) as c_int + }; + + res + } +} + +extern fn raw_sni(ssl: *mut ffi::SSL, ad: &mut c_int, _arg: *mut c_void) + -> c_int { + unsafe { + let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); + let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, SNI_IDX); + let callback: Option<ServerNameCallback> = mem::transmute(callback); + rust_SSL_clone(ssl); + let mut s = Ssl { ssl: ssl }; + + let res = match callback { + None => ffi::SSL_TLSEXT_ERR_ALERT_FATAL, + Some(callback) => callback(&mut s, ad) + }; + + res + } +} + +extern fn raw_sni_with_data<T>(ssl: *mut ffi::SSL, ad: &mut c_int, arg: *mut c_void) -> c_int + where T: Any + 'static { + unsafe { + let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); + + let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, SNI_IDX); + let callback: Option<ServerNameCallbackData<T>> = mem::transmute(callback); + rust_SSL_clone(ssl); + let mut s = Ssl { ssl: ssl }; + + let data: &T = mem::transmute(arg); + + let res = match callback { + None => ffi::SSL_TLSEXT_ERR_ALERT_FATAL, + Some(callback) => callback(&mut s, ad, &*data) }; // Since data might be required on the next verification // it is time to forget about it and avoid dropping // data will be freed once OpenSSL considers it is time // to free all context data - mem::forget(data); res } } + #[cfg(any(feature = "npn", feature = "alpn"))] unsafe fn select_proto_using(ssl: *mut ffi::SSL, out: *mut *mut c_uchar, outlen: *mut c_uchar, @@ -404,6 +462,11 @@ pub type VerifyCallbackData<T> = fn(preverify_ok: bool, x509_ctx: &X509StoreContext, data: &T) -> bool; +/// The signature of functions that can be used to choose the context depending on the server name +pub type ServerNameCallback = fn(ssl: &mut Ssl, ad: &mut i32) -> i32; + +pub type ServerNameCallbackData<T> = fn(ssl: &mut Ssl, ad: &mut i32, data: &T) -> i32; + // FIXME: macro may be instead of inlining? #[inline] fn wrap_ssl_result(res: c_int) -> Result<(),SslError> { @@ -485,6 +548,35 @@ impl SslContext { } } + /// Configures the server name indication (SNI) callback for new connections + /// + /// obtain the server name with `get_servername` then set the corresponding context + /// with `set_ssl_context` + pub fn set_servername_callback(&mut self, callback: Option<ServerNameCallback>) { + unsafe { + ffi::SSL_CTX_set_ex_data(self.ctx, SNI_IDX, + mem::transmute(callback)); + let f: extern fn() = mem::transmute(raw_sni); + ffi_extras::SSL_CTX_set_tlsext_servername_callback(self.ctx, Some(f)); + } + } + + /// Configures the server name indication (SNI) callback for new connections + /// carrying supplied data + pub fn set_servername_callback_with_data<T>(&mut self, callback: ServerNameCallbackData<T>, + data: T) + where T: Any + 'static { + let data = Box::new(data); + unsafe { + ffi::SSL_CTX_set_ex_data(self.ctx, SNI_IDX, + mem::transmute(Some(callback))); + + ffi_extras::SSL_CTX_set_tlsext_servername_arg(self.ctx, mem::transmute(data)); + let f: extern fn() = mem::transmute(raw_sni_with_data::<T>); + ffi_extras::SSL_CTX_set_tlsext_servername_callback(self.ctx, Some(f)); + } + } + /// Sets verification depth pub fn set_verify_depth(&mut self, depth: u32) { unsafe { @@ -510,7 +602,7 @@ impl SslContext { let file = CString::new(file.as_ref().as_os_str().to_str().expect("invalid utf8")).unwrap(); wrap_ssl_result( unsafe { - ffi::SSL_CTX_load_verify_locations(self.ctx, file.as_ptr(), ptr::null()) + ffi::SSL_CTX_load_verify_locations(self.ctx, file.as_ptr() as *const _, ptr::null()) }) } @@ -520,7 +612,7 @@ impl SslContext { let file = CString::new(file.as_ref().as_os_str().to_str().expect("invalid utf8")).unwrap(); wrap_ssl_result( unsafe { - ffi::SSL_CTX_use_certificate_file(self.ctx, file.as_ptr(), file_type as c_int) + ffi::SSL_CTX_use_certificate_file(self.ctx, file.as_ptr() as *const _, file_type as c_int) }) } @@ -530,7 +622,7 @@ impl SslContext { let file = CString::new(file.as_ref().as_os_str().to_str().expect("invalid utf8")).unwrap(); wrap_ssl_result( unsafe { - ffi::SSL_CTX_use_certificate_chain_file(self.ctx, file.as_ptr(), file_type as c_int) + ffi::SSL_CTX_use_certificate_chain_file(self.ctx, file.as_ptr() as *const _, file_type as c_int) }) } @@ -557,7 +649,7 @@ impl SslContext { let file = CString::new(file.as_ref().as_os_str().to_str().expect("invalid utf8")).unwrap(); wrap_ssl_result( unsafe { - ffi::SSL_CTX_use_PrivateKey_file(self.ctx, file.as_ptr(), file_type as c_int) + ffi::SSL_CTX_use_PrivateKey_file(self.ctx, file.as_ptr() as *const _, file_type as c_int) }) } @@ -581,7 +673,7 @@ impl SslContext { wrap_ssl_result( unsafe { let cipher_list = CString::new(cipher_list).unwrap(); - ffi::SSL_CTX_set_cipher_list(self.ctx, cipher_list.as_ptr()) + ffi::SSL_CTX_set_cipher_list(self.ctx, cipher_list.as_ptr() as *const _) }) } @@ -673,26 +765,7 @@ impl SslContext { ffi::SSL_CTX_set_alpn_select_cb(self.ctx, raw_alpn_select_cb, ptr::null_mut()); } } -} - -#[allow(dead_code)] -struct MemBioRef<'ssl> { - ssl: &'ssl Ssl, - bio: MemBio, -} - -impl<'ssl> Deref for MemBioRef<'ssl> { - type Target = MemBio; - - fn deref(&self) -> &MemBio { - &self.bio - } -} -impl<'ssl> DerefMut for MemBioRef<'ssl> { - fn deref_mut(&mut self) -> &mut MemBio { - &mut self.bio - } } pub struct Ssl { @@ -716,6 +789,14 @@ impl Drop for Ssl { } } +impl Clone for Ssl { + fn clone(&self) -> Ssl { + unsafe { rust_SSL_clone(self.ssl) }; + Ssl { ssl: self.ssl } + + } +} + impl Ssl { pub fn new(ctx: &SslContext) -> Result<Ssl, SslError> { let ssl = try_ssl_null!(unsafe { ffi::SSL_new(ctx.ctx) }); @@ -723,20 +804,8 @@ impl Ssl { Ok(ssl) } - fn get_rbio<'a>(&'a self) -> MemBioRef<'a> { - unsafe { self.wrap_bio(ffi::SSL_get_rbio(self.ssl)) } - } - - fn get_wbio<'a>(&'a self) -> MemBioRef<'a> { - unsafe { self.wrap_bio(ffi::SSL_get_wbio(self.ssl)) } - } - - fn wrap_bio<'a>(&'a self, bio: *mut ffi::BIO) -> MemBioRef<'a> { - assert!(bio != ptr::null_mut()); - MemBioRef { - ssl: self, - bio: MemBio::borrowed(bio) - } + fn get_raw_rbio(&self) -> *mut ffi::BIO { + unsafe { ffi::SSL_get_rbio(self.ssl) } } fn connect(&self) -> c_int { @@ -768,7 +837,7 @@ impl Ssl { pub fn state_string(&self) -> &'static str { let state = unsafe { let ptr = ffi::SSL_state_string(self.ssl); - CStr::from_ptr(ptr) + CStr::from_ptr(ptr as *const _) }; str::from_utf8(state.to_bytes()).unwrap() @@ -777,7 +846,7 @@ impl Ssl { pub fn state_string_long(&self) -> &'static str { let state = unsafe { let ptr = ffi::SSL_state_string_long(self.ssl); - CStr::from_ptr(ptr) + CStr::from_ptr(ptr as *const _) }; str::from_utf8(state.to_bytes()).unwrap() @@ -786,7 +855,7 @@ impl Ssl { /// Sets the host name to be used with SNI (Server Name Indication). pub fn set_hostname(&self, hostname: &str) -> Result<(), SslError> { let cstr = CString::new(hostname).unwrap(); - let ret = unsafe { ffi_extras::SSL_set_tlsext_host_name(self.ssl, cstr.as_ptr()) }; + let ret = unsafe { ffi_extras::SSL_set_tlsext_host_name(self.ssl, cstr.as_ptr() as *const _) }; // For this case, 0 indicates failure. if ret == 0 { @@ -874,7 +943,7 @@ impl Ssl { let meth = unsafe { ffi::SSL_COMP_get_name(ptr) }; let s = unsafe { - String::from_utf8(CStr::from_ptr(meth).to_bytes().to_vec()).unwrap() + String::from_utf8(CStr::from_ptr(meth as *const _).to_bytes().to_vec()).unwrap() }; Some(s) @@ -886,6 +955,32 @@ impl Ssl { SslMethod::from_raw(method) } } + + /// Returns the server's name for the current connection + pub fn get_servername(&self) -> Option<String> { + let name = unsafe { ffi::SSL_get_servername(self.ssl, ffi::TLSEXT_NAMETYPE_host_name) }; + if name == ptr::null() { + return None; + } + + unsafe { + String::from_utf8(CStr::from_ptr(name).to_bytes().to_vec()).ok() + } + } + + /// change the context corresponding to the current connection + pub fn set_ssl_context(&self, ctx: &SslContext) -> SslContext { + SslContext { ctx: unsafe { ffi::SSL_set_SSL_CTX(self.ssl, ctx.ctx) } } + } + + /// obtain the context corresponding to the current connection + pub fn get_ssl_context(&self) -> SslContext { + unsafe { + let ssl_ctx = ffi::SSL_get_SSL_CTX(self.ssl); + rust_SSL_CTX_clone(ssl_ctx); + SslContext { ctx: ssl_ctx } + } + } } macro_rules! make_LibSslError { @@ -919,182 +1014,152 @@ make_LibSslError! { ErrorWantAccept = SSL_ERROR_WANT_ACCEPT } -struct IndirectStream<S> { - stream: S, - ssl: Arc<Ssl>, - // Max TLS record size is 16k - buf: Box<[u8; 16 * 1024]>, +/// A stream wrapper which handles SSL encryption for an underlying stream. +pub struct SslStream<S> { + ssl: Ssl, + _method: Box<ffi::BIO_METHOD>, // :( + _p: PhantomData<S>, } -impl<S: Clone> Clone for IndirectStream<S> { - fn clone(&self) -> IndirectStream<S> { - IndirectStream { - stream: self.stream.clone(), - ssl: self.ssl.clone(), - buf: Box::new(*self.buf) - } +unsafe impl<S: Send> Send for SslStream<S> {} + +impl<S: Clone + Read + Write> Clone for SslStream<S> { + fn clone(&self) -> SslStream<S> { + let stream = self.get_ref().clone(); + Self::new_base(self.ssl.clone(), stream) } } -impl IndirectStream<net::TcpStream> { - fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> { - Ok(IndirectStream { - stream: try!(self.stream.try_clone()), - ssl: self.ssl.clone(), - buf: Box::new(*self.buf) - }) +impl<S> Drop for SslStream<S> { + fn drop(&mut self) { + unsafe { + let _ = bio::take_stream::<S>(self.ssl.get_raw_rbio()); + } } } -impl<S: Read+Write> IndirectStream<S> { - fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - - let rbio = try!(MemBio::new()); - let wbio = try!(MemBio::new()); - - unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } - - Ok(IndirectStream { - stream: stream, - ssl: Arc::new(ssl), - buf: Box::new([0; 16 * 1024]), - }) +impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("SslStream") + .field("stream", &self.get_ref()) + .field("ssl", &self.ssl()) + .finish() } +} - fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { - let mut ssl = try!(IndirectStream::new_base(ssl, stream)); - try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); - Ok(ssl) +#[cfg(unix)] +impl<S: AsRawFd> AsRawFd for SslStream<S> { + fn as_raw_fd(&self) -> RawFd { + self.get_ref().as_raw_fd() } +} - fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { - let mut ssl = try!(IndirectStream::new_base(ssl, stream)); - try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); - Ok(ssl) +#[cfg(windows)] +impl<S: AsRawSocket> AsRawSocket for SslStream<S> { + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().as_raw_socket() } +} - fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError> - where F: FnMut(&Ssl) -> c_int { - loop { - let ret = blk(&self.ssl); - if ret > 0 { - return Ok(ret); - } - - let e = self.ssl.get_error(ret); - match e { - LibSslError::ErrorWantRead => { - try_ssl_stream!(self.flush()); - let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); - +impl<S: Read+Write> SslStream<S> { + fn new_base(ssl: Ssl, stream: S) -> Self { + unsafe { + let (bio, method) = bio::new(stream).unwrap(); + ffi::SSL_set_bio(ssl.ssl, bio, bio); - if len == 0 { - let method = self.ssl.get_ssl_method(); + SslStream { + ssl: ssl, + _method: method, + _p: PhantomData, + } + } + } - if method.map(|m| m.is_dtls()).unwrap_or(false) { - return Ok(0); - } else { - self.ssl.get_rbio().set_eof(true); - } + /// Creates an SSL/TLS client operating over the provided stream. + pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, SslError> { + let ssl = try!(ssl.into_ssl()); + let mut stream = Self::new_base(ssl, stream); + let ret = stream.ssl.connect(); + if ret > 0 { + Ok(stream) + } else { + match stream.make_old_error(ret) { + SslError::StreamError(e) => { + // This is fine - nonblocking sockets will finish the handshake in read/write + if e.kind() == io::ErrorKind::WouldBlock { + Ok(stream) } else { - try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); + Err(SslError::StreamError(e)) } } - LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } - LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), - LibSslError::ErrorSsl => return Err(SslError::get()), - LibSslError::ErrorSyscall if ret == 0 => return Ok(0), - err => panic!("unexpected error {:?} with ret {}", err, ret), + e => Err(e) } } } - fn write_through(&mut self) -> io::Result<()> { - io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) - } -} - -impl<S: Read+Write> Read for IndirectStream<S> { - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { - Ok(len) => Ok(len as usize), - Err(SslSessionClosed) => Ok(0), - Err(StreamError(e)) => Err(e), - Err(e @ OpenSslErrors(_)) => { - Err(io::Error::new(io::ErrorKind::Other, e)) + /// Creates an SSL/TLS server operating over the provided stream. + pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, SslError> { + let ssl = try!(ssl.into_ssl()); + let mut stream = Self::new_base(ssl, stream); + let ret = stream.ssl.accept(); + if ret > 0 { + Ok(stream) + } else { + match stream.make_old_error(ret) { + SslError::StreamError(e) => { + // This is fine - nonblocking sockets will finish the handshake in read/write + if e.kind() == io::ErrorKind::WouldBlock { + Ok(stream) + } else { + Err(SslError::StreamError(e)) + } + } + e => Err(e) } } } -} - -impl<S: Read+Write> Write for IndirectStream<S> { - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { - Ok(len) => len as usize, - Err(SslSessionClosed) => 0, - Err(StreamError(e)) => return Err(e), - Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), - }; - try!(self.write_through()); - Ok(count) - } - fn flush(&mut self) -> io::Result<()> { - try!(self.write_through()); - self.stream.flush() + /// ### Deprecated + /// + /// Use `connect`. + pub fn connect_generic<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { + Self::connect(ssl, stream) } -} - -#[derive(Clone)] -struct DirectStream<S> { - stream: S, - ssl: Arc<Ssl>, -} -impl DirectStream<net::TcpStream> { - fn try_clone(&self) -> io::Result<DirectStream<net::TcpStream>> { - Ok(DirectStream { - stream: try!(self.stream.try_clone()), - ssl: self.ssl.clone(), - }) + /// ### Deprecated + /// + /// Use `accept`. + pub fn accept_generic<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { + Self::accept(ssl, stream) } } -impl<S> DirectStream<S> { - fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { - unsafe { - let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0)); - ffi::SSL_set_bio(ssl.ssl, bio, bio); - } - - Ok(DirectStream { - stream: stream, - ssl: Arc::new(ssl), - }) - } - - fn connect(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { - let ssl = try!(DirectStream::new_base(ssl, stream, sock)); - let ret = ssl.ssl.connect(); - if ret > 0 { - Ok(ssl) - } else { - Err(ssl.make_error(ret)) - } - } - - fn accept(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { - let ssl = try!(DirectStream::new_base(ssl, stream, sock)); - let ret = ssl.ssl.accept(); - if ret > 0 { - Ok(ssl) - } else { - Err(ssl.make_error(ret)) +impl<S> SslStream<S> { + fn make_error(&mut self, ret: c_int) -> Error { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => Error::Ssl(OpenSslError::get_stack()), + LibSslError::ErrorSyscall => { + let errs = OpenSslError::get_stack(); + if errs.is_empty() { + if ret == 0 { + Error::Stream(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + Error::Stream(self.get_bio_error()) + } + } else { + Error::Ssl(errs) + } + } + LibSslError::ErrorZeroReturn => Error::ZeroReturn, + LibSslError::ErrorWantWrite => Error::WantWrite(self.get_bio_error()), + LibSslError::ErrorWantRead => Error::WantRead(self.get_bio_error()), + err => Error::Stream(io::Error::new(io::ErrorKind::Other, + format!("unexpected error {:?}", err))), } } - fn make_error(&self, ret: c_int) -> SslError { + fn make_old_error(&mut self, ret: c_int) -> SslError { match self.ssl.get_error(ret) { LibSslError::ErrorSsl => SslError::get(), LibSslError::ErrorSyscall => { @@ -1108,199 +1173,36 @@ impl<S> DirectStream<S> { SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, "unexpected EOF observed")) } else { - SslError::StreamError(io::Error::last_os_error()) + SslError::StreamError(self.get_bio_error()) } } else { err } } + LibSslError::ErrorZeroReturn => SslError::SslSessionClosed, LibSslError::ErrorWantWrite | LibSslError::ErrorWantRead => { - SslError::StreamError(io::Error::last_os_error()) + SslError::StreamError(self.get_bio_error()) } - err => panic!("unexpected error {:?} with ret {}", err, ret), - } - } -} - -impl<S> Read for DirectStream<S> { - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - let ret = self.ssl.read(buf); - if ret >= 0 { - return Ok(ret as usize); - } - - match self.make_error(ret) { - SslError::StreamError(e) => Err(e), - e => Err(io::Error::new(io::ErrorKind::Other, e)), + err => SslError::StreamError(io::Error::new(io::ErrorKind::Other, + format!("unexpected error {:?}", err))), } } -} - -impl<S: Write> Write for DirectStream<S> { - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - let ret = self.ssl.write(buf); - if ret > 0 { - return Ok(ret as usize); - } - match self.make_error(ret) { - SslError::StreamError(e) => Err(e), - e => Err(io::Error::new(io::ErrorKind::Other, e)), + fn get_bio_error(&mut self) -> io::Error { + let error = unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }; + match error { + Some(error) => error, + None => io::Error::new(io::ErrorKind::Other, + "BUG: got an ErrorSyscall without an error in the BIO?") } } - fn flush(&mut self) -> io::Result<()> { - self.stream.flush() - } -} - -#[derive(Clone)] -enum StreamKind<S> { - Indirect(IndirectStream<S>), - Direct(DirectStream<S>), -} - -impl<S> StreamKind<S> { - fn stream(&self) -> &S { - match *self { - StreamKind::Indirect(ref s) => &s.stream, - StreamKind::Direct(ref s) => &s.stream, - } - } - - fn mut_stream(&mut self) -> &mut S { - match *self { - StreamKind::Indirect(ref mut s) => &mut s.stream, - StreamKind::Direct(ref mut s) => &mut s.stream, - } - } - - fn ssl(&self) -> &Ssl { - match *self { - StreamKind::Indirect(ref s) => &s.ssl, - StreamKind::Direct(ref s) => &s.ssl, - } - } -} - -/// A stream wrapper which handles SSL encryption for an underlying stream. -#[derive(Clone)] -pub struct SslStream<S> { - kind: StreamKind<S>, -} - -impl SslStream<net::TcpStream> { - /// Create a new independently owned handle to the underlying socket. - pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> { - let kind = match self.kind { - StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())), - StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone())) - }; - Ok(SslStream { - kind: kind - }) - } -} - -impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("SslStream") - .field("stream", &self.kind.stream()) - .field("ssl", &self.kind.ssl()) - .finish() - } -} - -#[cfg(unix)] -impl<S: Read+Write+::std::os::unix::io::AsRawFd> SslStream<S> { - /// Creates an SSL/TLS client operating over the provided stream. - /// - /// Streams passed to this method must implement `AsRawFd` on Unixy - /// platforms and `AsRawSocket` on Windows. Use `connect_generic` for - /// streams that do not. - pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_fd() as c_int; - let stream = try!(DirectStream::connect(ssl, stream, fd)); - Ok(SslStream { - kind: StreamKind::Direct(stream) - }) - } - - /// Creates an SSL/TLS server operating over the provided stream. - /// - /// Streams passed to this method must implement `AsRawFd` on Unixy - /// platforms and `AsRawSocket` on Windows. Use `accept_generic` for - /// streams that do not. - pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_fd() as c_int; - let stream = try!(DirectStream::accept(ssl, stream, fd)); - Ok(SslStream { - kind: StreamKind::Direct(stream) - }) - } -} - -#[cfg(windows)] -impl<S: Read+Write+::std::os::windows::io::AsRawSocket> SslStream<S> { - /// Creates an SSL/TLS client operating over the provided stream. - /// - /// Streams passed to this method must implement `AsRawFd` on Unixy - /// platforms and `AsRawSocket` on Windows. Use `connect_generic` for - /// streams that do not. - pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_socket() as c_int; - let stream = try!(DirectStream::connect(ssl, stream, fd)); - Ok(SslStream { - kind: StreamKind::Direct(stream) - }) - } - - /// Creates an SSL/TLS server operating over the provided stream. - /// - /// Streams passed to this method must implement `AsRawFd` on Unixy - /// platforms and `AsRawSocket` on Windows. Use `accept_generic` for - /// streams that do not. - pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_socket() as c_int; - let stream = try!(DirectStream::accept(ssl, stream, fd)); - Ok(SslStream { - kind: StreamKind::Direct(stream) - }) - } -} - -impl<S: Read+Write> SslStream<S> { - /// Creates an SSL/TLS client operating over the provided stream. - /// - /// `SslStream`s returned by this method will be less efficient than ones - /// returned by `connect`, so this method should only be used for streams - /// that do not implement `AsRawFd` and `AsRawSocket`. - pub fn connect_generic<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { - let stream = try!(IndirectStream::connect(ssl, stream)); - Ok(SslStream { - kind: StreamKind::Indirect(stream) - }) - } - - /// Creates an SSL/TLS server operating over the provided stream. - /// - /// `SslStream`s returned by this method will be less efficient than ones - /// returned by `accept`, so this method should only be used for streams - /// that do not implement `AsRawFd` and `AsRawSocket`. - pub fn accept_generic<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { - let stream = try!(IndirectStream::accept(ssl, stream)); - Ok(SslStream { - kind: StreamKind::Indirect(stream) - }) - } - /// Returns a reference to the underlying stream. pub fn get_ref(&self) -> &S { - self.kind.stream() + unsafe { + let bio = self.ssl.get_raw_rbio(); + bio::get_ref(bio) + } } /// Returns a mutable reference to the underlying stream. @@ -1310,37 +1212,79 @@ impl<S: Read+Write> SslStream<S> { /// It is inadvisable to read from or write to the underlying stream as it /// will most likely corrupt the SSL session. pub fn get_mut(&mut self) -> &mut S { - self.kind.mut_stream() + unsafe { + let bio = self.ssl.get_raw_rbio(); + bio::get_mut(bio) + } + } + + /// Like `read`, but returns an `ssl::Error` rather than an `io::Error`. + /// + /// This is particularly useful with a nonblocking socket, where the error + /// value will identify if OpenSSL is waiting on read or write readiness. + pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> { + let ret = self.ssl.read(buf); + if ret >= 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } + + /// Like `write`, but returns an `ssl::Error` rather than an `io::Error`. + /// + /// This is particularly useful with a nonblocking socket, where the error + /// value will identify if OpenSSL is waiting on read or write readiness. + pub fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, Error> { + let ret = self.ssl.write(buf); + if ret >= 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } } /// Returns the OpenSSL `Ssl` object associated with this stream. pub fn ssl(&self) -> &Ssl { - self.kind.ssl() + &self.ssl } } -impl<S: Read+Write> Read for SslStream<S> { +impl SslStream<::std::net::TcpStream> { + /// Like `TcpStream::try_clone`. + pub fn try_clone(&self) -> io::Result<SslStream<::std::net::TcpStream>> { + let stream = try!(self.get_ref().try_clone()); + Ok(Self::new_base(self.ssl.clone(), stream)) + } +} + +impl<S: Read> Read for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - match self.kind { - StreamKind::Indirect(ref mut s) => s.read(buf), - StreamKind::Direct(ref mut s) => s.read(buf), + match self.ssl_read(buf) { + Ok(n) => Ok(n), + Err(Error::ZeroReturn) => Ok(0), + Err(Error::Stream(e)) => Err(e), + Err(Error::WantRead(e)) => Err(e), + Err(Error::WantWrite(e)) => Err(e), + Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)), } } } -impl<S: Read+Write> Write for SslStream<S> { +impl<S: Write> Write for SslStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - match self.kind { - StreamKind::Indirect(ref mut s) => s.write(buf), - StreamKind::Direct(ref mut s) => s.write(buf), - } + self.ssl_write(buf).map_err(|e| { + match e { + Error::Stream(e) => e, + Error::WantRead(e) => e, + Error::WantWrite(e) => e, + e => io::Error::new(io::ErrorKind::Other, e), + } + }) } fn flush(&mut self) -> io::Result<()> { - match self.kind { - StreamKind::Indirect(ref mut s) => s.flush(), - StreamKind::Direct(ref mut s) => s.flush(), - } + self.get_mut().flush() } } @@ -1426,66 +1370,41 @@ impl MaybeSslStream<net::TcpStream> { } } -/// An SSL stream wrapping a nonblocking socket. -#[derive(Clone)] -pub struct NonblockingSslStream<S> { - stream: S, - ssl: Arc<Ssl>, -} +/// # Deprecated +/// +/// Use `SslStream` with `ssl_read` and `ssl_write`. +pub struct NonblockingSslStream<S>(SslStream<S>); -impl NonblockingSslStream<net::TcpStream> { - pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> { - Ok(NonblockingSslStream { - stream: try!(self.stream.try_clone()), - ssl: self.ssl.clone(), - }) +impl<S: Clone + Read + Write> Clone for NonblockingSslStream<S> { + fn clone(&self) -> Self { + NonblockingSslStream(self.0.clone()) } } -impl<S> NonblockingSslStream<S> { - fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<NonblockingSslStream<S>, SslError> { - unsafe { - let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0)); - ffi_extras::BIO_set_nbio(bio, 1); - ffi::SSL_set_bio(ssl.ssl, bio, bio); - } +#[cfg(unix)] +impl<S: AsRawFd> AsRawFd for NonblockingSslStream<S> { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} - Ok(NonblockingSslStream { - stream: stream, - ssl: Arc::new(ssl), - }) +#[cfg(windows)] +impl<S: AsRawSocket> AsRawSocket for NonblockingSslStream<S> { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() } +} - fn make_error(&self, ret: c_int) -> NonblockingSslError { - match self.ssl.get_error(ret) { - LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()), - LibSslError::ErrorSyscall => { - let err = SslError::get(); - let count = match err { - SslError::OpenSslErrors(ref v) => v.len(), - _ => unreachable!(), - }; - let ssl_error = if count == 0 { - if ret == 0 { - SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, - "unexpected EOF observed")) - } else { - SslError::StreamError(io::Error::last_os_error()) - } - } else { - err - }; - ssl_error.into() - }, - LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite, - LibSslError::ErrorWantRead => NonblockingSslError::WantRead, - err => panic!("unexpected error {:?} with ret {}", err, ret), - } +impl NonblockingSslStream<net::TcpStream> { + pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> { + self.0.try_clone().map(NonblockingSslStream) } +} +impl<S> NonblockingSslStream<S> { /// Returns a reference to the underlying stream. pub fn get_ref(&self) -> &S { - &self.stream + self.0.get_ref() } /// Returns a mutable reference to the underlying stream. @@ -1495,91 +1414,23 @@ impl<S> NonblockingSslStream<S> { /// It is inadvisable to read from or write to the underlying stream as it /// will most likely corrupt the SSL session. pub fn get_mut(&mut self) -> &mut S { - &mut self.stream + self.0.get_mut() } /// Returns a reference to the Ssl. pub fn ssl(&self) -> &Ssl { - &self.ssl + self.0.ssl() } } -#[cfg(unix)] -impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> { - /// Create a new nonblocking client ssl connection on wrapped `stream`. - /// - /// Note that this method will most likely not actually complete the SSL - /// handshake because doing so requires several round trips; the handshake will - /// be completed in subsequent read/write calls managed by your event loop. - pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_fd() as c_int; - let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); - let ret = ssl.ssl.connect(); - if ret > 0 { - Ok(ssl) - } else { - // WantRead/WantWrite is okay here; we'll finish the handshake in - // subsequent send/recv calls. - match ssl.make_error(ret) { - NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), - NonblockingSslError::SslError(other) => Err(other), - } - } - } - - /// Create a new nonblocking server ssl connection on wrapped `stream`. - /// - /// Note that this method will most likely not actually complete the SSL - /// handshake because doing so requires several round trips; the handshake will - /// be completed in subsequent read/write calls managed by your event loop. - pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_fd() as c_int; - let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); - let ret = ssl.ssl.accept(); - if ret > 0 { - Ok(ssl) - } else { - // WantRead/WantWrite is okay here; we'll finish the handshake in - // subsequent send/recv calls. - match ssl.make_error(ret) { - NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), - NonblockingSslError::SslError(other) => Err(other), - } - } - } -} - -#[cfg(unix)] -impl<S: ::std::os::unix::io::AsRawFd> ::std::os::unix::io::AsRawFd for NonblockingSslStream<S> { - fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { - self.stream.as_raw_fd() - } -} - -#[cfg(windows)] -impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> { +impl<S: Read+Write> NonblockingSslStream<S> { /// Create a new nonblocking client ssl connection on wrapped `stream`. /// /// Note that this method will most likely not actually complete the SSL /// handshake because doing so requires several round trips; the handshake will /// be completed in subsequent read/write calls managed by your event loop. pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_socket() as c_int; - let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); - let ret = ssl.ssl.connect(); - if ret > 0 { - Ok(ssl) - } else { - // WantRead/WantWrite is okay here; we'll finish the handshake in - // subsequent send/recv calls. - match ssl.make_error(ret) { - NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), - NonblockingSslError::SslError(other) => Err(other), - } - } + SslStream::connect(ssl, stream).map(NonblockingSslStream) } /// Create a new nonblocking server ssl connection on wrapped `stream`. @@ -1588,24 +1439,25 @@ impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> /// handshake because doing so requires several round trips; the handshake will /// be completed in subsequent read/write calls managed by your event loop. pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_socket() as c_int; - let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); - let ret = ssl.ssl.accept(); - if ret > 0 { - Ok(ssl) - } else { - // WantRead/WantWrite is okay here; we'll finish the handshake in - // subsequent send/recv calls. - match ssl.make_error(ret) { - NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), - NonblockingSslError::SslError(other) => Err(other), + SslStream::accept(ssl, stream).map(NonblockingSslStream) + } + + fn convert_err(&self, err: Error) -> NonblockingSslError { + match err { + Error::ZeroReturn => SslError::SslSessionClosed.into(), + Error::WantRead(_) => NonblockingSslError::WantRead, + Error::WantWrite(_) => NonblockingSslError::WantWrite, + Error::WantX509Lookup => unreachable!(), + Error::Stream(e) => SslError::StreamError(e).into(), + Error::Ssl(e) => { + SslError::OpenSslErrors(e.iter() + .map(|e| OpensslError::from_error_code(e.error_code())) + .collect()) + .into() } } } -} -impl<S: Read+Write> NonblockingSslStream<S> { /// Read bytes from the SSL stream into `buf`. /// /// Given the SSL state machine, this method may return either `WantWrite` @@ -1622,11 +1474,10 @@ impl<S: Read+Write> NonblockingSslStream<S> { /// On a return value of `Ok(count)`, count is the number of decrypted /// plaintext bytes copied into the `buf` slice. pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> { - let ret = self.ssl.read(buf); - if ret >= 0 { - Ok(ret as usize) - } else { - Err(self.make_error(ret)) + match self.0.ssl_read(buf) { + Ok(n) => Ok(n), + Err(Error::ZeroReturn) => Ok(0), + Err(e) => Err(self.convert_err(e)) } } @@ -1646,11 +1497,6 @@ impl<S: Read+Write> NonblockingSslStream<S> { /// Given a return value of `Ok(count)`, count is the number of plaintext bytes /// from the `buf` slice that were encrypted and written onto the stream. pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> { - let ret = self.ssl.write(buf); - if ret > 0 { - Ok(ret as usize) - } else { - Err(self.make_error(ret)) - } + self.0.ssl_write(buf).map_err(|e| self.convert_err(e)) } } |