aboutsummaryrefslogtreecommitdiff
path: root/openssl/src
diff options
context:
space:
mode:
authorSteven Fackler <[email protected]>2018-03-31 11:28:03 -0700
committerGitHub <[email protected]>2018-03-31 11:28:03 -0700
commite423da2d124de888187355d6aada527b1eaeaaa5 (patch)
treeaa713d6937fdfd022efcefe52509384e75b58c5c /openssl/src
parentMerge pull request #891 from sfackler/fix-vcpkg (diff)
parentAdd test for stateless connection (diff)
downloadrust-openssl-e423da2d124de888187355d6aada527b1eaeaaa5.tar.xz
rust-openssl-e423da2d124de888187355d6aada527b1eaeaaa5.zip
Merge pull request #858 from Ralith/stateless-api
Introduce SslStreamBuilder
Diffstat (limited to 'openssl/src')
-rw-r--r--openssl/src/ssl/callbacks.rs49
-rw-r--r--openssl/src/ssl/mod.rs202
-rw-r--r--openssl/src/ssl/test.rs121
3 files changed, 331 insertions, 41 deletions
diff --git a/openssl/src/ssl/callbacks.rs b/openssl/src/ssl/callbacks.rs
index 5b95ed02..bff71022 100644
--- a/openssl/src/ssl/callbacks.rs
+++ b/openssl/src/ssl/callbacks.rs
@@ -366,6 +366,55 @@ where
callback(ssl, line);
}
+#[cfg(ossl111)]
+pub extern "C" fn raw_stateless_cookie_generate<F>(
+ ssl: *mut ffi::SSL,
+ cookie: *mut c_uchar,
+ cookie_len: *mut size_t,
+) -> c_int
+where
+ F: Fn(&mut SslRef, &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send,
+{
+ unsafe {
+ let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl as *const _);
+ let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_callback_idx::<F>());
+ let ssl = SslRef::from_ptr_mut(ssl);
+ let callback = &*(callback as *mut F);
+ let slice =
+ slice::from_raw_parts_mut(cookie as *mut u8, ffi::SSL_COOKIE_LENGTH as usize);
+ match callback(ssl, slice) {
+ Ok(len) => {
+ *cookie_len = len as size_t;
+ 1
+ }
+ Err(e) => {
+ e.put();
+ 0
+ }
+ }
+ }
+}
+
+#[cfg(ossl111)]
+pub extern "C" fn raw_stateless_cookie_verify<F>(
+ ssl: *mut ffi::SSL,
+ cookie: *const c_uchar,
+ cookie_len: size_t,
+) -> c_int
+where
+ F: Fn(&mut SslRef, &[u8]) -> bool + 'static + Sync + Send,
+{
+ unsafe {
+ let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl as *const _);
+ let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_callback_idx::<F>());
+ let ssl = SslRef::from_ptr_mut(ssl);
+ let callback = &*(callback as *mut F);
+ let slice =
+ slice::from_raw_parts(cookie as *const c_uchar as *const u8, cookie_len as usize);
+ callback(ssl, slice) as c_int
+ }
+}
+
pub extern "C" fn raw_cookie_generate<F>(
ssl: *mut ffi::SSL,
cookie: *mut c_uchar,
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs
index c1021b8b..0f9e8935 100644
--- a/openssl/src/ssl/mod.rs
+++ b/openssl/src/ssl/mod.rs
@@ -1432,13 +1432,15 @@ impl SslContextBuilder {
}
}
- /// Sets the callback for generating an application cookie for stateless handshakes.
+ /// Sets the callback for generating an application cookie for TLS1.3
+ /// stateless handshakes.
///
/// The callback will be called with the SSL context and a slice into which the cookie
/// should be written. The callback should return the number of bytes written.
///
- /// This corresponds to `SSL_CTX_set_cookie_generate_cb`.
- pub fn set_cookie_generate_cb<F>(&mut self, callback: F)
+ /// This corresponds to `SSL_CTX_set_stateless_cookie_generate_cb`.
+ #[cfg(ossl111)]
+ pub fn set_stateless_cookie_generate_cb<F>(&mut self, callback: F)
where
F: Fn(&mut SslRef, &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send,
{
@@ -1447,13 +1449,14 @@ impl SslContextBuilder {
ffi::SSL_CTX_set_ex_data(
self.as_ptr(),
get_callback_idx::<F>(),
- mem::transmute(callback),
+ Box::into_raw(callback) as *mut _,
);
- ffi::SSL_CTX_set_cookie_generate_cb(self.as_ptr(), Some(raw_cookie_generate::<F>))
+ ffi::SSL_CTX_set_stateless_cookie_generate_cb(self.as_ptr(), Some(raw_stateless_cookie_generate::<F>))
}
}
- /// Sets the callback for verifying an application cookie for stateless handshakes.
+ /// Sets the callback for verifying an application cookie for TLS1.3
+ /// stateless handshakes.
///
/// The callback will be called with the SSL context and the cookie supplied by the
/// client. It should return true if and only if the cookie is valid.
@@ -1461,6 +1464,49 @@ impl SslContextBuilder {
/// Note that the OpenSSL implementation independently verifies the integrity of
/// application cookies using an HMAC before invoking the supplied callback.
///
+ /// This corresponds to `SSL_CTX_set_stateless_cookie_verify_cb`.
+ #[cfg(ossl111)]
+ pub fn set_stateless_cookie_verify_cb<F>(&mut self, callback: F)
+ where
+ F: Fn(&mut SslRef, &[u8]) -> bool + 'static + Sync + Send,
+ {
+ unsafe {
+ let callback = Box::new(callback);
+ ffi::SSL_CTX_set_ex_data(
+ self.as_ptr(),
+ get_callback_idx::<F>(),
+ Box::into_raw(callback) as *mut _,
+ );
+ ffi::SSL_CTX_set_stateless_cookie_verify_cb(self.as_ptr(), Some(raw_stateless_cookie_verify::<F>))
+ }
+ }
+
+ /// Sets the callback for generating a DTLSv1 cookie
+ ///
+ /// The callback will be called with the SSL context and a slice into which the cookie
+ /// should be written. The callback should return the number of bytes written.
+ ///
+ /// This corresponds to `SSL_CTX_set_cookie_generate_cb`.
+ pub fn set_cookie_generate_cb<F>(&mut self, callback: F)
+ where
+ F: Fn(&mut SslRef, &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send,
+ {
+ unsafe {
+ let callback = Box::new(callback);
+ ffi::SSL_CTX_set_ex_data(
+ self.as_ptr(),
+ get_callback_idx::<F>(),
+ Box::into_raw(callback) as *mut _,
+ );
+ ffi::SSL_CTX_set_cookie_generate_cb(self.as_ptr(), Some(raw_cookie_generate::<F>))
+ }
+ }
+
+ /// Sets the callback for verifying a DTLSv1 cookie
+ ///
+ /// The callback will be called with the SSL context and the cookie supplied by the
+ /// client. It should return true if and only if the cookie is valid.
+ ///
/// This corresponds to `SSL_CTX_set_cookie_verify_cb`.
pub fn set_cookie_verify_cb<F>(&mut self, callback: F)
where
@@ -1471,7 +1517,7 @@ impl SslContextBuilder {
ffi::SSL_CTX_set_ex_data(
self.as_ptr(),
get_callback_idx::<F>(),
- mem::transmute(callback),
+ Box::into_raw(callback) as *mut _,
);
ffi::SSL_CTX_set_cookie_verify_cb(self.as_ptr(), Some(raw_cookie_verify::<F>))
}
@@ -2590,22 +2636,7 @@ impl Ssl {
where
S: Read + Write,
{
- let mut stream = SslStream::new_base(self, stream);
- let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) };
- if ret > 0 {
- Ok(stream)
- } else {
- let error = stream.make_error(ret);
- match error.code() {
- ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
- MidHandshakeSslStream { stream, error },
- )),
- _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
- stream,
- error,
- })),
- }
- }
+ SslStreamBuilder::new(self, stream).connect()
}
/// Initiates a server-side TLS handshake.
@@ -2622,22 +2653,7 @@ impl Ssl {
where
S: Read + Write,
{
- let mut stream = SslStream::new_base(self, stream);
- let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) };
- if ret > 0 {
- Ok(stream)
- } else {
- let error = stream.make_error(ret);
- match error.code() {
- ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
- MidHandshakeSslStream { stream, error },
- )),
- _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
- stream,
- error,
- })),
- }
- }
+ SslStreamBuilder::new(self, stream).accept()
}
}
@@ -2909,6 +2925,114 @@ impl<S: Read + Write> Write for SslStream<S> {
}
}
+/// A partially constructed `SslStream`, useful for unusual handshakes.
+pub struct SslStreamBuilder<S> {
+ inner: SslStream<S>
+}
+
+impl<S> SslStreamBuilder<S>
+ where S: Read + Write
+{
+ /// Begin creating an `SslStream` atop `stream`
+ pub fn new(ssl: Ssl, stream: S) -> Self {
+ Self {
+ inner: SslStream::new_base(ssl, stream),
+ }
+ }
+
+ /// Perform a stateless server-side handshake
+ ///
+ /// Requires that cookie generation and verification callbacks were
+ /// set on the SSL context.
+ ///
+ /// Returns `Ok(true)` if a complete ClientHello containing a valid cookie
+ /// was read, in which case the handshake should be continued via
+ /// `accept`. If a HelloRetryRequest containing a fresh cookie was
+ /// transmitted, `Ok(false)` is returned instead. If the handshake cannot
+ /// proceed at all, `Err` is returned.
+ ///
+ /// This corresponds to [`SSL_stateless`]
+ ///
+ /// [`SSL_stateless`]: https://www.openssl.org/docs/manmaster/man3/SSL_stateless.html
+ #[cfg(ossl111)]
+ pub fn stateless(&mut self) -> Result<bool, ErrorStack> {
+ match unsafe { ffi::SSL_stateless(self.inner.ssl.as_ptr()) } {
+ 1 => Ok(true),
+ 0 => Ok(false),
+ -1 => Err(ErrorStack::get()),
+ _ => unreachable!(),
+ }
+ }
+
+ /// See `Ssl::connect`
+ pub fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
+ let mut stream = self.inner;
+ let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) };
+ if ret > 0 {
+ Ok(stream)
+ } else {
+ let error = stream.make_error(ret);
+ match error.code() {
+ ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
+ MidHandshakeSslStream { stream, error },
+ )),
+ _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
+ stream,
+ error,
+ })),
+ }
+ }
+ }
+
+ /// See `Ssl::accept`
+ pub fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
+ let mut stream = self.inner;
+ let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) };
+ if ret > 0 {
+ Ok(stream)
+ } else {
+ let error = stream.make_error(ret);
+ match error.code() {
+ ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
+ MidHandshakeSslStream { stream, error },
+ )),
+ _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
+ stream,
+ error,
+ })),
+ }
+ }
+ }
+
+ // Future work: early IO methods
+}
+
+impl<S> SslStreamBuilder<S> {
+ /// Returns a shared reference to the underlying stream.
+ pub fn get_ref(&self) -> &S {
+ unsafe {
+ let bio = self.inner.ssl.get_raw_rbio();
+ bio::get_ref(bio)
+ }
+ }
+
+ /// Returns a mutable reference to the underlying stream.
+ ///
+ /// # Warning
+ ///
+ /// It is inadvisable to read from or write to the underlying stream as it
+ /// will most likely corrupt the SSL session.
+ pub fn get_mut(&mut self) -> &mut S {
+ unsafe {
+ let bio = self.inner.ssl.get_raw_rbio();
+ bio::get_mut(bio)
+ }
+ }
+
+ /// Returns a shared reference to the `Ssl` object associated with this builder.
+ pub fn ssl(&self) -> &SslRef { &self.inner.ssl }
+}
+
/// The result of a shutdown request.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ShutdownResult {
diff --git a/openssl/src/ssl/test.rs b/openssl/src/ssl/test.rs
index c732f3fc..ddcb49ff 100644
--- a/openssl/src/ssl/test.rs
+++ b/openssl/src/ssl/test.rs
@@ -19,8 +19,9 @@ use hash::MessageDigest;
use ocsp::{OcspResponse, OcspResponseStatus};
use ssl;
use ssl::{Error, HandshakeError, ShutdownResult, Ssl, SslAcceptor, SslConnector, SslContext,
- SslFiletype, SslMethod, SslSessionCacheMode, SslStream, SslVerifyMode, StatusType};
-#[cfg(any(ossl110))]
+ SslFiletype, SslMethod, SslSessionCacheMode, SslStream, MidHandshakeSslStream,
+ SslVerifyMode, StatusType};
+#[cfg(any(ossl110, ossl111))]
use ssl::SslVersion;
use x509::{X509, X509Name, X509StoreContext, X509VerifyResult};
#[cfg(any(ossl102, ossl110))]
@@ -1389,3 +1390,119 @@ fn _check_kinds() {
is_send::<SslStream<TcpStream>>();
is_sync::<SslStream<TcpStream>>();
}
+
+#[derive(Debug)]
+struct MemoryStream {
+ incoming: io::Cursor<Vec<u8>>,
+ outgoing: Vec<u8>,
+}
+
+impl MemoryStream {
+ pub fn new() -> Self { Self {
+ incoming: io::Cursor::new(Vec::new()),
+ outgoing: Vec::new(),
+ }}
+
+ pub fn extend_incoming(&mut self, data: &[u8]) {
+ self.incoming.get_mut().extend_from_slice(data);
+ }
+
+ pub fn take_outgoing(&mut self) -> Outgoing { Outgoing(&mut self.outgoing) }
+}
+
+impl Read for MemoryStream {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ let n = self.incoming.read(buf)?;
+ if self.incoming.position() == self.incoming.get_ref().len() as u64 {
+ self.incoming.set_position(0);
+ self.incoming.get_mut().clear();
+ }
+ if n == 0 {
+ return Err(io::Error::new(io::ErrorKind::WouldBlock, "no data available"));
+ }
+ Ok(n)
+ }
+}
+
+impl Write for MemoryStream {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ self.outgoing.write(buf)
+ }
+
+ fn flush(&mut self) -> io::Result<()> { Ok(()) }
+}
+
+pub struct Outgoing<'a>(&'a mut Vec<u8>);
+
+impl<'a> Drop for Outgoing<'a> {
+ fn drop(&mut self) {
+ self.0.clear();
+ }
+}
+
+impl<'a> ::std::ops::Deref for Outgoing<'a> {
+ type Target = [u8];
+ fn deref(&self) -> &[u8] { &self.0 }
+}
+
+impl<'a> AsRef<[u8]> for Outgoing<'a> {
+ fn as_ref(&self) -> &[u8] { &self.0 }
+}
+
+fn send(from: &mut MemoryStream, to: &mut MemoryStream) {
+ to.extend_incoming(&from.take_outgoing());
+}
+
+#[test]
+#[cfg(ossl111)]
+fn stateless() {
+ use super::SslOptions;
+
+ fn hs<S: ::std::fmt::Debug>(stream: Result<SslStream<S>, HandshakeError<S>>) -> Result<SslStream<S>, MidHandshakeSslStream<S>> {
+ match stream {
+ Ok(stream) => Ok(stream),
+ Err(HandshakeError::WouldBlock(stream)) => Err(stream),
+ Err(e) => panic!("unexpected error: {:?}", e),
+ }
+ }
+
+ //
+ // Setup
+ //
+
+ let mut client_ctx = SslContext::builder(SslMethod::tls()).unwrap();
+ client_ctx.clear_options(SslOptions::ENABLE_MIDDLEBOX_COMPAT);
+ let client_stream = Ssl::new(&client_ctx.build()).unwrap();
+
+ let mut server_ctx = SslContext::builder(SslMethod::tls()).unwrap();
+ server_ctx.set_certificate_file(&Path::new("test/cert.pem"), SslFiletype::PEM)
+ .unwrap();
+ server_ctx.set_private_key_file(&Path::new("test/key.pem"), SslFiletype::PEM)
+ .unwrap();
+ const COOKIE: &[u8] = b"chocolate chip";
+ server_ctx.set_stateless_cookie_generate_cb(|_tls, buf| { buf[0..COOKIE.len()].copy_from_slice(COOKIE); Ok(COOKIE.len()) });
+ server_ctx.set_stateless_cookie_verify_cb(|_tls, buf| buf == COOKIE);
+ let mut server_stream = ssl::SslStreamBuilder::new(Ssl::new(&server_ctx.build()).unwrap(), MemoryStream::new());
+
+ //
+ // Handshake
+ //
+
+ // Initial ClientHello
+ let mut client_stream = hs(client_stream.connect(MemoryStream::new())).unwrap_err();
+ send(client_stream.get_mut(), server_stream.get_mut());
+ // HelloRetryRequest
+ assert!(!server_stream.stateless().unwrap());
+ send(server_stream.get_mut(), client_stream.get_mut());
+ // Second ClientHello
+ let mut client_stream = hs(client_stream.handshake()).unwrap_err();
+ send(client_stream.get_mut(), server_stream.get_mut());
+ // ServerHello
+ assert!(server_stream.stateless().unwrap());
+ let mut server_stream = hs(server_stream.accept()).unwrap_err();
+ send(server_stream.get_mut(), client_stream.get_mut());
+ // Finished
+ let mut client_stream = hs(client_stream.handshake()).unwrap();
+ send(client_stream.get_mut(), server_stream.get_mut());
+ hs(server_stream.handshake()).unwrap();
+}