diff options
| author | Alex Crichton <[email protected]> | 2016-07-27 14:18:02 -0700 |
|---|---|---|
| committer | Alex Crichton <[email protected]> | 2016-07-31 16:01:06 -0700 |
| commit | 3539be33660dc136e0b585400f70447e31ccb62a (patch) | |
| tree | 020de98557d936d6ab06c58291eb270c55e5c578 /openssl/src | |
| parent | Merge remote-tracking branch 'origin/master' into breaks (diff) | |
| download | rust-openssl-3539be33660dc136e0b585400f70447e31ccb62a.tar.xz rust-openssl-3539be33660dc136e0b585400f70447e31ccb62a.zip | |
Add MidHandshakeSslStream
Allows recognizing when a stream is still in handshake mode and can gracefully
transition when ready. The blocking usage of the API should still be the same,
just helps nonblocking implementations!
Diffstat (limited to 'openssl/src')
| -rw-r--r-- | openssl/src/ssl/mod.rs | 120 | ||||
| -rw-r--r-- | openssl/src/ssl/tests/mod.rs | 35 |
2 files changed, 136 insertions, 19 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index b18df7b3..9215a928 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -5,6 +5,7 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; +use std::error as stderror; use std::mem; use std::str; use std::path::Path; @@ -832,6 +833,10 @@ impl Ssl { unsafe { ffi::SSL_accept(self.ssl) } } + fn handshake(&self) -> c_int { + unsafe { ffi::SSL_do_handshake(self.ssl) } + } + fn read(&self, buf: &mut [u8]) -> c_int { let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) } @@ -1081,31 +1086,49 @@ impl<S: Read + Write> SslStream<S> { } /// Creates an SSL/TLS client operating over the provided stream. - pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { - let ssl = try!(ssl.into_ssl()); + pub fn connect<T: IntoSsl>(ssl: T, stream: S) + -> Result<Self, HandshakeError<S>>{ + let ssl = try!(ssl.into_ssl().map_err(|e| { + HandshakeError::Failure(Error::Ssl(e)) + })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.connect(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { - Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), - err => Err(err) + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + Err(HandshakeError::Interrupted(MidHandshakeSslStream { + stream: stream, + error: e, + })) + } + err => Err(HandshakeError::Failure(err)), } } } /// Creates an SSL/TLS server operating over the provided stream. - pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { - let ssl = try!(ssl.into_ssl()); + pub fn accept<T: IntoSsl>(ssl: T, stream: S) + -> Result<Self, HandshakeError<S>> { + let ssl = try!(ssl.into_ssl().map_err(|e| { + HandshakeError::Failure(Error::Ssl(e)) + })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.accept(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { - Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), - err => Err(err) + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + Err(HandshakeError::Interrupted(MidHandshakeSslStream { + stream: stream, + error: e, + })) + } + err => Err(HandshakeError::Failure(err)), } } } @@ -1137,6 +1160,87 @@ impl<S: Read + Write> SslStream<S> { } } +/// An error or intermediate state after a TLS handshake attempt. +#[derive(Debug)] +pub enum HandshakeError<S> { + /// The handshake failed. + Failure(Error), + /// The handshake was interrupted midway through. + Interrupted(MidHandshakeSslStream<S>), +} + +impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> { + fn description(&self) -> &str { + match *self { + HandshakeError::Failure(ref e) => e.description(), + HandshakeError::Interrupted(ref e) => e.error.description(), + } + } + + fn cause(&self) -> Option<&stderror::Error> { + match *self { + HandshakeError::Failure(ref e) => Some(e), + HandshakeError::Interrupted(ref e) => Some(&e.error), + } + } +} + +impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(f.write_str(stderror::Error::description(self))); + if let Some(e) = stderror::Error::cause(self) { + try!(write!(f, ": {}", e)); + } + Ok(()) + } +} + +/// An SSL stream midway through the handshake process. +#[derive(Debug)] +pub struct MidHandshakeSslStream<S> { + stream: SslStream<S>, + error: Error, +} + +impl<S> MidHandshakeSslStream<S> { + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &S { + self.stream.get_ref() + } + + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut S { + self.stream.get_mut() + } + + /// Returns a shared reference to the `SslContext` of the stream. + pub fn ssl(&self) -> &Ssl { + self.stream.ssl() + } + + /// Returns the underlying error which interrupted this handshake. + pub fn error(&self) -> &Error { + &self.error + } + + /// Restarts the handshake process. + pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> { + let ret = self.stream.ssl.handshake(); + if ret > 0 { + Ok(self.stream) + } else { + match self.stream.make_error(ret) { + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + self.error = e; + Err(HandshakeError::Interrupted(self)) + } + err => Err(HandshakeError::Failure(err)), + } + } + } +} + impl<S> SslStream<S> { fn make_error(&mut self, ret: c_int) -> Error { self.check_panic(); diff --git a/openssl/src/ssl/tests/mod.rs b/openssl/src/ssl/tests/mod.rs index 1fc0076b..0b638546 100644 --- a/openssl/src/ssl/tests/mod.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -17,7 +17,7 @@ use crypto::hash::Type::SHA256; use ssl; use ssl::SSL_VERIFY_PEER; use ssl::SslMethod::Sslv23; -use ssl::SslMethod; +use ssl::{SslMethod, HandshakeError}; use ssl::error::Error; use ssl::{SslContext, SslStream}; use x509::X509StoreContext; @@ -133,6 +133,7 @@ impl Drop for Server { } #[cfg(feature = "dtlsv1")] +#[derive(Debug)] struct UdpConnected(UdpSocket); #[cfg(feature = "dtlsv1")] @@ -846,10 +847,10 @@ fn test_sslv2_connect_failure() { .unwrap(); } -fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool { +fn wait_io(stream: &TcpStream, read: bool, timeout_ms: u32) -> bool { unsafe { let mut set: select::fd_set = mem::zeroed(); - select::fd_set(&mut set, stream.get_ref()); + select::fd_set(&mut set, stream); let write = if read { 0 as *mut _ @@ -861,7 +862,19 @@ fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool { } else { &mut set as *mut _ }; - select::select(stream.get_ref(), read, write, 0 as *mut _, timeout_ms).unwrap() + select::select(stream, read, write, 0 as *mut _, timeout_ms).unwrap() + } +} + +fn handshake(res: Result<SslStream<TcpStream>, HandshakeError<TcpStream>>) + -> SslStream<TcpStream> { + match res { + Ok(s) => s, + Err(HandshakeError::Interrupted(s)) => { + wait_io(s.get_ref(), true, 1_000); + handshake(s.handshake()) + } + Err(err) => panic!("error on handshake {:?}", err), } } @@ -870,7 +883,7 @@ fn test_write_nonblocking() { let (_s, stream) = Server::new(); stream.set_nonblocking(true).unwrap(); let cx = SslContext::new(Sslv23).unwrap(); - let mut stream = SslStream::connect(&cx, stream).unwrap(); + let mut stream = handshake(SslStream::connect(&cx, stream)); let mut iterations = 0; loop { @@ -886,10 +899,10 @@ fn test_write_nonblocking() { break; } Err(Error::WantRead(_)) => { - assert!(wait_io(&stream, true, 1000)); + assert!(wait_io(stream.get_ref(), true, 1000)); } Err(Error::WantWrite(_)) => { - assert!(wait_io(&stream, false, 1000)); + assert!(wait_io(stream.get_ref(), false, 1000)); } Err(other) => { panic!("Unexpected SSL Error: {:?}", other); @@ -907,7 +920,7 @@ fn test_read_nonblocking() { let (_s, stream) = Server::new(); stream.set_nonblocking(true).unwrap(); let cx = SslContext::new(Sslv23).unwrap(); - let mut stream = SslStream::connect(&cx, stream).unwrap(); + let mut stream = handshake(SslStream::connect(&cx, stream)); let mut iterations = 0; loop { @@ -924,10 +937,10 @@ fn test_read_nonblocking() { break; } Err(Error::WantRead(..)) => { - assert!(wait_io(&stream, true, 1000)); + assert!(wait_io(stream.get_ref(), true, 1000)); } Err(Error::WantWrite(..)) => { - assert!(wait_io(&stream, false, 1000)); + assert!(wait_io(stream.get_ref(), false, 1000)); } Err(other) => { panic!("Unexpected SSL Error: {:?}", other); @@ -944,7 +957,7 @@ fn test_read_nonblocking() { n } Err(Error::WantRead(..)) => { - assert!(wait_io(&stream, true, 3000)); + assert!(wait_io(stream.get_ref(), true, 3000)); // Second read should return application data. stream.read(&mut input_buffer).unwrap() } |