aboutsummaryrefslogtreecommitdiff
path: root/openssl/src/ssl/mod.rs
diff options
context:
space:
mode:
authorSteven Fackler <[email protected]>2016-07-31 16:20:10 -0700
committerGitHub <[email protected]>2016-07-31 16:20:10 -0700
commit2574bff52d379a2655e69e1e6498d4ff148558e6 (patch)
treee84f3a70249af20408eecd46868a15070be5ee18 /openssl/src/ssl/mod.rs
parentFix appveyor (diff)
parentAdd MidHandshakeSslStream (diff)
downloadrust-openssl-2574bff52d379a2655e69e1e6498d4ff148558e6.tar.xz
rust-openssl-2574bff52d379a2655e69e1e6498d4ff148558e6.zip
Merge pull request #432 from alexcrichton/mid-handshake
Add MidHandshakeSslStream
Diffstat (limited to 'openssl/src/ssl/mod.rs')
-rw-r--r--openssl/src/ssl/mod.rs120
1 files changed, 112 insertions, 8 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs
index aba809fd..3d1ec6e5 100644
--- a/openssl/src/ssl/mod.rs
+++ b/openssl/src/ssl/mod.rs
@@ -5,6 +5,7 @@ use std::ffi::{CStr, CString};
use std::fmt;
use std::io;
use std::io::prelude::*;
+use std::error as stderror;
use std::mem;
use std::str;
use std::path::Path;
@@ -832,6 +833,10 @@ impl Ssl {
unsafe { ffi::SSL_accept(self.ssl) }
}
+ fn handshake(&self) -> c_int {
+ unsafe { ffi::SSL_do_handshake(self.ssl) }
+ }
+
fn read(&self, buf: &mut [u8]) -> c_int {
let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int;
unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) }
@@ -1081,31 +1086,49 @@ impl<S: Read + Write> SslStream<S> {
}
/// Creates an SSL/TLS client operating over the provided stream.
- pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> {
- let ssl = try!(ssl.into_ssl());
+ pub fn connect<T: IntoSsl>(ssl: T, stream: S)
+ -> Result<Self, HandshakeError<S>>{
+ let ssl = try!(ssl.into_ssl().map_err(|e| {
+ HandshakeError::Failure(Error::Ssl(e))
+ }));
let mut stream = Self::new_base(ssl, stream);
let ret = stream.ssl.connect();
if ret > 0 {
Ok(stream)
} else {
match stream.make_error(ret) {
- Error::WantRead(..) | Error::WantWrite(..) => Ok(stream),
- err => Err(err)
+ e @ Error::WantWrite(_) |
+ e @ Error::WantRead(_) => {
+ Err(HandshakeError::Interrupted(MidHandshakeSslStream {
+ stream: stream,
+ error: e,
+ }))
+ }
+ err => Err(HandshakeError::Failure(err)),
}
}
}
/// Creates an SSL/TLS server operating over the provided stream.
- pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> {
- let ssl = try!(ssl.into_ssl());
+ pub fn accept<T: IntoSsl>(ssl: T, stream: S)
+ -> Result<Self, HandshakeError<S>> {
+ let ssl = try!(ssl.into_ssl().map_err(|e| {
+ HandshakeError::Failure(Error::Ssl(e))
+ }));
let mut stream = Self::new_base(ssl, stream);
let ret = stream.ssl.accept();
if ret > 0 {
Ok(stream)
} else {
match stream.make_error(ret) {
- Error::WantRead(..) | Error::WantWrite(..) => Ok(stream),
- err => Err(err)
+ e @ Error::WantWrite(_) |
+ e @ Error::WantRead(_) => {
+ Err(HandshakeError::Interrupted(MidHandshakeSslStream {
+ stream: stream,
+ error: e,
+ }))
+ }
+ err => Err(HandshakeError::Failure(err)),
}
}
}
@@ -1137,6 +1160,87 @@ impl<S: Read + Write> SslStream<S> {
}
}
+/// An error or intermediate state after a TLS handshake attempt.
+#[derive(Debug)]
+pub enum HandshakeError<S> {
+ /// The handshake failed.
+ Failure(Error),
+ /// The handshake was interrupted midway through.
+ Interrupted(MidHandshakeSslStream<S>),
+}
+
+impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> {
+ fn description(&self) -> &str {
+ match *self {
+ HandshakeError::Failure(ref e) => e.description(),
+ HandshakeError::Interrupted(ref e) => e.error.description(),
+ }
+ }
+
+ fn cause(&self) -> Option<&stderror::Error> {
+ match *self {
+ HandshakeError::Failure(ref e) => Some(e),
+ HandshakeError::Interrupted(ref e) => Some(&e.error),
+ }
+ }
+}
+
+impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ try!(f.write_str(stderror::Error::description(self)));
+ if let Some(e) = stderror::Error::cause(self) {
+ try!(write!(f, ": {}", e));
+ }
+ Ok(())
+ }
+}
+
+/// An SSL stream midway through the handshake process.
+#[derive(Debug)]
+pub struct MidHandshakeSslStream<S> {
+ stream: SslStream<S>,
+ error: Error,
+}
+
+impl<S> MidHandshakeSslStream<S> {
+ /// Returns a shared reference to the inner stream.
+ pub fn get_ref(&self) -> &S {
+ self.stream.get_ref()
+ }
+
+ /// Returns a mutable reference to the inner stream.
+ pub fn get_mut(&mut self) -> &mut S {
+ self.stream.get_mut()
+ }
+
+ /// Returns a shared reference to the `SslContext` of the stream.
+ pub fn ssl(&self) -> &Ssl {
+ self.stream.ssl()
+ }
+
+ /// Returns the underlying error which interrupted this handshake.
+ pub fn error(&self) -> &Error {
+ &self.error
+ }
+
+ /// Restarts the handshake process.
+ pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> {
+ let ret = self.stream.ssl.handshake();
+ if ret > 0 {
+ Ok(self.stream)
+ } else {
+ match self.stream.make_error(ret) {
+ e @ Error::WantWrite(_) |
+ e @ Error::WantRead(_) => {
+ self.error = e;
+ Err(HandshakeError::Interrupted(self))
+ }
+ err => Err(HandshakeError::Failure(err)),
+ }
+ }
+ }
+}
+
impl<S> SslStream<S> {
fn make_error(&mut self, ret: c_int) -> Error {
self.check_panic();