diff options
| author | Jamie Turner <[email protected]> | 2015-09-19 20:50:06 -0700 |
|---|---|---|
| committer | Jamie Turner <[email protected]> | 2015-10-20 23:14:26 -0700 |
| commit | c37767df8fc1775858cd573cbe4d5e3a17fbd370 (patch) | |
| tree | a80fcebc4891737bc833a51b4cf5034dcb63ab6f /openssl/src/ssl/mod.rs | |
| parent | Merge pull request #290 from jimmycuadra/master (diff) | |
| download | rust-openssl-c37767df8fc1775858cd573cbe4d5e3a17fbd370.tar.xz rust-openssl-c37767df8fc1775858cd573cbe4d5e3a17fbd370.zip | |
Nonblocking streams support.
Diffstat (limited to 'openssl/src/ssl/mod.rs')
| -rw-r--r-- | openssl/src/ssl/mod.rs | 231 |
1 files changed, 230 insertions, 1 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index e76529a5..62080056 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -22,7 +22,7 @@ use std::slice; use bio::{MemBio}; use ffi; use dh::DH; -use ssl::error::{SslError, SslSessionClosed, StreamError, OpenSslErrors}; +use ssl::error::{NonblockingSslError, SslError, SslSessionClosed, StreamError, OpenSslErrors}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; @@ -1465,3 +1465,232 @@ impl<S> MaybeSslStream<S> where S: Read+Write { } } } + +/// An SSL stream wrapping a nonblocking socket. +#[derive(Clone)] +pub struct NonblockingSslStream<S> { + stream: S, + ssl: Arc<Ssl>, +} + +impl NonblockingSslStream<net::TcpStream> { + pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> { + Ok(NonblockingSslStream { + stream: try!(self.stream.try_clone()), + ssl: self.ssl.clone(), + }) + } +} + +impl<S> NonblockingSslStream<S> { + fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<NonblockingSslStream<S>, SslError> { + unsafe { + let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0)); + ffi::BIO_set_nbio(bio, 1); + ffi::SSL_set_bio(ssl.ssl, bio, bio); + } + + Ok(NonblockingSslStream { + stream: stream, + ssl: Arc::new(ssl), + }) + } + + fn make_error(&self, ret: c_int) -> NonblockingSslError { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()), + LibSslError::ErrorSyscall => { + let err = SslError::get(); + let count = match err { + SslError::OpenSslErrors(ref v) => v.len(), + _ => unreachable!(), + }; + let ssl_error = if count == 0 { + if ret == 0 { + SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + SslError::StreamError(io::Error::last_os_error()) + } + } else { + err + }; + ssl_error.into() + }, + LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite, + LibSslError::ErrorWantRead => NonblockingSslError::WantRead, + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } + + /// Returns a reference to the underlying stream. + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Returns a mutable reference to the underlying stream. + /// + /// ## Warning + /// + /// It is inadvisable to read from or write to the underlying stream as it + /// will most likely corrupt the SSL session. + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } + + /// Returns a reference to the Ssl. + pub fn ssl(&self) -> &Ssl { + &self.ssl + } +} + +#[cfg(unix)] +impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> { + /// Create a new nonblocking client ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.connect(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } + + /// Create a new nonblocking server ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.accept(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } +} + +#[cfg(unix)] +impl<S: ::std::os::unix::io::AsRawFd> ::std::os::unix::io::AsRawFd for NonblockingSslStream<S> { + fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { + self.stream.as_raw_fd() + } +} + +#[cfg(windows)] +impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> { + /// Create a new nonblocking client ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_socket() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.connect(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } + + /// Create a new nonblocking server ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_socket() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.accept(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } +} + +impl<S: Read+Write> NonblockingSslStream<S> { + /// Read bytes from the SSL stream into `buf`. + /// + /// Given the SSL state machine, this method may return either `WantWrite` + /// or `WantRead` to indicate that your event loop should respectively wait + /// for write or read readiness on the underlying stream. Upon readiness, + /// repeat your `read()` call with the same arguments each time until you + /// receive an `Ok(count)`. + /// + /// An `SslError` return value, is terminal; do not re-attempt your read. + /// + /// As expected of a nonblocking API, this method will never block your + /// thread on I/O. + /// + /// On a return value of `Ok(count)`, count is the number of decrypted + /// plaintext bytes copied into the `buf` slice. + pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> { + let ret = self.ssl.read(buf); + if ret >= 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } + + /// Write bytes from `buf` to the SSL stream. + /// + /// Given the SSL state machine, this method may return either `WantWrite` + /// or `WantRead` to indicate that your event loop should respectively wait + /// for write or read readiness on the underlying stream. Upon readiness, + /// repeat your `write()` call with the same arguments each time until you + /// receive an `Ok(count)`. + /// + /// An `SslError` return value, is terminal; do not re-attempt your write. + /// + /// As expected of a nonblocking API, this method will never block your + /// thread on I/O. + /// + /// Given a return value of `Ok(count)`, count is the number of plaintext bytes + /// from the `buf` slice that were encrypted and written onto the stream. + pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> { + let ret = self.ssl.write(buf); + if ret > 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } +} |