diff options
Diffstat (limited to 'openssl/src/ssl/mod.rs')
| -rw-r--r-- | openssl/src/ssl/mod.rs | 691 |
1 files changed, 530 insertions, 161 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index a0f97b17..88ba9af4 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -13,9 +13,9 @@ use std::sync::{Once, ONCE_INIT, Arc, Mutex}; use std::ops::{Deref, DerefMut}; use std::cmp; use std::any::Any; -#[cfg(feature = "npn")] +#[cfg(any(feature = "npn", feature = "alpn"))] use libc::{c_uchar, c_uint}; -#[cfg(feature = "npn")] +#[cfg(any(feature = "npn", feature = "alpn"))] use std::slice; use bio::{MemBio}; @@ -170,49 +170,37 @@ lazy_static! { // Registers a destructor for the data which will be called // when context is freed fn get_verify_data_idx<T: Any + 'static>() -> c_int { - extern fn free_data_box<T>(_parent: *mut c_void, ptr: *mut c_void, - _ad: *mut ffi::CRYPTO_EX_DATA, _idx: c_int, - _argl: c_long, _argp: *mut c_void) { - if ptr != 0 as *mut _ { - let _: Box<T> = unsafe { mem::transmute(ptr) }; - } - } - *INDEXES.lock().unwrap().entry(TypeId::of::<T>()).or_insert_with(|| { - unsafe { - let f: ffi::CRYPTO_EX_free = free_data_box::<T>; - let idx = ffi::SSL_CTX_get_ex_new_index(0, ptr::null(), None, None, Some(f)); - assert!(idx >= 0); - idx - } + get_new_idx::<T>() }) } -/// Creates a static index for the list of NPN protocols. -/// Registers a destructor for the data which will be called -/// when the context is freed. #[cfg(feature = "npn")] -fn get_npn_protos_idx() -> c_int { - static mut NPN_PROTOS_IDX: c_int = -1; - static mut INIT: Once = ONCE_INIT; +lazy_static! { + static ref NPN_PROTOS_IDX: c_int = get_new_idx::<Vec<u8>>(); +} +#[cfg(feature = "alpn")] +lazy_static! { + static ref ALPN_PROTOS_IDX: c_int = get_new_idx::<Vec<u8>>(); +} - extern fn free_data_box(_parent: *mut c_void, ptr: *mut c_void, +/// Determine a new index to use for SSL CTX ex data. +/// Registers a destruct for the data which will be called by openssl when the context is freed. +fn get_new_idx<T>() -> c_int { + extern fn free_data_box<T>(_parent: *mut c_void, ptr: *mut c_void, _ad: *mut ffi::CRYPTO_EX_DATA, _idx: c_int, _argl: c_long, _argp: *mut c_void) { if !ptr.is_null() { - let _: Box<Vec<u8>> = unsafe { mem::transmute(ptr) }; + let _: Box<T> = unsafe { mem::transmute(ptr) }; } } unsafe { - INIT.call_once(|| { - let f: ffi::CRYPTO_EX_free = free_data_box; - let idx = ffi::SSL_CTX_get_ex_new_index(0, ptr::null(), None, - None, Some(f)); - assert!(idx >= 0); - NPN_PROTOS_IDX = idx; - }); - NPN_PROTOS_IDX + let f: ffi::CRYPTO_EX_free = free_data_box::<T>; + let idx = ffi::SSL_CTX_get_ex_new_index(0, ptr::null(), None, + None, Some(f)); + assert!(idx >= 0); + idx } } @@ -264,6 +252,26 @@ extern fn raw_verify_with_data<T>(preverify_ok: c_int, } } +#[cfg(any(feature = "npn", feature = "alpn"))] +unsafe fn select_proto_using(ssl: *mut ffi::SSL, + out: *mut *mut c_uchar, outlen: *mut c_uchar, + inbuf: *const c_uchar, inlen: c_uint, + ex_data: c_int) -> c_int { + + // First, get the list of protocols (that the client should support) saved in the context + // extra data. + let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); + let protocols = ffi::SSL_CTX_get_ex_data(ssl_ctx, ex_data); + let protocols: &Vec<u8> = mem::transmute(protocols); + // Prepare the client list parameters to be passed to the OpenSSL function... + let client = protocols.as_ptr(); + let client_len = protocols.len() as c_uint; + // Finally, let OpenSSL find a protocol to be used, by matching the given server and + // client lists. + ffi::SSL_select_next_proto(out, outlen, inbuf, inlen, client, client_len); + ffi::SSL_TLSEXT_ERR_OK +} + /// The function is given as the callback to `SSL_CTX_set_next_proto_select_cb`. /// /// It chooses the protocol that the client wishes to use, out of the given list of protocols @@ -276,20 +284,18 @@ extern fn raw_next_proto_select_cb(ssl: *mut ffi::SSL, inbuf: *const c_uchar, inlen: c_uint, _arg: *mut c_void) -> c_int { unsafe { - // First, get the list of protocols (that the client should support) saved in the context - // extra data. - let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); - let protocols = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_npn_protos_idx()); - let protocols: &Vec<u8> = mem::transmute(protocols); - // Prepare the client list parameters to be passed to the OpenSSL function... - let client = protocols.as_ptr(); - let client_len = protocols.len() as c_uint; - // Finally, let OpenSSL find a protocol to be used, by matching the given server and - // client lists. - ffi::SSL_select_next_proto(out, outlen, inbuf, inlen, client, client_len); + select_proto_using(ssl, out, outlen, inbuf, inlen, *NPN_PROTOS_IDX) } +} - ffi::SSL_TLSEXT_ERR_OK +#[cfg(feature = "alpn")] +extern fn raw_alpn_select_cb(ssl: *mut ffi::SSL, + out: *mut *mut c_uchar, outlen: *mut c_uchar, + inbuf: *const c_uchar, inlen: c_uint, + _arg: *mut c_void) -> c_int { + unsafe { + select_proto_using(ssl, out, outlen, inbuf, inlen, *ALPN_PROTOS_IDX) + } } /// The function is given as the callback to `SSL_CTX_set_next_protos_advertised_cb`. @@ -306,7 +312,7 @@ extern fn raw_next_protos_advertise_cb(ssl: *mut ffi::SSL, unsafe { // First, get the list of (supported) protocols saved in the context extra data. let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); - let protocols = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_npn_protos_idx()); + let protocols = ffi::SSL_CTX_get_ex_data(ssl_ctx, *NPN_PROTOS_IDX); if protocols.is_null() { *out = b"".as_ptr(); *outlen = 0; @@ -322,6 +328,24 @@ extern fn raw_next_protos_advertise_cb(ssl: *mut ffi::SSL, ffi::SSL_TLSEXT_ERR_OK } +/// Convert a set of byte slices into a series of byte strings encoded for SSL. Encoding is a byte +/// containing the length followed by the string. +#[cfg(any(feature = "npn", feature = "alpn"))] +fn ssl_encode_byte_strings(strings: &[&[u8]]) -> Vec<u8> +{ + let mut enc = Vec::new(); + for string in strings { + let len = string.len() as u8; + if len as usize != string.len() { + // If the item does not fit, discard it + continue; + } + enc.push(len); + enc.extend(string[..len as usize].to_vec()); + } + enc +} + /// The signature of functions that can be used to manually verify certificates pub type VerifyCallback = fn(preverify_ok: bool, x509_ctx: &X509StoreContext) -> bool; @@ -495,7 +519,7 @@ impl SslContext { pub fn set_cipher_list(&mut self, cipher_list: &str) -> Result<(),SslError> { wrap_ssl_result( unsafe { - let cipher_list = CString::new(cipher_list.as_bytes()).unwrap(); + let cipher_list = CString::new(cipher_list).unwrap(); ffi::SSL_CTX_set_cipher_list(self.ctx, cipher_list.as_ptr()) }) } @@ -531,19 +555,12 @@ impl SslContext { pub fn set_npn_protocols(&mut self, protocols: &[&[u8]]) { // Firstly, convert the list of protocols to a byte-array that can be passed to OpenSSL // APIs -- a list of length-prefixed strings. - let mut npn_protocols = Vec::new(); - for protocol in protocols { - let len = protocol.len() as u8; - npn_protocols.push(len); - // If the length is greater than the max `u8`, this truncates the protocol name. - npn_protocols.extend(protocol[..len as usize].to_vec()); - } - let protocols: Box<Vec<u8>> = Box::new(npn_protocols); + let protocols: Box<Vec<u8>> = Box::new(ssl_encode_byte_strings(protocols)); unsafe { // Attach the protocol list to the OpenSSL context structure, // so that we can refer to it within the callback. - ffi::SSL_CTX_set_ex_data(self.ctx, get_npn_protos_idx(), + ffi::SSL_CTX_set_ex_data(self.ctx, *NPN_PROTOS_IDX, mem::transmute(protocols)); // Now register the callback that performs the default protocol // matching based on the client-supported list of protocols that @@ -554,6 +571,35 @@ impl SslContext { ffi::SSL_CTX_set_next_protos_advertised_cb(self.ctx, raw_next_protos_advertise_cb, ptr::null_mut()); } } + + /// Set the protocols to be used during ALPN (application layer protocol negotiation). + /// If this is a server, these are the protocols we report to the client. + /// If this is a client, these are the protocols we try to match with those reported by the + /// server. + /// + /// Note that ordering of the protocols controls the priority with which they are chosen. + /// + /// This method needs the `alpn` feature. + #[cfg(feature = "alpn")] + pub fn set_alpn_protocols(&mut self, protocols: &[&[u8]]) { + let protocols: Box<Vec<u8>> = Box::new(ssl_encode_byte_strings(protocols)); + unsafe { + // Set the context's internal protocol list for use if we are a server + ffi::SSL_CTX_set_alpn_protos(self.ctx, protocols.as_ptr(), protocols.len() as c_uint); + + // Rather than use the argument to the callback to contain our data, store it in the + // ssl ctx's ex_data so that we can configure a function to free it later. In the + // future, it might make sense to pull this into our internal struct Ssl instead of + // leaning on openssl and using function pointers. + ffi::SSL_CTX_set_ex_data(self.ctx, *ALPN_PROTOS_IDX, + mem::transmute(protocols)); + + // Now register the callback that performs the default protocol + // matching based on the client-supported list of protocols that + // has been saved. + ffi::SSL_CTX_set_alpn_select_cb(self.ctx, raw_alpn_select_cb, ptr::null_mut()); + } + } } #[allow(dead_code)] @@ -603,11 +649,6 @@ impl Ssl { return Err(SslError::get()); } let ssl = Ssl { ssl: ssl }; - - let rbio = try!(MemBio::new()); - let wbio = try!(MemBio::new()); - - unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } Ok(ssl) } @@ -655,16 +696,8 @@ impl Ssl { /// Set the host name to be used with SNI (Server Name Indication). pub fn set_hostname(&self, hostname: &str) -> Result<(), SslError> { - let ret = unsafe { - // This is defined as a macro: - // #define SSL_set_tlsext_host_name(s,name) \ - // SSL_ctrl(s,SSL_CTRL_SET_TLSEXT_HOSTNAME,TLSEXT_NAMETYPE_host_name,(char *)name) - - let hostname = CString::new(hostname.as_bytes()).unwrap(); - ffi::SSL_ctrl(self.ssl, ffi::SSL_CTRL_SET_TLSEXT_HOSTNAME, - ffi::TLSEXT_NAMETYPE_host_name, - hostname.as_ptr() as *mut c_void) - }; + let cstr = CString::new(hostname).unwrap(); + let ret = unsafe { ffi::SSL_set_tlsext_host_name(self.ssl, cstr.as_ptr()) }; // For this case, 0 indicates failure. if ret == 0 { @@ -708,6 +741,29 @@ impl Ssl { } } + /// Returns the protocol selected by performing ALPN, if any. + /// + /// The protocol's name is returned is an opaque sequence of bytes. It is up to the client + /// to interpret it. + /// + /// This method needs the `alpn` feature. + #[cfg(feature = "alpn")] + pub fn get_selected_alpn_protocol(&self) -> Option<&[u8]> { + unsafe { + let mut data: *const c_uchar = ptr::null(); + let mut len: c_uint = 0; + // Get the negotiated protocol from the SSL instance. + // `data` will point at a `c_uchar` array; `len` will contain the length of this array. + ffi::SSL_get0_alpn_selected(self.ssl, &mut data, &mut len); + + if data.is_null() { + None + } else { + Some(slice::from_raw_parts(data, len as usize)) + } + } + } + /// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any). pub fn pending(&self) -> usize { unsafe { @@ -747,71 +803,395 @@ make_LibSslError! { ErrorWantAccept = SSL_ERROR_WANT_ACCEPT } -/// A stream wrapper which handles SSL encryption for an underlying stream. +struct IndirectStream<S> { + stream: S, + ssl: Arc<Ssl>, + // Max TLS record size is 16k + buf: Box<[u8; 16 * 1024]>, +} + +impl<S: Clone> Clone for IndirectStream<S> { + fn clone(&self) -> IndirectStream<S> { + IndirectStream { + stream: self.stream.clone(), + ssl: self.ssl.clone(), + buf: Box::new(*self.buf) + } + } +} + +impl IndirectStream<net::TcpStream> { + fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> { + Ok(IndirectStream { + stream: try!(self.stream.try_clone()), + ssl: self.ssl.clone(), + buf: Box::new(*self.buf) + }) + } +} + +impl<S: Read+Write> IndirectStream<S> { + fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + + let rbio = try!(MemBio::new()); + let wbio = try!(MemBio::new()); + + unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } + + Ok(IndirectStream { + stream: stream, + ssl: Arc::new(ssl), + buf: Box::new([0; 16 * 1024]), + }) + } + + fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { + let mut ssl = try!(IndirectStream::new_base(ssl, stream)); + try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); + Ok(ssl) + } + + fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { + let mut ssl = try!(IndirectStream::new_base(ssl, stream)); + try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); + Ok(ssl) + } + + fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError> + where F: FnMut(&Ssl) -> c_int { + loop { + let ret = blk(&self.ssl); + if ret > 0 { + return Ok(ret); + } + + let e = self.ssl.get_error(ret); + match e { + LibSslError::ErrorWantRead => { + try_ssl_stream!(self.flush()); + let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); + if len == 0 { + self.ssl.get_rbio().set_eof(true); + } else { + try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); + } + } + LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } + LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), + LibSslError::ErrorSsl => return Err(SslError::get()), + LibSslError::ErrorSyscall if ret == 0 => return Ok(0), + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } + } + + fn write_through(&mut self) -> io::Result<()> { + io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) + } +} + +impl<S: Read+Write> Read for IndirectStream<S> { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { + Ok(len) => Ok(len as usize), + Err(SslSessionClosed) => Ok(0), + Err(StreamError(e)) => Err(e), + Err(e @ OpenSslErrors(_)) => { + Err(io::Error::new(io::ErrorKind::Other, e)) + } + } + } +} + +impl<S: Read+Write> Write for IndirectStream<S> { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { + Ok(len) => len as usize, + Err(SslSessionClosed) => 0, + Err(StreamError(e)) => return Err(e), + Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), + }; + try!(self.write_through()); + Ok(count) + } + + fn flush(&mut self) -> io::Result<()> { + try!(self.write_through()); + self.stream.flush() + } +} + #[derive(Clone)] -pub struct SslStream<S> { +struct DirectStream<S> { stream: S, ssl: Arc<Ssl>, - buf: Vec<u8> +} + +impl DirectStream<net::TcpStream> { + fn try_clone(&self) -> io::Result<DirectStream<net::TcpStream>> { + Ok(DirectStream { + stream: try!(self.stream.try_clone()), + ssl: self.ssl.clone(), + }) + } +} + +impl<S> DirectStream<S> { + fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { + unsafe { + let bio = ffi::BIO_new_socket(sock, 0); + if bio == ptr::null_mut() { + return Err(SslError::get()); + } + ffi::SSL_set_bio(ssl.ssl, bio, bio); + } + + Ok(DirectStream { + stream: stream, + ssl: Arc::new(ssl), + }) + } + + fn connect(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { + let ssl = try!(DirectStream::new_base(ssl, stream, sock)); + let ret = ssl.ssl.connect(); + if ret > 0 { + Ok(ssl) + } else { + Err(ssl.make_error(ret)) + } + } + + fn accept(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { + let ssl = try!(DirectStream::new_base(ssl, stream, sock)); + let ret = ssl.ssl.accept(); + if ret > 0 { + Ok(ssl) + } else { + Err(ssl.make_error(ret)) + } + } + + fn make_error(&self, ret: c_int) -> SslError { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => SslError::get(), + LibSslError::ErrorSyscall => { + let err = SslError::get(); + let count = match err { + SslError::OpenSslErrors(ref v) => v.len(), + _ => unreachable!(), + }; + if count == 0 { + if ret == 0 { + SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + SslError::StreamError(io::Error::last_os_error()) + } + } else { + err + } + } + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } +} + +impl<S> Read for DirectStream<S> { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + let ret = self.ssl.read(buf); + if ret >= 0 { + return Ok(ret as usize); + } + + match self.make_error(ret) { + SslError::StreamError(e) => Err(e), + e => Err(io::Error::new(io::ErrorKind::Other, e)), + } + } +} + +impl<S: Write> Write for DirectStream<S> { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let ret = self.ssl.write(buf); + if ret > 0 { + return Ok(ret as usize); + } + + match self.make_error(ret) { + SslError::StreamError(e) => Err(e), + e => Err(io::Error::new(io::ErrorKind::Other, e)), + } + } + + fn flush(&mut self) -> io::Result<()> { + self.stream.flush() + } +} + +#[derive(Clone)] +enum StreamKind<S> { + Indirect(IndirectStream<S>), + Direct(DirectStream<S>), +} + +impl<S> StreamKind<S> { + fn stream(&self) -> &S { + match *self { + StreamKind::Indirect(ref s) => &s.stream, + StreamKind::Direct(ref s) => &s.stream, + } + } + + fn mut_stream(&mut self) -> &mut S { + match *self { + StreamKind::Indirect(ref mut s) => &mut s.stream, + StreamKind::Direct(ref mut s) => &mut s.stream, + } + } + + fn ssl(&self) -> &Ssl { + match *self { + StreamKind::Indirect(ref s) => &s.ssl, + StreamKind::Direct(ref s) => &s.ssl, + } + } +} + +/// A stream wrapper which handles SSL encryption for an underlying stream. +#[derive(Clone)] +pub struct SslStream<S> { + kind: StreamKind<S>, } impl SslStream<net::TcpStream> { /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> { + let kind = match self.kind { + StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())), + StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone())) + }; Ok(SslStream { - stream: try!(self.stream.try_clone()), - ssl: self.ssl.clone(), - buf: self.buf.clone(), + kind: kind }) } } impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl) + write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl()) + } +} + +#[cfg(unix)] +impl<S: Read+Write+::std::os::unix::io::AsRawFd> SslStream<S> { + /// Creates an SSL/TLS client operating over the provided stream. + /// + /// Streams passed to this method must implement `AsRawFd` on Unixy + /// platforms and `AsRawSocket` on Windows. Use `connect_generic` for + /// streams that do not. + pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let stream = try!(DirectStream::connect(ssl, stream, fd)); + Ok(SslStream { + kind: StreamKind::Direct(stream) + }) + } + + /// Creates an SSL/TLS server operating over the provided stream. + /// + /// Streams passed to this method must implement `AsRawFd` on Unixy + /// platforms and `AsRawSocket` on Windows. Use `accept_generic` for + /// streams that do not. + pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let stream = try!(DirectStream::accept(ssl, stream, fd)); + Ok(SslStream { + kind: StreamKind::Direct(stream) + }) + } +} + +#[cfg(windows)] +impl<S: Read+Write+::std::os::windows::io::AsRawSocket> SslStream<S> { + /// Creates an SSL/TLS client operating over the provided stream. + /// + /// Streams passed to this method must implement `AsRawFd` on Unixy + /// platforms and `AsRawSocket` on Windows. Use `connect_generic` for + /// streams that do not. + pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_socket() as c_int; + let stream = try!(DirectStream::connect(ssl, stream, fd)); + Ok(SslStream { + kind: StreamKind::Direct(stream) + }) + } + + /// Creates an SSL/TLS server operating over the provided stream. + /// + /// Streams passed to this method must implement `AsRawFd` on Unixy + /// platforms and `AsRawSocket` on Windows. Use `accept_generic` for + /// streams that do not. + pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_socket() as c_int; + let stream = try!(DirectStream::accept(ssl, stream, fd)); + Ok(SslStream { + kind: StreamKind::Direct(stream) + }) } } impl<S: Read+Write> SslStream<S> { - fn new_base(ssl:Ssl, stream: S) -> SslStream<S> { - SslStream { - stream: stream, - ssl: Arc::new(ssl), - // Maximum TLS record size is 16k - // We're just using this as a buffer, so there's no reason to pay - // to memset it - buf: { - const CAP: usize = 16 * 1024; - let mut v = Vec::with_capacity(CAP); - unsafe { v.set_len(CAP); } - v - } - } + /// Creates an SSL/TLS client operating over the provided stream. + /// + /// `SslStream`s returned by this method will be less efficient than ones + /// returned by `connect`, so this method should only be used for streams + /// that do not implement `AsRawFd` and `AsRawSocket`. + pub fn connect_generic<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { + let stream = try!(IndirectStream::connect(ssl, stream)); + Ok(SslStream { + kind: StreamKind::Indirect(stream) + }) } + /// Creates an SSL/TLS server operating over the provided stream. + /// + /// `SslStream`s returned by this method will be less efficient than ones + /// returned by `accept`, so this method should only be used for streams + /// that do not implement `AsRawFd` and `AsRawSocket`. + pub fn accept_generic<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { + let stream = try!(IndirectStream::accept(ssl, stream)); + Ok(SslStream { + kind: StreamKind::Indirect(stream) + }) + } + + /// # Deprecated + pub fn new_server(ssl: &SslContext, stream: S) -> Result<SslStream<S>, SslError> { + SslStream::accept_generic(ssl, stream) + } + + /// # Deprecated pub fn new_server_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> { - let mut ssl = SslStream::new_base(ssl, stream); - ssl.in_retry_wrapper(|ssl| { ssl.accept() }).and(Ok(ssl)) + SslStream::accept_generic(ssl, stream) } - /// Attempts to create a new SSL stream from a given `Ssl` instance. + /// # Deprecated pub fn new_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> { - let mut ssl = SslStream::new_base(ssl, stream); - ssl.in_retry_wrapper(|ssl| { ssl.connect() }).and(Ok(ssl)) + SslStream::connect_generic(ssl, stream) } - /// Creates a new SSL stream + /// # Deprecated pub fn new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> { - let ssl = try!(Ssl::new(ctx)); - SslStream::new_from(ssl, stream) - } - - /// Creates a new SSL server stream - pub fn new_server(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> { - let ssl = try!(Ssl::new(ctx)); - SslStream::new_server_from(ssl, stream) + SslStream::connect_generic(ctx, stream) } + /// # Deprecated #[doc(hidden)] pub fn get_inner(&mut self) -> &mut S { self.get_mut() @@ -819,12 +1199,12 @@ impl<S: Read+Write> SslStream<S> { /// Returns a reference to the underlying stream. pub fn get_ref(&self) -> &S { - &self.stream + self.kind.stream() } /// Return the certificate of the peer pub fn get_peer_certificate(&self) -> Option<X509> { - self.ssl.get_peer_certificate() + self.kind.ssl().get_peer_certificate() } /// Returns a mutable reference to the underlying stream. @@ -832,48 +1212,16 @@ impl<S: Read+Write> SslStream<S> { /// ## Warning /// /// It is inadvisable to read from or write to the underlying stream as it - /// will most likely desynchronize the SSL session. + /// will most likely corrupt the SSL session. pub fn get_mut(&mut self) -> &mut S { - &mut self.stream - } - - fn in_retry_wrapper<F>(&mut self, mut blk: F) - -> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int { - loop { - let ret = blk(&self.ssl); - if ret > 0 { - return Ok(ret); - } - - let e = self.ssl.get_error(ret); - match e { - LibSslError::ErrorWantRead => { - try_ssl_stream!(self.flush()); - let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); - if len == 0 { - self.ssl.get_rbio().set_eof(true); - } else { - try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); - } - } - LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } - LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), - LibSslError::ErrorSsl => return Err(SslError::get()), - LibSslError::ErrorSyscall if ret == 0 => return Ok(0), - err => panic!("unexpected error {:?} with ret {}", err, ret), - } - } - } - - fn write_through(&mut self) -> io::Result<()> { - io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) + self.kind.mut_stream() } /// Get the compression currently in use. The result will be /// either None, indicating no compression is in use, or a string /// with the compression name. pub fn get_compression(&self) -> Option<String> { - let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) }; + let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) }; if ptr == ptr::null() { return None; } @@ -894,43 +1242,64 @@ impl<S: Read+Write> SslStream<S> { /// This method needs the `npn` feature. #[cfg(feature = "npn")] pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> { - self.ssl.get_selected_npn_protocol() + self.kind.ssl().get_selected_npn_protocol() + } + + /// Returns the protocol selected by performing ALPN, if any. + /// + /// The protocol's name is returned is an opaque sequence of bytes. It is up to the client + /// to interpret it. + /// + /// This method needs the `alpn` feature. + #[cfg(feature = "alpn")] + pub fn get_selected_alpn_protocol(&self) -> Option<&[u8]> { + self.kind.ssl().get_selected_alpn_protocol() } /// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any). pub fn pending(&self) -> usize { - self.ssl.pending() + self.kind.ssl().pending() } } impl<S: Read+Write> Read for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { - Ok(len) => Ok(len as usize), - Err(SslSessionClosed) => Ok(0), - Err(StreamError(e)) => Err(e), - Err(e @ OpenSslErrors(_)) => { - Err(io::Error::new(io::ErrorKind::Other, e)) - } + match self.kind { + StreamKind::Indirect(ref mut s) => s.read(buf), + StreamKind::Direct(ref mut s) => s.read(buf), } } } impl<S: Read+Write> Write for SslStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { - Ok(len) => len as usize, - Err(SslSessionClosed) => 0, - Err(StreamError(e)) => return Err(e), - Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), - }; - try!(self.write_through()); - Ok(count) + match self.kind { + StreamKind::Indirect(ref mut s) => s.write(buf), + StreamKind::Direct(ref mut s) => s.write(buf), + } } fn flush(&mut self) -> io::Result<()> { - try!(self.write_through()); - self.stream.flush() + match self.kind { + StreamKind::Indirect(ref mut s) => s.flush(), + StreamKind::Direct(ref mut s) => s.flush(), + } + } +} + +pub trait IntoSsl { + fn into_ssl(self) -> Result<Ssl, SslError>; +} + +impl IntoSsl for Ssl { + fn into_ssl(self) -> Result<Ssl, SslError> { + Ok(self) + } +} + +impl<'a> IntoSsl for &'a SslContext { + fn into_ssl(self) -> Result<Ssl, SslError> { + Ssl::new(self) } } |