diff options
| author | Jamie Turner <[email protected]> | 2015-09-19 20:50:06 -0700 |
|---|---|---|
| committer | Jamie Turner <[email protected]> | 2015-10-20 23:14:26 -0700 |
| commit | c37767df8fc1775858cd573cbe4d5e3a17fbd370 (patch) | |
| tree | a80fcebc4891737bc833a51b4cf5034dcb63ab6f /openssl/src | |
| parent | Merge pull request #290 from jimmycuadra/master (diff) | |
| download | rust-openssl-c37767df8fc1775858cd573cbe4d5e3a17fbd370.tar.xz rust-openssl-c37767df8fc1775858cd573cbe4d5e3a17fbd370.zip | |
Nonblocking streams support.
Diffstat (limited to 'openssl/src')
| -rw-r--r-- | openssl/src/ssl/error.rs | 44 | ||||
| -rw-r--r-- | openssl/src/ssl/mod.rs | 231 | ||||
| -rw-r--r-- | openssl/src/ssl/tests.rs | 132 |
3 files changed, 405 insertions, 2 deletions
diff --git a/openssl/src/ssl/error.rs b/openssl/src/ssl/error.rs index 9ff6cae9..0126b277 100644 --- a/openssl/src/ssl/error.rs +++ b/openssl/src/ssl/error.rs @@ -17,7 +17,20 @@ pub enum SslError { /// The SSL session has been closed by the other end SslSessionClosed, /// An error in the OpenSSL library - OpenSslErrors(Vec<OpensslError>) + OpenSslErrors(Vec<OpensslError>), +} + +/// An error on a nonblocking stream. +#[derive(Debug)] +pub enum NonblockingSslError { + /// A standard SSL error occurred. + SslError(SslError), + /// The OpenSSL library wants data from the remote socket; + /// the caller should wait for read readiness. + WantRead, + /// The OpenSSL library wants to send data to the remote socket; + /// the caller should wait for write readiness. + WantWrite, } impl fmt::Display for SslError { @@ -59,6 +72,35 @@ impl error::Error for SslError { } } +impl fmt::Display for NonblockingSslError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str(error::Error::description(self)) + } +} + +impl error::Error for NonblockingSslError { + fn description(&self) -> &str { + match *self { + NonblockingSslError::SslError(ref e) => e.description(), + NonblockingSslError::WantRead => "The OpenSSL library wants data from the remote socket", + NonblockingSslError::WantWrite => "The OpenSSL library want to send data to the remote socket", + } + } + + fn cause(&self) -> Option<&error::Error> { + match *self { + NonblockingSslError::SslError(ref e) => e.cause(), + _ => None + } + } +} + +impl From<SslError> for NonblockingSslError { + fn from(e: SslError) -> NonblockingSslError { + NonblockingSslError::SslError(e) + } +} + /// An error from the OpenSSL library #[derive(Debug, Clone, PartialEq, Eq)] pub enum OpensslError { diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index e76529a5..62080056 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -22,7 +22,7 @@ use std::slice; use bio::{MemBio}; use ffi; use dh::DH; -use ssl::error::{SslError, SslSessionClosed, StreamError, OpenSslErrors}; +use ssl::error::{NonblockingSslError, SslError, SslSessionClosed, StreamError, OpenSslErrors}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; @@ -1465,3 +1465,232 @@ impl<S> MaybeSslStream<S> where S: Read+Write { } } } + +/// An SSL stream wrapping a nonblocking socket. +#[derive(Clone)] +pub struct NonblockingSslStream<S> { + stream: S, + ssl: Arc<Ssl>, +} + +impl NonblockingSslStream<net::TcpStream> { + pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> { + Ok(NonblockingSslStream { + stream: try!(self.stream.try_clone()), + ssl: self.ssl.clone(), + }) + } +} + +impl<S> NonblockingSslStream<S> { + fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<NonblockingSslStream<S>, SslError> { + unsafe { + let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0)); + ffi::BIO_set_nbio(bio, 1); + ffi::SSL_set_bio(ssl.ssl, bio, bio); + } + + Ok(NonblockingSslStream { + stream: stream, + ssl: Arc::new(ssl), + }) + } + + fn make_error(&self, ret: c_int) -> NonblockingSslError { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()), + LibSslError::ErrorSyscall => { + let err = SslError::get(); + let count = match err { + SslError::OpenSslErrors(ref v) => v.len(), + _ => unreachable!(), + }; + let ssl_error = if count == 0 { + if ret == 0 { + SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + SslError::StreamError(io::Error::last_os_error()) + } + } else { + err + }; + ssl_error.into() + }, + LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite, + LibSslError::ErrorWantRead => NonblockingSslError::WantRead, + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } + + /// Returns a reference to the underlying stream. + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Returns a mutable reference to the underlying stream. + /// + /// ## Warning + /// + /// It is inadvisable to read from or write to the underlying stream as it + /// will most likely corrupt the SSL session. + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } + + /// Returns a reference to the Ssl. + pub fn ssl(&self) -> &Ssl { + &self.ssl + } +} + +#[cfg(unix)] +impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> { + /// Create a new nonblocking client ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.connect(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } + + /// Create a new nonblocking server ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.accept(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } +} + +#[cfg(unix)] +impl<S: ::std::os::unix::io::AsRawFd> ::std::os::unix::io::AsRawFd for NonblockingSslStream<S> { + fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { + self.stream.as_raw_fd() + } +} + +#[cfg(windows)] +impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> { + /// Create a new nonblocking client ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_socket() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.connect(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } + + /// Create a new nonblocking server ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_socket() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.accept(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } +} + +impl<S: Read+Write> NonblockingSslStream<S> { + /// Read bytes from the SSL stream into `buf`. + /// + /// Given the SSL state machine, this method may return either `WantWrite` + /// or `WantRead` to indicate that your event loop should respectively wait + /// for write or read readiness on the underlying stream. Upon readiness, + /// repeat your `read()` call with the same arguments each time until you + /// receive an `Ok(count)`. + /// + /// An `SslError` return value, is terminal; do not re-attempt your read. + /// + /// As expected of a nonblocking API, this method will never block your + /// thread on I/O. + /// + /// On a return value of `Ok(count)`, count is the number of decrypted + /// plaintext bytes copied into the `buf` slice. + pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> { + let ret = self.ssl.read(buf); + if ret >= 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } + + /// Write bytes from `buf` to the SSL stream. + /// + /// Given the SSL state machine, this method may return either `WantWrite` + /// or `WantRead` to indicate that your event loop should respectively wait + /// for write or read readiness on the underlying stream. Upon readiness, + /// repeat your `write()` call with the same arguments each time until you + /// receive an `Ok(count)`. + /// + /// An `SslError` return value, is terminal; do not re-attempt your write. + /// + /// As expected of a nonblocking API, this method will never block your + /// thread on I/O. + /// + /// Given a return value of `Ok(count)`, count is the number of plaintext bytes + /// from the `buf` slice that were encrypted and written onto the stream. + pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> { + let ret = self.ssl.write(buf); + if ret > 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } +} diff --git a/openssl/src/ssl/tests.rs b/openssl/src/ssl/tests.rs index 033a3b86..8335bc53 100644 --- a/openssl/src/ssl/tests.rs +++ b/openssl/src/ssl/tests.rs @@ -819,3 +819,135 @@ fn test_sslv2_connect_failure() { let (_s, tcp) = Server::new_tcp(&["-no_ssl2", "-www"]); SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(), tcp).err().unwrap(); } + +#[cfg(target_os = "linux")] +mod nonblocking_tests { + extern crate nix; + + use std::io::Write; + use std::net::TcpStream; + use std::os::unix::io::AsRawFd; + + use super::Server; + use self::nix::sys::epoll; + use self::nix::fcntl; + use ssl; + use ssl::error::NonblockingSslError; + use ssl::SslMethod; + use ssl::SslMethod::Sslv23; + use ssl::{SslContext, NonblockingSslStream}; + + fn wait_io(stream: &NonblockingSslStream<TcpStream>, read: bool, timeout_ms: isize) -> bool { + let fd = stream.as_raw_fd(); + let ep = epoll::epoll_create().unwrap(); + let event = if read { + epoll::EpollEvent { + events: epoll::EPOLLIN | epoll::EPOLLERR, + data: 0, + } + } else { + epoll::EpollEvent { + events: epoll::EPOLLOUT, + data: 0, + } + }; + epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlAdd, fd, &event).unwrap(); + let mut events = [event]; + let count = epoll::epoll_wait(ep, &mut events, timeout_ms).unwrap(); + epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlDel, fd, &event).unwrap(); + assert!(count <= 1); + count == 1 + } + + fn make_nonblocking(stream: &TcpStream) { + let fd = stream.as_raw_fd(); + fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::O_NONBLOCK)).unwrap(); + } + + #[test] + fn test_write_nonblocking() { + let (_s, stream) = Server::new(); + make_nonblocking(&stream); + let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + + let mut iterations = 0; + loop { + iterations += 1; + if iterations > 7 { + // Probably a safe assumption for the foreseeable future of openssl. + panic!("Too many read/write round trips in handshake!!"); + } + let result = stream.write("hello".as_bytes()); + match result { + Ok(_) => { + break; + }, + Err(NonblockingSslError::WantRead) => { + assert!(wait_io(&stream, true, 1000)); + }, + Err(NonblockingSslError::WantWrite) => { + assert!(wait_io(&stream, false, 1000)); + }, + Err(other) => { + panic!("Unexpected SSL Error: {:?}", other); + }, + } + } + + // Second write should succeed immediately--plenty of space in kernel buffer, + // and handshake just completed. + stream.write(" there".as_bytes()).unwrap(); + } + + #[test] + fn test_read_nonblocking() { + let (_s, stream) = Server::new(); + make_nonblocking(&stream); + let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + + let mut iterations = 0; + loop { + iterations += 1; + if iterations > 7 { + // Probably a safe assumption for the foreseeable future of openssl. + panic!("Too many read/write round trips in handshake!!"); + } + let result = stream.write("GET /\r\n\r\n".as_bytes()); + match result { + Ok(n) => { + assert_eq!(n, 9); + break; + }, + Err(NonblockingSslError::WantRead) => { + assert!(wait_io(&stream, true, 1000)); + }, + Err(NonblockingSslError::WantWrite) => { + assert!(wait_io(&stream, false, 1000)); + }, + Err(other) => { + panic!("Unexpected SSL Error: {:?}", other); + }, + } + } + let mut input_buffer = [0u8; 1500]; + let result = stream.read(&mut input_buffer); + let bytes_read = match result { + Ok(n) => { + // This branch is unlikely, but on an overloaded VM with + // unlucky context switching, the response could actually + // be in the receive buffer before we issue the read() syscall... + n + }, + Err(NonblockingSslError::WantRead) => { + assert!(wait_io(&stream, true, 3000)); + // Second read should return application data. + stream.read(&mut input_buffer).unwrap() + }, + Err(other) => { + panic!("Unexpected SSL Error: {:?}", other); + }, + }; + assert!(bytes_read >= 5); + assert_eq!(&input_buffer[..5], b"HTTP/"); + } +} |