aboutsummaryrefslogtreecommitdiff
path: root/openssl/src/ssl
diff options
context:
space:
mode:
authorSteven Fackler <[email protected]>2015-06-27 22:37:10 -0700
committerSteven Fackler <[email protected]>2015-06-27 22:37:10 -0700
commit9b235a7b9121613780810b0bc7b4d1f30dc861c9 (patch)
tree04188887225d5b764cd4ccdbf3fcb01d75389d53 /openssl/src/ssl
parentDocs tweak (diff)
downloadrust-openssl-9b235a7b9121613780810b0bc7b4d1f30dc861c9.tar.xz
rust-openssl-9b235a7b9121613780810b0bc7b4d1f30dc861c9.zip
Prepare for direct stream support
Diffstat (limited to 'openssl/src/ssl')
-rw-r--r--openssl/src/ssl/mod.rs243
1 files changed, 165 insertions, 78 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs
index cb4448b8..18acf7f8 100644
--- a/openssl/src/ssl/mod.rs
+++ b/openssl/src/ssl/mod.rs
@@ -5,7 +5,6 @@ use std::ffi::{CStr, CString};
use std::fmt;
use std::io;
use std::io::prelude::*;
-use std::iter;
use std::mem;
use std::net;
use std::path::Path;
@@ -740,52 +739,181 @@ make_LibSslError! {
ErrorWantAccept = SSL_ERROR_WANT_ACCEPT
}
+struct IndirectStream<S> {
+ stream: S,
+ ssl: Arc<Ssl>,
+ // Max TLS record size is 16k
+ buf: Box<[u8; 16 * 1024]>,
+}
+
+impl<S: Clone> Clone for IndirectStream<S> {
+ fn clone(&self) -> IndirectStream<S> {
+ IndirectStream {
+ stream: self.stream.clone(),
+ ssl: self.ssl.clone(),
+ buf: Box::new(*self.buf)
+ }
+ }
+}
+
+impl IndirectStream<net::TcpStream> {
+ fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> {
+ Ok(IndirectStream {
+ stream: try!(self.stream.try_clone()),
+ ssl: self.ssl.clone(),
+ buf: Box::new(*self.buf)
+ })
+ }
+}
+
+impl<S: Read+Write> IndirectStream<S> {
+ fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
+ let ssl = try!(ssl.into_ssl());
+ Ok(IndirectStream {
+ stream: stream,
+ ssl: Arc::new(ssl),
+ buf: Box::new([0; 16 * 1024]),
+ })
+ }
+
+ fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
+ let mut ssl = try!(IndirectStream::new_base(ssl, stream));
+ try!(ssl.in_retry_wrapper(|ssl| ssl.connect()));
+ Ok(ssl)
+ }
+
+ fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
+ let mut ssl = try!(IndirectStream::new_base(ssl, stream));
+ try!(ssl.in_retry_wrapper(|ssl| ssl.accept()));
+ Ok(ssl)
+ }
+
+ fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError>
+ where F: FnMut(&Ssl) -> c_int {
+ loop {
+ let ret = blk(&self.ssl);
+ if ret > 0 {
+ return Ok(ret);
+ }
+
+ let e = self.ssl.get_error(ret);
+ match e {
+ LibSslError::ErrorWantRead => {
+ try_ssl_stream!(self.flush());
+ let len = try_ssl_stream!(self.stream.read(&mut self.buf[..]));
+ if len == 0 {
+ self.ssl.get_rbio().set_eof(true);
+ } else {
+ try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len]));
+ }
+ }
+ LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) }
+ LibSslError::ErrorZeroReturn => return Err(SslSessionClosed),
+ LibSslError::ErrorSsl => return Err(SslError::get()),
+ LibSslError::ErrorSyscall if ret == 0 => return Ok(0),
+ err => panic!("unexpected error {:?} with ret {}", err, ret),
+ }
+ }
+ }
+
+ fn write_through(&mut self) -> io::Result<()> {
+ io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ())
+ }
+}
+
+impl<S: Read+Write> Read for IndirectStream<S> {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
+ Ok(len) => Ok(len as usize),
+ Err(SslSessionClosed) => Ok(0),
+ Err(StreamError(e)) => Err(e),
+ Err(e @ OpenSslErrors(_)) => {
+ Err(io::Error::new(io::ErrorKind::Other, e))
+ }
+ }
+ }
+}
+
+impl<S: Read+Write> Write for IndirectStream<S> {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
+ Ok(len) => len as usize,
+ Err(SslSessionClosed) => 0,
+ Err(StreamError(e)) => return Err(e),
+ Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
+ };
+ try!(self.write_through());
+ Ok(count)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ try!(self.write_through());
+ self.stream.flush()
+ }
+}
+
+#[derive(Clone)]
+enum StreamKind<S> {
+ Indirect(IndirectStream<S>),
+}
+
+impl<S> StreamKind<S> {
+ fn stream(&self) -> &S {
+ match *self {
+ StreamKind::Indirect(ref s) => &s.stream
+ }
+ }
+
+ fn mut_stream(&mut self) -> &mut S {
+ match *self {
+ StreamKind::Indirect(ref mut s) => &mut s.stream
+ }
+ }
+
+ fn ssl(&self) -> &Ssl {
+ match *self {
+ StreamKind::Indirect(ref s) => &s.ssl
+ }
+ }
+}
+
/// A stream wrapper which handles SSL encryption for an underlying stream.
#[derive(Clone)]
pub struct SslStream<S> {
- stream: S,
- ssl: Arc<Ssl>,
- buf: Vec<u8>
+ kind: StreamKind<S>,
}
impl SslStream<net::TcpStream> {
/// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
+ let kind = match self.kind {
+ StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone()))
+ };
Ok(SslStream {
- stream: try!(self.stream.try_clone()),
- ssl: self.ssl.clone(),
- buf: self.buf.clone(),
+ kind: kind
})
}
}
impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
- write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl)
+ write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl())
}
}
impl<S: Read+Write> SslStream<S> {
- fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
- let ssl = try!(ssl.into_ssl());
+ pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
+ let stream = try!(IndirectStream::new_client(ssl, stream));
Ok(SslStream {
- stream: stream,
- ssl: Arc::new(ssl),
- // Maximum TLS record size is 16k
- buf: iter::repeat(0).take(16 * 1024).collect(),
+ kind: StreamKind::Indirect(stream)
})
}
- pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
- let mut ssl = try!(SslStream::new_base(ssl, stream));
- try!(ssl.in_retry_wrapper(|ssl| ssl.connect()));
- Ok(ssl)
- }
-
pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
- let mut ssl = try!(SslStream::new_base(ssl, stream));
- try!(ssl.in_retry_wrapper(|ssl| ssl.accept()));
- Ok(ssl)
+ let stream = try!(IndirectStream::new_server(ssl, stream));
+ Ok(SslStream {
+ kind: StreamKind::Indirect(stream)
+ })
}
/// # Deprecated
@@ -811,12 +939,12 @@ impl<S: Read+Write> SslStream<S> {
/// Returns a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
- &self.stream
+ self.kind.stream()
}
/// Return the certificate of the peer
pub fn get_peer_certificate(&self) -> Option<X509> {
- self.ssl.get_peer_certificate()
+ self.kind.ssl().get_peer_certificate()
}
/// Returns a mutable reference to the underlying stream.
@@ -826,46 +954,14 @@ impl<S: Read+Write> SslStream<S> {
/// It is inadvisable to read from or write to the underlying stream as it
/// will most likely corrupt the SSL session.
pub fn get_mut(&mut self) -> &mut S {
- &mut self.stream
- }
-
- fn in_retry_wrapper<F>(&mut self, mut blk: F)
- -> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int {
- loop {
- let ret = blk(&self.ssl);
- if ret > 0 {
- return Ok(ret);
- }
-
- let e = self.ssl.get_error(ret);
- match e {
- LibSslError::ErrorWantRead => {
- try_ssl_stream!(self.flush());
- let len = try_ssl_stream!(self.stream.read(&mut self.buf[..]));
- if len == 0 {
- self.ssl.get_rbio().set_eof(true);
- } else {
- try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len]));
- }
- }
- LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) }
- LibSslError::ErrorZeroReturn => return Err(SslSessionClosed),
- LibSslError::ErrorSsl => return Err(SslError::get()),
- LibSslError::ErrorSyscall if ret == 0 => return Ok(0),
- err => panic!("unexpected error {:?} with ret {}", err, ret),
- }
- }
- }
-
- fn write_through(&mut self) -> io::Result<()> {
- io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ())
+ self.kind.mut_stream()
}
/// Get the compression currently in use. The result will be
/// either None, indicating no compression is in use, or a string
/// with the compression name.
pub fn get_compression(&self) -> Option<String> {
- let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) };
+ let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) };
if ptr == ptr::null() {
return None;
}
@@ -886,43 +982,34 @@ impl<S: Read+Write> SslStream<S> {
/// This method needs the `npn` feature.
#[cfg(feature = "npn")]
pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> {
- self.ssl.get_selected_npn_protocol()
+ self.kind.ssl().get_selected_npn_protocol()
}
/// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any).
pub fn pending(&self) -> usize {
- self.ssl.pending()
+ self.kind.ssl().pending()
}
}
impl<S: Read+Write> Read for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
- match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
- Ok(len) => Ok(len as usize),
- Err(SslSessionClosed) => Ok(0),
- Err(StreamError(e)) => Err(e),
- Err(e @ OpenSslErrors(_)) => {
- Err(io::Error::new(io::ErrorKind::Other, e))
- }
+ match self.kind {
+ StreamKind::Indirect(ref mut s) => s.read(buf)
}
}
}
impl<S: Read+Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
- let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
- Ok(len) => len as usize,
- Err(SslSessionClosed) => 0,
- Err(StreamError(e)) => return Err(e),
- Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
- };
- try!(self.write_through());
- Ok(count)
+ match self.kind {
+ StreamKind::Indirect(ref mut s) => s.write(buf)
+ }
}
fn flush(&mut self) -> io::Result<()> {
- try!(self.write_through());
- self.stream.flush()
+ match self.kind {
+ StreamKind::Indirect(ref mut s) => s.flush()
+ }
}
}