diff options
| author | Steven Fackler <[email protected]> | 2016-07-31 16:20:10 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2016-07-31 16:20:10 -0700 |
| commit | 2574bff52d379a2655e69e1e6498d4ff148558e6 (patch) | |
| tree | e84f3a70249af20408eecd46868a15070be5ee18 /openssl/src/ssl/mod.rs | |
| parent | Fix appveyor (diff) | |
| parent | Add MidHandshakeSslStream (diff) | |
| download | rust-openssl-2574bff52d379a2655e69e1e6498d4ff148558e6.tar.xz rust-openssl-2574bff52d379a2655e69e1e6498d4ff148558e6.zip | |
Merge pull request #432 from alexcrichton/mid-handshake
Add MidHandshakeSslStream
Diffstat (limited to 'openssl/src/ssl/mod.rs')
| -rw-r--r-- | openssl/src/ssl/mod.rs | 120 |
1 files changed, 112 insertions, 8 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index aba809fd..3d1ec6e5 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -5,6 +5,7 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; +use std::error as stderror; use std::mem; use std::str; use std::path::Path; @@ -832,6 +833,10 @@ impl Ssl { unsafe { ffi::SSL_accept(self.ssl) } } + fn handshake(&self) -> c_int { + unsafe { ffi::SSL_do_handshake(self.ssl) } + } + fn read(&self, buf: &mut [u8]) -> c_int { let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) } @@ -1081,31 +1086,49 @@ impl<S: Read + Write> SslStream<S> { } /// Creates an SSL/TLS client operating over the provided stream. - pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { - let ssl = try!(ssl.into_ssl()); + pub fn connect<T: IntoSsl>(ssl: T, stream: S) + -> Result<Self, HandshakeError<S>>{ + let ssl = try!(ssl.into_ssl().map_err(|e| { + HandshakeError::Failure(Error::Ssl(e)) + })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.connect(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { - Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), - err => Err(err) + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + Err(HandshakeError::Interrupted(MidHandshakeSslStream { + stream: stream, + error: e, + })) + } + err => Err(HandshakeError::Failure(err)), } } } /// Creates an SSL/TLS server operating over the provided stream. - pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { - let ssl = try!(ssl.into_ssl()); + pub fn accept<T: IntoSsl>(ssl: T, stream: S) + -> Result<Self, HandshakeError<S>> { + let ssl = try!(ssl.into_ssl().map_err(|e| { + HandshakeError::Failure(Error::Ssl(e)) + })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.accept(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { - Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), - err => Err(err) + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + Err(HandshakeError::Interrupted(MidHandshakeSslStream { + stream: stream, + error: e, + })) + } + err => Err(HandshakeError::Failure(err)), } } } @@ -1137,6 +1160,87 @@ impl<S: Read + Write> SslStream<S> { } } +/// An error or intermediate state after a TLS handshake attempt. +#[derive(Debug)] +pub enum HandshakeError<S> { + /// The handshake failed. + Failure(Error), + /// The handshake was interrupted midway through. + Interrupted(MidHandshakeSslStream<S>), +} + +impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> { + fn description(&self) -> &str { + match *self { + HandshakeError::Failure(ref e) => e.description(), + HandshakeError::Interrupted(ref e) => e.error.description(), + } + } + + fn cause(&self) -> Option<&stderror::Error> { + match *self { + HandshakeError::Failure(ref e) => Some(e), + HandshakeError::Interrupted(ref e) => Some(&e.error), + } + } +} + +impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(f.write_str(stderror::Error::description(self))); + if let Some(e) = stderror::Error::cause(self) { + try!(write!(f, ": {}", e)); + } + Ok(()) + } +} + +/// An SSL stream midway through the handshake process. +#[derive(Debug)] +pub struct MidHandshakeSslStream<S> { + stream: SslStream<S>, + error: Error, +} + +impl<S> MidHandshakeSslStream<S> { + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &S { + self.stream.get_ref() + } + + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut S { + self.stream.get_mut() + } + + /// Returns a shared reference to the `SslContext` of the stream. + pub fn ssl(&self) -> &Ssl { + self.stream.ssl() + } + + /// Returns the underlying error which interrupted this handshake. + pub fn error(&self) -> &Error { + &self.error + } + + /// Restarts the handshake process. + pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> { + let ret = self.stream.ssl.handshake(); + if ret > 0 { + Ok(self.stream) + } else { + match self.stream.make_error(ret) { + e @ Error::WantWrite(_) | + e @ Error::WantRead(_) => { + self.error = e; + Err(HandshakeError::Interrupted(self)) + } + err => Err(HandshakeError::Failure(err)), + } + } + } +} + impl<S> SslStream<S> { fn make_error(&mut self, ret: c_int) -> Error { self.check_panic(); |