diff options
| author | Alex Crichton <[email protected]> | 2016-07-27 14:18:02 -0700 |
|---|---|---|
| committer | Alex Crichton <[email protected]> | 2016-07-31 16:01:06 -0700 |
| commit | 3539be33660dc136e0b585400f70447e31ccb62a (patch) | |
| tree | 020de98557d936d6ab06c58291eb270c55e5c578 /openssl/src/ssl/mod.rs | |
| parent | Merge remote-tracking branch 'origin/master' into breaks (diff) | |
| download | rust-openssl-3539be33660dc136e0b585400f70447e31ccb62a.tar.xz rust-openssl-3539be33660dc136e0b585400f70447e31ccb62a.zip | |
Add MidHandshakeSslStream
Allows recognizing when a stream is still in handshake mode and can gracefully
transition when ready. The blocking usage of the API should still be the same,
just helps nonblocking implementations!
Diffstat (limited to 'openssl/src/ssl/mod.rs')
| -rw-r--r-- | openssl/src/ssl/mod.rs | 120 |
1 files changed, 112 insertions, 8 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index b18df7b3..9215a928 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -5,6 +5,7 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; +use std::error as stderror; use std::mem; use std::str; use std::path::Path; @@ -832,6 +833,10 @@ impl Ssl { unsafe { ffi::SSL_accept(self.ssl) } } + fn handshake(&self) -> c_int { + unsafe { ffi::SSL_do_handshake(self.ssl) } + } + fn read(&self, buf: &mut [u8]) -> c_int { let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) } @@ -1081,31 +1086,49 @@ impl<S: Read + Write> SslStream<S> { } /// Creates an SSL/TLS client operating over the provided stream. - pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { - let ssl = try!(ssl.into_ssl()); + pub fn connect<T: IntoSsl>(ssl: T, stream: S) + -> Result<Self, HandshakeError<S>>{ + let ssl = try!(ssl.into_ssl().map_err(|e| { + HandshakeError::Failure(Error::Ssl(e)) + })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.connect(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { - Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), - err => Err(err) + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + Err(HandshakeError::Interrupted(MidHandshakeSslStream { + stream: stream, + error: e, + })) + } + err => Err(HandshakeError::Failure(err)), } } } /// Creates an SSL/TLS server operating over the provided stream. - pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { - let ssl = try!(ssl.into_ssl()); + pub fn accept<T: IntoSsl>(ssl: T, stream: S) + -> Result<Self, HandshakeError<S>> { + let ssl = try!(ssl.into_ssl().map_err(|e| { + HandshakeError::Failure(Error::Ssl(e)) + })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.accept(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { - Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), - err => Err(err) + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + Err(HandshakeError::Interrupted(MidHandshakeSslStream { + stream: stream, + error: e, + })) + } + err => Err(HandshakeError::Failure(err)), } } } @@ -1137,6 +1160,87 @@ impl<S: Read + Write> SslStream<S> { } } +/// An error or intermediate state after a TLS handshake attempt. +#[derive(Debug)] +pub enum HandshakeError<S> { + /// The handshake failed. + Failure(Error), + /// The handshake was interrupted midway through. + Interrupted(MidHandshakeSslStream<S>), +} + +impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> { + fn description(&self) -> &str { + match *self { + HandshakeError::Failure(ref e) => e.description(), + HandshakeError::Interrupted(ref e) => e.error.description(), + } + } + + fn cause(&self) -> Option<&stderror::Error> { + match *self { + HandshakeError::Failure(ref e) => Some(e), + HandshakeError::Interrupted(ref e) => Some(&e.error), + } + } +} + +impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(f.write_str(stderror::Error::description(self))); + if let Some(e) = stderror::Error::cause(self) { + try!(write!(f, ": {}", e)); + } + Ok(()) + } +} + +/// An SSL stream midway through the handshake process. +#[derive(Debug)] +pub struct MidHandshakeSslStream<S> { + stream: SslStream<S>, + error: Error, +} + +impl<S> MidHandshakeSslStream<S> { + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &S { + self.stream.get_ref() + } + + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut S { + self.stream.get_mut() + } + + /// Returns a shared reference to the `SslContext` of the stream. + pub fn ssl(&self) -> &Ssl { + self.stream.ssl() + } + + /// Returns the underlying error which interrupted this handshake. + pub fn error(&self) -> &Error { + &self.error + } + + /// Restarts the handshake process. + pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> { + let ret = self.stream.ssl.handshake(); + if ret > 0 { + Ok(self.stream) + } else { + match self.stream.make_error(ret) { + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + self.error = e; + Err(HandshakeError::Interrupted(self)) + } + err => Err(HandshakeError::Failure(err)), + } + } + } +} + impl<S> SslStream<S> { fn make_error(&mut self, ret: c_int) -> Error { self.check_panic(); |