aboutsummaryrefslogtreecommitdiff
path: root/openssl/src
diff options
context:
space:
mode:
authorJamie Turner <[email protected]>2015-09-19 20:50:06 -0700
committerJamie Turner <[email protected]>2015-10-20 23:14:26 -0700
commitc37767df8fc1775858cd573cbe4d5e3a17fbd370 (patch)
treea80fcebc4891737bc833a51b4cf5034dcb63ab6f /openssl/src
parentMerge pull request #290 from jimmycuadra/master (diff)
downloadrust-openssl-c37767df8fc1775858cd573cbe4d5e3a17fbd370.tar.xz
rust-openssl-c37767df8fc1775858cd573cbe4d5e3a17fbd370.zip
Nonblocking streams support.
Diffstat (limited to 'openssl/src')
-rw-r--r--openssl/src/ssl/error.rs44
-rw-r--r--openssl/src/ssl/mod.rs231
-rw-r--r--openssl/src/ssl/tests.rs132
3 files changed, 405 insertions, 2 deletions
diff --git a/openssl/src/ssl/error.rs b/openssl/src/ssl/error.rs
index 9ff6cae9..0126b277 100644
--- a/openssl/src/ssl/error.rs
+++ b/openssl/src/ssl/error.rs
@@ -17,7 +17,20 @@ pub enum SslError {
/// The SSL session has been closed by the other end
SslSessionClosed,
/// An error in the OpenSSL library
- OpenSslErrors(Vec<OpensslError>)
+ OpenSslErrors(Vec<OpensslError>),
+}
+
+/// An error on a nonblocking stream.
+#[derive(Debug)]
+pub enum NonblockingSslError {
+ /// A standard SSL error occurred.
+ SslError(SslError),
+ /// The OpenSSL library wants data from the remote socket;
+ /// the caller should wait for read readiness.
+ WantRead,
+ /// The OpenSSL library wants to send data to the remote socket;
+ /// the caller should wait for write readiness.
+ WantWrite,
}
impl fmt::Display for SslError {
@@ -59,6 +72,35 @@ impl error::Error for SslError {
}
}
+impl fmt::Display for NonblockingSslError {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ fmt.write_str(error::Error::description(self))
+ }
+}
+
+impl error::Error for NonblockingSslError {
+ fn description(&self) -> &str {
+ match *self {
+ NonblockingSslError::SslError(ref e) => e.description(),
+ NonblockingSslError::WantRead => "The OpenSSL library wants data from the remote socket",
+ NonblockingSslError::WantWrite => "The OpenSSL library want to send data to the remote socket",
+ }
+ }
+
+ fn cause(&self) -> Option<&error::Error> {
+ match *self {
+ NonblockingSslError::SslError(ref e) => e.cause(),
+ _ => None
+ }
+ }
+}
+
+impl From<SslError> for NonblockingSslError {
+ fn from(e: SslError) -> NonblockingSslError {
+ NonblockingSslError::SslError(e)
+ }
+}
+
/// An error from the OpenSSL library
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OpensslError {
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs
index e76529a5..62080056 100644
--- a/openssl/src/ssl/mod.rs
+++ b/openssl/src/ssl/mod.rs
@@ -22,7 +22,7 @@ use std::slice;
use bio::{MemBio};
use ffi;
use dh::DH;
-use ssl::error::{SslError, SslSessionClosed, StreamError, OpenSslErrors};
+use ssl::error::{NonblockingSslError, SslError, SslSessionClosed, StreamError, OpenSslErrors};
use x509::{X509StoreContext, X509FileType, X509};
use crypto::pkey::PKey;
@@ -1465,3 +1465,232 @@ impl<S> MaybeSslStream<S> where S: Read+Write {
}
}
}
+
+/// An SSL stream wrapping a nonblocking socket.
+#[derive(Clone)]
+pub struct NonblockingSslStream<S> {
+ stream: S,
+ ssl: Arc<Ssl>,
+}
+
+impl NonblockingSslStream<net::TcpStream> {
+ pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> {
+ Ok(NonblockingSslStream {
+ stream: try!(self.stream.try_clone()),
+ ssl: self.ssl.clone(),
+ })
+ }
+}
+
+impl<S> NonblockingSslStream<S> {
+ fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<NonblockingSslStream<S>, SslError> {
+ unsafe {
+ let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0));
+ ffi::BIO_set_nbio(bio, 1);
+ ffi::SSL_set_bio(ssl.ssl, bio, bio);
+ }
+
+ Ok(NonblockingSslStream {
+ stream: stream,
+ ssl: Arc::new(ssl),
+ })
+ }
+
+ fn make_error(&self, ret: c_int) -> NonblockingSslError {
+ match self.ssl.get_error(ret) {
+ LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()),
+ LibSslError::ErrorSyscall => {
+ let err = SslError::get();
+ let count = match err {
+ SslError::OpenSslErrors(ref v) => v.len(),
+ _ => unreachable!(),
+ };
+ let ssl_error = if count == 0 {
+ if ret == 0 {
+ SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
+ "unexpected EOF observed"))
+ } else {
+ SslError::StreamError(io::Error::last_os_error())
+ }
+ } else {
+ err
+ };
+ ssl_error.into()
+ },
+ LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite,
+ LibSslError::ErrorWantRead => NonblockingSslError::WantRead,
+ err => panic!("unexpected error {:?} with ret {}", err, ret),
+ }
+ }
+
+ /// Returns a reference to the underlying stream.
+ pub fn get_ref(&self) -> &S {
+ &self.stream
+ }
+
+ /// Returns a mutable reference to the underlying stream.
+ ///
+ /// ## Warning
+ ///
+ /// It is inadvisable to read from or write to the underlying stream as it
+ /// will most likely corrupt the SSL session.
+ pub fn get_mut(&mut self) -> &mut S {
+ &mut self.stream
+ }
+
+ /// Returns a reference to the Ssl.
+ pub fn ssl(&self) -> &Ssl {
+ &self.ssl
+ }
+}
+
+#[cfg(unix)]
+impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> {
+ /// Create a new nonblocking client ssl connection on wrapped `stream`.
+ ///
+ /// Note that this method will most likely not actually complete the SSL
+ /// handshake because doing so requires several round trips; the handshake will
+ /// be completed in subsequent read/write calls managed by your event loop.
+ pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
+ let ssl = try!(ssl.into_ssl());
+ let fd = stream.as_raw_fd() as c_int;
+ let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
+ let ret = ssl.ssl.connect();
+ if ret > 0 {
+ Ok(ssl)
+ } else {
+ // WantRead/WantWrite is okay here; we'll finish the handshake in
+ // subsequent send/recv calls.
+ match ssl.make_error(ret) {
+ NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
+ NonblockingSslError::SslError(other) => Err(other),
+ }
+ }
+ }
+
+ /// Create a new nonblocking server ssl connection on wrapped `stream`.
+ ///
+ /// Note that this method will most likely not actually complete the SSL
+ /// handshake because doing so requires several round trips; the handshake will
+ /// be completed in subsequent read/write calls managed by your event loop.
+ pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
+ let ssl = try!(ssl.into_ssl());
+ let fd = stream.as_raw_fd() as c_int;
+ let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
+ let ret = ssl.ssl.accept();
+ if ret > 0 {
+ Ok(ssl)
+ } else {
+ // WantRead/WantWrite is okay here; we'll finish the handshake in
+ // subsequent send/recv calls.
+ match ssl.make_error(ret) {
+ NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
+ NonblockingSslError::SslError(other) => Err(other),
+ }
+ }
+ }
+}
+
+#[cfg(unix)]
+impl<S: ::std::os::unix::io::AsRawFd> ::std::os::unix::io::AsRawFd for NonblockingSslStream<S> {
+ fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd {
+ self.stream.as_raw_fd()
+ }
+}
+
+#[cfg(windows)]
+impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> {
+ /// Create a new nonblocking client ssl connection on wrapped `stream`.
+ ///
+ /// Note that this method will most likely not actually complete the SSL
+ /// handshake because doing so requires several round trips; the handshake will
+ /// be completed in subsequent read/write calls managed by your event loop.
+ pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
+ let ssl = try!(ssl.into_ssl());
+ let fd = stream.as_raw_socket() as c_int;
+ let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
+ let ret = ssl.ssl.connect();
+ if ret > 0 {
+ Ok(ssl)
+ } else {
+ // WantRead/WantWrite is okay here; we'll finish the handshake in
+ // subsequent send/recv calls.
+ match ssl.make_error(ret) {
+ NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
+ NonblockingSslError::SslError(other) => Err(other),
+ }
+ }
+ }
+
+ /// Create a new nonblocking server ssl connection on wrapped `stream`.
+ ///
+ /// Note that this method will most likely not actually complete the SSL
+ /// handshake because doing so requires several round trips; the handshake will
+ /// be completed in subsequent read/write calls managed by your event loop.
+ pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
+ let ssl = try!(ssl.into_ssl());
+ let fd = stream.as_raw_socket() as c_int;
+ let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
+ let ret = ssl.ssl.accept();
+ if ret > 0 {
+ Ok(ssl)
+ } else {
+ // WantRead/WantWrite is okay here; we'll finish the handshake in
+ // subsequent send/recv calls.
+ match ssl.make_error(ret) {
+ NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
+ NonblockingSslError::SslError(other) => Err(other),
+ }
+ }
+ }
+}
+
+impl<S: Read+Write> NonblockingSslStream<S> {
+ /// Read bytes from the SSL stream into `buf`.
+ ///
+ /// Given the SSL state machine, this method may return either `WantWrite`
+ /// or `WantRead` to indicate that your event loop should respectively wait
+ /// for write or read readiness on the underlying stream. Upon readiness,
+ /// repeat your `read()` call with the same arguments each time until you
+ /// receive an `Ok(count)`.
+ ///
+ /// An `SslError` return value, is terminal; do not re-attempt your read.
+ ///
+ /// As expected of a nonblocking API, this method will never block your
+ /// thread on I/O.
+ ///
+ /// On a return value of `Ok(count)`, count is the number of decrypted
+ /// plaintext bytes copied into the `buf` slice.
+ pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> {
+ let ret = self.ssl.read(buf);
+ if ret >= 0 {
+ Ok(ret as usize)
+ } else {
+ Err(self.make_error(ret))
+ }
+ }
+
+ /// Write bytes from `buf` to the SSL stream.
+ ///
+ /// Given the SSL state machine, this method may return either `WantWrite`
+ /// or `WantRead` to indicate that your event loop should respectively wait
+ /// for write or read readiness on the underlying stream. Upon readiness,
+ /// repeat your `write()` call with the same arguments each time until you
+ /// receive an `Ok(count)`.
+ ///
+ /// An `SslError` return value, is terminal; do not re-attempt your write.
+ ///
+ /// As expected of a nonblocking API, this method will never block your
+ /// thread on I/O.
+ ///
+ /// Given a return value of `Ok(count)`, count is the number of plaintext bytes
+ /// from the `buf` slice that were encrypted and written onto the stream.
+ pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> {
+ let ret = self.ssl.write(buf);
+ if ret > 0 {
+ Ok(ret as usize)
+ } else {
+ Err(self.make_error(ret))
+ }
+ }
+}
diff --git a/openssl/src/ssl/tests.rs b/openssl/src/ssl/tests.rs
index 033a3b86..8335bc53 100644
--- a/openssl/src/ssl/tests.rs
+++ b/openssl/src/ssl/tests.rs
@@ -819,3 +819,135 @@ fn test_sslv2_connect_failure() {
let (_s, tcp) = Server::new_tcp(&["-no_ssl2", "-www"]);
SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(), tcp).err().unwrap();
}
+
+#[cfg(target_os = "linux")]
+mod nonblocking_tests {
+ extern crate nix;
+
+ use std::io::Write;
+ use std::net::TcpStream;
+ use std::os::unix::io::AsRawFd;
+
+ use super::Server;
+ use self::nix::sys::epoll;
+ use self::nix::fcntl;
+ use ssl;
+ use ssl::error::NonblockingSslError;
+ use ssl::SslMethod;
+ use ssl::SslMethod::Sslv23;
+ use ssl::{SslContext, NonblockingSslStream};
+
+ fn wait_io(stream: &NonblockingSslStream<TcpStream>, read: bool, timeout_ms: isize) -> bool {
+ let fd = stream.as_raw_fd();
+ let ep = epoll::epoll_create().unwrap();
+ let event = if read {
+ epoll::EpollEvent {
+ events: epoll::EPOLLIN | epoll::EPOLLERR,
+ data: 0,
+ }
+ } else {
+ epoll::EpollEvent {
+ events: epoll::EPOLLOUT,
+ data: 0,
+ }
+ };
+ epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlAdd, fd, &event).unwrap();
+ let mut events = [event];
+ let count = epoll::epoll_wait(ep, &mut events, timeout_ms).unwrap();
+ epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlDel, fd, &event).unwrap();
+ assert!(count <= 1);
+ count == 1
+ }
+
+ fn make_nonblocking(stream: &TcpStream) {
+ let fd = stream.as_raw_fd();
+ fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::O_NONBLOCK)).unwrap();
+ }
+
+ #[test]
+ fn test_write_nonblocking() {
+ let (_s, stream) = Server::new();
+ make_nonblocking(&stream);
+ let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
+
+ let mut iterations = 0;
+ loop {
+ iterations += 1;
+ if iterations > 7 {
+ // Probably a safe assumption for the foreseeable future of openssl.
+ panic!("Too many read/write round trips in handshake!!");
+ }
+ let result = stream.write("hello".as_bytes());
+ match result {
+ Ok(_) => {
+ break;
+ },
+ Err(NonblockingSslError::WantRead) => {
+ assert!(wait_io(&stream, true, 1000));
+ },
+ Err(NonblockingSslError::WantWrite) => {
+ assert!(wait_io(&stream, false, 1000));
+ },
+ Err(other) => {
+ panic!("Unexpected SSL Error: {:?}", other);
+ },
+ }
+ }
+
+ // Second write should succeed immediately--plenty of space in kernel buffer,
+ // and handshake just completed.
+ stream.write(" there".as_bytes()).unwrap();
+ }
+
+ #[test]
+ fn test_read_nonblocking() {
+ let (_s, stream) = Server::new();
+ make_nonblocking(&stream);
+ let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
+
+ let mut iterations = 0;
+ loop {
+ iterations += 1;
+ if iterations > 7 {
+ // Probably a safe assumption for the foreseeable future of openssl.
+ panic!("Too many read/write round trips in handshake!!");
+ }
+ let result = stream.write("GET /\r\n\r\n".as_bytes());
+ match result {
+ Ok(n) => {
+ assert_eq!(n, 9);
+ break;
+ },
+ Err(NonblockingSslError::WantRead) => {
+ assert!(wait_io(&stream, true, 1000));
+ },
+ Err(NonblockingSslError::WantWrite) => {
+ assert!(wait_io(&stream, false, 1000));
+ },
+ Err(other) => {
+ panic!("Unexpected SSL Error: {:?}", other);
+ },
+ }
+ }
+ let mut input_buffer = [0u8; 1500];
+ let result = stream.read(&mut input_buffer);
+ let bytes_read = match result {
+ Ok(n) => {
+ // This branch is unlikely, but on an overloaded VM with
+ // unlucky context switching, the response could actually
+ // be in the receive buffer before we issue the read() syscall...
+ n
+ },
+ Err(NonblockingSslError::WantRead) => {
+ assert!(wait_io(&stream, true, 3000));
+ // Second read should return application data.
+ stream.read(&mut input_buffer).unwrap()
+ },
+ Err(other) => {
+ panic!("Unexpected SSL Error: {:?}", other);
+ },
+ };
+ assert!(bytes_read >= 5);
+ assert_eq!(&input_buffer[..5], b"HTTP/");
+ }
+}