diff options
Diffstat (limited to 'openssl/src/ssl')
| -rw-r--r-- | openssl/src/ssl/bio.rs | 98 | ||||
| -rw-r--r-- | openssl/src/ssl/mod.rs | 172 | ||||
| -rw-r--r-- | openssl/src/ssl/tests/mod.rs | 106 |
3 files changed, 339 insertions, 37 deletions
diff --git a/openssl/src/ssl/bio.rs b/openssl/src/ssl/bio.rs index a361ae81..aa445562 100644 --- a/openssl/src/ssl/bio.rs +++ b/openssl/src/ssl/bio.rs @@ -1,11 +1,12 @@ use libc::{c_char, c_int, c_long, c_void, strlen}; use ffi::{BIO, BIO_METHOD, BIO_CTRL_FLUSH, BIO_TYPE_NONE, BIO_new}; use ffi_extras::{BIO_clear_retry_flags, BIO_set_retry_read, BIO_set_retry_write}; +use std::any::Any; use std::io; use std::io::prelude::*; use std::mem; -use std::slice; use std::ptr; +use std::slice; use std::sync::Arc; use ssl::error::SslError; @@ -16,6 +17,7 @@ const NAME: [c_char; 5] = [114, 117, 115, 116, 0]; pub struct StreamState<S> { pub stream: S, pub error: Option<io::Error>, + pub panic: Option<Box<Any + Send>>, } pub fn new<S: Read + Write>(stream: S) -> Result<(*mut BIO, Arc<BIO_METHOD>), SslError> { @@ -35,6 +37,7 @@ pub fn new<S: Read + Write>(stream: S) -> Result<(*mut BIO, Arc<BIO_METHOD>), Ss let state = Box::new(StreamState { stream: stream, error: None, + panic: None, }); unsafe { @@ -51,6 +54,12 @@ pub unsafe fn take_error<S>(bio: *mut BIO) -> Option<io::Error> { state.error.take() } +#[cfg_attr(not(feature = "nightly"), allow(dead_code))] +pub unsafe fn take_panic<S>(bio: *mut BIO) -> Option<Box<Any + Send>> { + let state = state::<S>(bio); + state.panic.take() +} + pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S { let state: &'a StreamState<S> = mem::transmute((*bio).ptr); &state.stream @@ -64,20 +73,69 @@ unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState<S> { mem::transmute((*bio).ptr) } +#[cfg(feature = "nightly")] +fn recover<F, T>(f: F) -> Result<T, Box<Any + Send>> where F: FnOnce() -> T + ::std::panic::RecoverSafe { + ::std::panic::recover(f) +} + +#[cfg(not(feature = "nightly"))] +fn recover<F, T>(f: F) -> Result<T, Box<Any + Send>> where F: FnOnce() -> T { + Ok(f()) +} + +#[cfg(feature = "nightly")] +use std::panic::AssertRecoverSafe; + +#[cfg(not(feature = "nightly"))] +struct AssertRecoverSafe<T>(T); + +#[cfg(not(feature = "nightly"))] +impl<T> AssertRecoverSafe<T> { + fn new(t: T) -> Self { + AssertRecoverSafe(t) + } +} + +#[cfg(not(feature = "nightly"))] +impl<T> ::std::ops::Deref for AssertRecoverSafe<T> { + type Target = T; + + fn deref(&self) -> &T { + &self.0 + } +} + +#[cfg(not(feature = "nightly"))] +impl<T> ::std::ops::DerefMut for AssertRecoverSafe<T> { + fn deref_mut(&mut self) -> &mut T { + &mut self.0 + } +} + unsafe extern "C" fn bwrite<S: Write>(bio: *mut BIO, buf: *const c_char, len: c_int) -> c_int { BIO_clear_retry_flags(bio); let state = state::<S>(bio); let buf = slice::from_raw_parts(buf as *const _, len as usize); - match state.stream.write(buf) { - Ok(len) => len as c_int, - Err(err) => { + + let result = { + let mut youre_not_my_supervisor = AssertRecoverSafe::new(&mut *state); + recover(move || youre_not_my_supervisor.stream.write(buf)) + }; + + match result { + Ok(Ok(len)) => len as c_int, + Ok(Err(err)) => { if retriable_error(&err) { BIO_set_retry_write(bio); } state.error = Some(err); -1 } + Err(err) => { + state.panic = Some(err); + -1 + } } } @@ -86,15 +144,26 @@ unsafe extern "C" fn bread<S: Read>(bio: *mut BIO, buf: *mut c_char, len: c_int) let state = state::<S>(bio); let buf = slice::from_raw_parts_mut(buf as *mut _, len as usize); - match state.stream.read(buf) { - Ok(len) => len as c_int, - Err(err) => { + + let result = { + let mut youre_not_my_supervisor = AssertRecoverSafe::new(&mut *state); + let mut fuuuu = AssertRecoverSafe::new(buf); + recover(move || youre_not_my_supervisor.stream.read(&mut *fuuuu)) + }; + + match result { + Ok(Ok(len)) => len as c_int, + Ok(Err(err)) => { if retriable_error(&err) { BIO_set_retry_read(bio); } state.error = Some(err); -1 } + Err(err) => { + state.panic = Some(err); + -1 + } } } @@ -116,12 +185,21 @@ unsafe extern "C" fn ctrl<S: Write>(bio: *mut BIO, -> c_long { if cmd == BIO_CTRL_FLUSH { let state = state::<S>(bio); - match state.stream.flush() { - Ok(()) => 1, - Err(err) => { + let result = { + let mut youre_not_my_supervisor = AssertRecoverSafe::new(&mut *state); + recover(move || youre_not_my_supervisor.stream.flush()) + }; + + match result { + Ok(Ok(())) => 1, + Ok(Err(err)) => { state.error = Some(err); 0 } + Err(err) => { + state.panic = Some(err); + 0 + } } } else { 0 diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 955f10fd..b4c73479 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -26,8 +26,7 @@ use std::os::windows::io::{AsRawSocket, RawSocket}; use ffi; use ffi_extras; use dh::DH; -use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError, - OpensslError}; +use ssl::error::{NonblockingSslError, SslError, OpenSslError, OpensslError}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; @@ -482,6 +481,8 @@ fn wrap_ssl_result(res: c_int) -> Result<(), SslError> { } /// An SSL context object +/// +/// Internally ref-counted, use `.clone()` in the same way as Rc and Arc. pub struct SslContext { ctx: *mut ffi::SSL_CTX, } @@ -489,6 +490,12 @@ pub struct SslContext { unsafe impl Send for SslContext {} unsafe impl Sync for SslContext {} +impl Clone for SslContext { + fn clone(&self) -> Self { + unsafe { SslContext::new_ref(self.ctx) } + } +} + // TODO: add useful info here impl fmt::Debug for SslContext { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { @@ -503,6 +510,12 @@ impl Drop for SslContext { } impl SslContext { + // Create a new SslContext given an existing ref, and incriment ref-count appropriately. + unsafe fn new_ref(ctx: *mut ffi::SSL_CTX) -> SslContext { + rust_SSL_CTX_clone(ctx); + SslContext { ctx: ctx } + } + /// Creates a new SSL context. pub fn new(method: SslMethod) -> Result<SslContext, SslError> { init(); @@ -756,6 +769,71 @@ impl SslContext { } } + +pub struct CipherBits { + /// The number of secret bits used for the cipher. + pub secret: i32, + /// The number of bits processed by the chosen algorithm, if not None. + pub algorithm: Option<i32>, +} + + +pub struct SslCipher<'a> { + cipher: *const ffi::SSL_CIPHER, + ph: PhantomData<&'a ()>, +} + +impl <'a> SslCipher<'a> { + /// Returns the name of cipher. + pub fn name(&self) -> &'static str { + let name = unsafe { + let ptr = ffi::SSL_CIPHER_get_name(self.cipher); + CStr::from_ptr(ptr as *const _) + }; + + str::from_utf8(name.to_bytes()).unwrap() + } + + /// Returns the SSL/TLS protocol version that first defined the cipher. + pub fn version(&self) -> &'static str { + let version = unsafe { + let ptr = ffi::SSL_CIPHER_get_version(self.cipher); + CStr::from_ptr(ptr as *const _) + }; + + str::from_utf8(version.to_bytes()).unwrap() + } + + /// Returns the number of bits used for the cipher. + pub fn bits(&self) -> CipherBits { + unsafe { + let algo_bits : *mut c_int = ptr::null_mut(); + let secret_bits = ffi::SSL_CIPHER_get_bits(self.cipher, algo_bits); + if !algo_bits.is_null() { + CipherBits { secret: secret_bits, algorithm: Some(*algo_bits) } + } else { + CipherBits { secret: secret_bits, algorithm: None } + } + } + } + + /// Returns a textual description of the cipher used + pub fn description(&self) -> Option<String> { + unsafe { + // SSL_CIPHER_description requires a buffer of at least 128 bytes. + let mut buf = [0i8; 128]; + let desc_ptr = ffi::SSL_CIPHER_description(self.cipher, &mut buf[0], 128); + + if !desc_ptr.is_null() { + String::from_utf8(CStr::from_ptr(desc_ptr).to_bytes().to_vec()).ok() + } else { + None + } + } + } +} + + pub struct Ssl { ssl: *mut ffi::SSL, } @@ -823,6 +901,18 @@ impl Ssl { } } + pub fn get_current_cipher<'a>(&'a self) -> Option<SslCipher<'a>> { + unsafe { + let ptr = ffi::SSL_get_current_cipher(self.ssl); + + if ptr.is_null() { + None + } else { + Some(SslCipher{ cipher: ptr, ph: PhantomData }) + } + } + } + pub fn state_string(&self) -> &'static str { let state = unsafe { let ptr = ffi::SSL_state_string(self.ssl); @@ -868,6 +958,16 @@ impl Ssl { } } + /// Returns the name of the protocol used for the connection, e.g. "TLSv1.2", "SSLv3", etc. + pub fn version(&self) -> &'static str { + let version = unsafe { + let ptr = ffi::SSL_get_version(self.ssl); + CStr::from_ptr(ptr as *const _) + }; + + str::from_utf8(version.to_bytes()).unwrap() + } + /// Returns the protocol selected by performing Next Protocol Negotiation, if any. /// /// The protocol's name is returned is an opaque sequence of bytes. It is up to the client @@ -956,16 +1056,27 @@ impl Ssl { } /// change the context corresponding to the current connection + /// + /// Returns a clone of the SslContext @ctx (ie: the new context). The old context is freed. pub fn set_ssl_context(&self, ctx: &SslContext) -> SslContext { - SslContext { ctx: unsafe { ffi::SSL_set_SSL_CTX(self.ssl, ctx.ctx) } } + // If duplication of @ctx's cert fails, this returns NULL. This _appears_ to only occur on + // allocation failures (meaning panicing is probably appropriate), but it might be nice to + // propogate the error. + assert!(unsafe { ffi::SSL_set_SSL_CTX(self.ssl, ctx.ctx) } != ptr::null_mut()); + + // FIXME: we return this reference here for compatibility, but it isn't actually required. + // This should be removed when a api-incompatabile version is to be released. + // + // ffi:SSL_set_SSL_CTX() returns copy of the ctx pointer passed to it, so it's easier for + // us to do the clone directly. + ctx.clone() } /// obtain the context corresponding to the current connection pub fn get_ssl_context(&self) -> SslContext { unsafe { let ssl_ctx = ffi::SSL_get_SSL_CTX(self.ssl); - rust_SSL_CTX_clone(ssl_ctx); - SslContext { ctx: ssl_ctx } + SslContext::new_ref(ssl_ctx) } } } @@ -1070,10 +1181,9 @@ impl<S: Read + Write> SslStream<S> { if ret > 0 { Ok(stream) } else { - match stream.make_error(ret) { - // This is fine - nonblocking sockets will finish the handshake in read/write - Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), - _ => Err(stream.make_old_error(ret)), + match stream.make_old_error(ret) { + Some(err) => Err(err), + None => Ok(stream), } } } @@ -1086,10 +1196,9 @@ impl<S: Read + Write> SslStream<S> { if ret > 0 { Ok(stream) } else { - match stream.make_error(ret) { - // This is fine - nonblocking sockets will finish the handshake in read/write - Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), - _ => Err(stream.make_old_error(ret)), + match stream.make_old_error(ret) { + Some(err) => Err(err), + None => Ok(stream), } } } @@ -1137,6 +1246,8 @@ impl<S: Read + Write> SslStream<S> { impl<S> SslStream<S> { fn make_error(&mut self, ret: c_int) -> Error { + self.check_panic(); + match self.ssl.get_error(ret) { LibSslError::ErrorSsl => Error::Ssl(OpenSslError::get_stack()), LibSslError::ErrorSyscall => { @@ -1162,9 +1273,11 @@ impl<S> SslStream<S> { } } - fn make_old_error(&mut self, ret: c_int) -> SslError { + fn make_old_error(&mut self, ret: c_int) -> Option<SslError> { + self.check_panic(); + match self.ssl.get_error(ret) { - LibSslError::ErrorSsl => SslError::get(), + LibSslError::ErrorSsl => Some(SslError::get()), LibSslError::ErrorSyscall => { let err = SslError::get(); let count = match err { @@ -1173,26 +1286,35 @@ impl<S> SslStream<S> { }; if count == 0 { if ret == 0 { - SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, - "unexpected EOF observed")) + Some(SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed"))) } else { - SslError::StreamError(self.get_bio_error()) + Some(SslError::StreamError(self.get_bio_error())) } } else { - err + Some(err) } } - LibSslError::ErrorZeroReturn => SslError::SslSessionClosed, - LibSslError::ErrorWantWrite | LibSslError::ErrorWantRead => { - SslError::StreamError(self.get_bio_error()) - } + LibSslError::ErrorZeroReturn => Some(SslError::SslSessionClosed), + LibSslError::ErrorWantWrite | LibSslError::ErrorWantRead => None, err => { - SslError::StreamError(io::Error::new(io::ErrorKind::Other, - format!("unexpected error {:?}", err))) + Some(SslError::StreamError(io::Error::new(io::ErrorKind::Other, + format!("unexpected error {:?}", err)))) } } } + #[cfg(feature = "nightly")] + fn check_panic(&mut self) { + if let Some(err) = unsafe { bio::take_panic::<S>(self.ssl.get_raw_rbio()) } { + ::std::panic::propagate(err) + } + } + + #[cfg(not(feature = "nightly"))] + fn check_panic(&mut self) { + } + fn get_bio_error(&mut self) -> io::Error { let error = unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }; match error { diff --git a/openssl/src/ssl/tests/mod.rs b/openssl/src/ssl/tests/mod.rs index af3c005e..be35d7ef 100644 --- a/openssl/src/ssl/tests/mod.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -9,6 +9,7 @@ use std::net::{TcpStream, TcpListener, SocketAddr}; use std::path::Path; use std::process::{Command, Child, Stdio, ChildStdin}; use std::thread; +use std::time::Duration; use net2::TcpStreamExt; @@ -79,7 +80,7 @@ impl Server { match TcpStream::connect(&addr) { Ok(s) => return (server, s), Err(ref e) if e.kind() == io::ErrorKind::ConnectionRefused => { - thread::sleep_ms(100); + thread::sleep(Duration::from_millis(100)); } Err(e) => panic!("wut: {}", e), } @@ -117,7 +118,7 @@ impl Server { // Need to wait for the UDP socket to get bound in our child process, // but don't currently have a great way to do that so just wait for a // bit. - thread::sleep_ms(100); + thread::sleep(Duration::from_millis(100)); let socket = UdpSocket::bind(next_addr()).unwrap(); socket.connect(&addr).unwrap(); (s, UdpConnected(socket)) @@ -957,3 +958,104 @@ fn broken_try_clone_doesnt_crash() { let stream1 = SslStream::connect(&context, inner).unwrap(); let _stream2 = stream1.try_clone().unwrap(); } + +#[test] +#[should_panic(expected = "blammo")] +#[cfg(feature = "nightly")] +fn write_panic() { + struct ExplodingStream(TcpStream); + + impl Read for ExplodingStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.0.read(buf) + } + } + + impl Write for ExplodingStream { + fn write(&mut self, _: &[u8]) -> io::Result<usize> { + panic!("blammo"); + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + } + + let (_s, stream) = Server::new(); + let stream = ExplodingStream(stream); + + let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); + let _ = SslStream::connect(&ctx, stream); +} + +#[test] +#[should_panic(expected = "blammo")] +#[cfg(feature = "nightly")] +fn read_panic() { + struct ExplodingStream(TcpStream); + + impl Read for ExplodingStream { + fn read(&mut self, _: &mut [u8]) -> io::Result<usize> { + panic!("blammo"); + } + } + + impl Write for ExplodingStream { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + } + + let (_s, stream) = Server::new(); + let stream = ExplodingStream(stream); + + let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); + let _ = SslStream::connect(&ctx, stream); +} + +#[test] +#[should_panic(expected = "blammo")] +#[cfg(feature = "nightly")] +fn flush_panic() { + struct ExplodingStream(TcpStream); + + impl Read for ExplodingStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.0.read(buf) + } + } + + impl Write for ExplodingStream { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + panic!("blammo"); + } + } + + let (_s, stream) = Server::new(); + let stream = ExplodingStream(stream); + + let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); + let mut stream = SslStream::connect(&ctx, stream).unwrap(); + let _ = stream.flush(); +} + +#[test] +fn refcount_ssl_context() { + let ssl = { + let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); + ssl::Ssl::new(&ctx).unwrap() + }; + + { + let new_ctx_a = SslContext::new(SslMethod::Sslv23).unwrap(); + let _new_ctx_b = ssl.set_ssl_context(&new_ctx_a); + } +} |