aboutsummaryrefslogtreecommitdiff
path: root/openssl/src
diff options
context:
space:
mode:
authorBenjamin Saunders <[email protected]>2018-03-04 22:05:09 -0800
committerBenjamin Saunders <[email protected]>2018-03-28 18:14:48 -0700
commitf99c10155993ba2d34aa36f74d3458872787316b (patch)
tree0823428b7eced887ebf4596c5fef2bca7810174b /openssl/src
parentIntroduce SslStreamBuilder (diff)
downloadrust-openssl-f99c10155993ba2d34aa36f74d3458872787316b.tar.xz
rust-openssl-f99c10155993ba2d34aa36f74d3458872787316b.zip
Add test for stateless connection
Diffstat (limited to 'openssl/src')
-rw-r--r--openssl/src/ssl/test.rs121
1 files changed, 119 insertions, 2 deletions
diff --git a/openssl/src/ssl/test.rs b/openssl/src/ssl/test.rs
index c732f3fc..ddcb49ff 100644
--- a/openssl/src/ssl/test.rs
+++ b/openssl/src/ssl/test.rs
@@ -19,8 +19,9 @@ use hash::MessageDigest;
use ocsp::{OcspResponse, OcspResponseStatus};
use ssl;
use ssl::{Error, HandshakeError, ShutdownResult, Ssl, SslAcceptor, SslConnector, SslContext,
- SslFiletype, SslMethod, SslSessionCacheMode, SslStream, SslVerifyMode, StatusType};
-#[cfg(any(ossl110))]
+ SslFiletype, SslMethod, SslSessionCacheMode, SslStream, MidHandshakeSslStream,
+ SslVerifyMode, StatusType};
+#[cfg(any(ossl110, ossl111))]
use ssl::SslVersion;
use x509::{X509, X509Name, X509StoreContext, X509VerifyResult};
#[cfg(any(ossl102, ossl110))]
@@ -1389,3 +1390,119 @@ fn _check_kinds() {
is_send::<SslStream<TcpStream>>();
is_sync::<SslStream<TcpStream>>();
}
+
+#[derive(Debug)]
+struct MemoryStream {
+ incoming: io::Cursor<Vec<u8>>,
+ outgoing: Vec<u8>,
+}
+
+impl MemoryStream {
+ pub fn new() -> Self { Self {
+ incoming: io::Cursor::new(Vec::new()),
+ outgoing: Vec::new(),
+ }}
+
+ pub fn extend_incoming(&mut self, data: &[u8]) {
+ self.incoming.get_mut().extend_from_slice(data);
+ }
+
+ pub fn take_outgoing(&mut self) -> Outgoing { Outgoing(&mut self.outgoing) }
+}
+
+impl Read for MemoryStream {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ let n = self.incoming.read(buf)?;
+ if self.incoming.position() == self.incoming.get_ref().len() as u64 {
+ self.incoming.set_position(0);
+ self.incoming.get_mut().clear();
+ }
+ if n == 0 {
+ return Err(io::Error::new(io::ErrorKind::WouldBlock, "no data available"));
+ }
+ Ok(n)
+ }
+}
+
+impl Write for MemoryStream {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ self.outgoing.write(buf)
+ }
+
+ fn flush(&mut self) -> io::Result<()> { Ok(()) }
+}
+
+pub struct Outgoing<'a>(&'a mut Vec<u8>);
+
+impl<'a> Drop for Outgoing<'a> {
+ fn drop(&mut self) {
+ self.0.clear();
+ }
+}
+
+impl<'a> ::std::ops::Deref for Outgoing<'a> {
+ type Target = [u8];
+ fn deref(&self) -> &[u8] { &self.0 }
+}
+
+impl<'a> AsRef<[u8]> for Outgoing<'a> {
+ fn as_ref(&self) -> &[u8] { &self.0 }
+}
+
+fn send(from: &mut MemoryStream, to: &mut MemoryStream) {
+ to.extend_incoming(&from.take_outgoing());
+}
+
+#[test]
+#[cfg(ossl111)]
+fn stateless() {
+ use super::SslOptions;
+
+ fn hs<S: ::std::fmt::Debug>(stream: Result<SslStream<S>, HandshakeError<S>>) -> Result<SslStream<S>, MidHandshakeSslStream<S>> {
+ match stream {
+ Ok(stream) => Ok(stream),
+ Err(HandshakeError::WouldBlock(stream)) => Err(stream),
+ Err(e) => panic!("unexpected error: {:?}", e),
+ }
+ }
+
+ //
+ // Setup
+ //
+
+ let mut client_ctx = SslContext::builder(SslMethod::tls()).unwrap();
+ client_ctx.clear_options(SslOptions::ENABLE_MIDDLEBOX_COMPAT);
+ let client_stream = Ssl::new(&client_ctx.build()).unwrap();
+
+ let mut server_ctx = SslContext::builder(SslMethod::tls()).unwrap();
+ server_ctx.set_certificate_file(&Path::new("test/cert.pem"), SslFiletype::PEM)
+ .unwrap();
+ server_ctx.set_private_key_file(&Path::new("test/key.pem"), SslFiletype::PEM)
+ .unwrap();
+ const COOKIE: &[u8] = b"chocolate chip";
+ server_ctx.set_stateless_cookie_generate_cb(|_tls, buf| { buf[0..COOKIE.len()].copy_from_slice(COOKIE); Ok(COOKIE.len()) });
+ server_ctx.set_stateless_cookie_verify_cb(|_tls, buf| buf == COOKIE);
+ let mut server_stream = ssl::SslStreamBuilder::new(Ssl::new(&server_ctx.build()).unwrap(), MemoryStream::new());
+
+ //
+ // Handshake
+ //
+
+ // Initial ClientHello
+ let mut client_stream = hs(client_stream.connect(MemoryStream::new())).unwrap_err();
+ send(client_stream.get_mut(), server_stream.get_mut());
+ // HelloRetryRequest
+ assert!(!server_stream.stateless().unwrap());
+ send(server_stream.get_mut(), client_stream.get_mut());
+ // Second ClientHello
+ let mut client_stream = hs(client_stream.handshake()).unwrap_err();
+ send(client_stream.get_mut(), server_stream.get_mut());
+ // ServerHello
+ assert!(server_stream.stateless().unwrap());
+ let mut server_stream = hs(server_stream.accept()).unwrap_err();
+ send(server_stream.get_mut(), client_stream.get_mut());
+ // Finished
+ let mut client_stream = hs(client_stream.handshake()).unwrap();
+ send(client_stream.get_mut(), server_stream.get_mut());
+ hs(server_stream.handshake()).unwrap();
+}