aboutsummaryrefslogtreecommitdiff
path: root/openssl/src
diff options
context:
space:
mode:
authorSteven Fackler <[email protected]>2015-06-28 00:06:14 -0700
committerSteven Fackler <[email protected]>2015-06-28 00:06:14 -0700
commit1373a76ce12d6a856b6caae7457ceb3eb5ad4122 (patch)
tree67317b8480482532e79dc31adb466fb78c49e1f6 /openssl/src
parentPrepare for direct stream support (diff)
downloadrust-openssl-1373a76ce12d6a856b6caae7457ceb3eb5ad4122.tar.xz
rust-openssl-1373a76ce12d6a856b6caae7457ceb3eb5ad4122.zip
Implement direct IO support
Diffstat (limited to 'openssl/src')
-rw-r--r--openssl/src/ssl/mod.rs181
-rw-r--r--openssl/src/ssl/tests.rs19
2 files changed, 187 insertions, 13 deletions
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs
index 18acf7f8..0e1e5b30 100644
--- a/openssl/src/ssl/mod.rs
+++ b/openssl/src/ssl/mod.rs
@@ -603,11 +603,6 @@ impl Ssl {
return Err(SslError::get());
}
let ssl = Ssl { ssl: ssl };
-
- let rbio = try!(MemBio::new());
- let wbio = try!(MemBio::new());
-
- unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) }
Ok(ssl)
}
@@ -769,6 +764,12 @@ impl IndirectStream<net::TcpStream> {
impl<S: Read+Write> IndirectStream<S> {
fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
let ssl = try!(ssl.into_ssl());
+
+ let rbio = try!(MemBio::new());
+ let wbio = try!(MemBio::new());
+
+ unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) }
+
Ok(IndirectStream {
stream: stream,
ssl: Arc::new(ssl),
@@ -853,26 +854,138 @@ impl<S: Read+Write> Write for IndirectStream<S> {
}
#[derive(Clone)]
+struct DirectStream<S> {
+ stream: S,
+ ssl: Arc<Ssl>,
+}
+
+impl DirectStream<net::TcpStream> {
+ fn try_clone(&self) -> io::Result<DirectStream<net::TcpStream>> {
+ Ok(DirectStream {
+ stream: try!(self.stream.try_clone()),
+ ssl: self.ssl.clone(),
+ })
+ }
+}
+
+impl<S> DirectStream<S> {
+ fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
+ unsafe {
+ let bio = ffi::BIO_new_socket(sock, 0);
+ if bio == ptr::null_mut() {
+ return Err(SslError::get());
+ }
+ ffi::SSL_set_bio(ssl.ssl, bio, bio);
+ }
+
+ Ok(DirectStream {
+ stream: stream,
+ ssl: Arc::new(ssl),
+ })
+ }
+
+ fn new_client(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
+ let ssl = try!(DirectStream::new_base(ssl, stream, sock));
+ let ret = ssl.ssl.connect();
+ if ret > 0 {
+ Ok(ssl)
+ } else {
+ Err(ssl.make_error(ret))
+ }
+ }
+
+ fn new_server(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
+ let ssl = try!(DirectStream::new_base(ssl, stream, sock));
+ let ret = ssl.ssl.accept();
+ if ret > 0 {
+ Ok(ssl)
+ } else {
+ Err(ssl.make_error(ret))
+ }
+ }
+
+ fn make_error(&self, ret: c_int) -> SslError {
+ match self.ssl.get_error(ret) {
+ LibSslError::ErrorSsl => SslError::get(),
+ LibSslError::ErrorSyscall => {
+ let err = SslError::get();
+ let count = match err {
+ SslError::OpenSslErrors(ref v) => v.len(),
+ _ => unreachable!(),
+ };
+ if count == 0 {
+ if ret == 0 {
+ SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
+ "unexpected EOF observed"))
+ } else {
+ SslError::StreamError(io::Error::last_os_error())
+ }
+ } else {
+ err
+ }
+ }
+ err => panic!("unexpected error {:?} with ret {}", err, ret),
+ }
+ }
+}
+
+impl<S> Read for DirectStream<S> {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ let ret = self.ssl.read(buf);
+ if ret >= 0 {
+ return Ok(ret as usize);
+ }
+
+ match self.make_error(ret) {
+ SslError::StreamError(e) => Err(e),
+ e => Err(io::Error::new(io::ErrorKind::Other, e)),
+ }
+ }
+}
+
+impl<S: Write> Write for DirectStream<S> {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ let ret = self.ssl.write(buf);
+ if ret > 0 {
+ return Ok(ret as usize);
+ }
+
+ match self.make_error(ret) {
+ SslError::StreamError(e) => Err(e),
+ e => Err(io::Error::new(io::ErrorKind::Other, e)),
+ }
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ self.stream.flush()
+ }
+}
+
+#[derive(Clone)]
enum StreamKind<S> {
Indirect(IndirectStream<S>),
+ Direct(DirectStream<S>),
}
impl<S> StreamKind<S> {
fn stream(&self) -> &S {
match *self {
- StreamKind::Indirect(ref s) => &s.stream
+ StreamKind::Indirect(ref s) => &s.stream,
+ StreamKind::Direct(ref s) => &s.stream,
}
}
fn mut_stream(&mut self) -> &mut S {
match *self {
- StreamKind::Indirect(ref mut s) => &mut s.stream
+ StreamKind::Indirect(ref mut s) => &mut s.stream,
+ StreamKind::Direct(ref mut s) => &mut s.stream,
}
}
fn ssl(&self) -> &Ssl {
match *self {
- StreamKind::Indirect(ref s) => &s.ssl
+ StreamKind::Indirect(ref s) => &s.ssl,
+ StreamKind::Direct(ref s) => &s.ssl,
}
}
}
@@ -887,7 +1000,8 @@ impl SslStream<net::TcpStream> {
/// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
let kind = match self.kind {
- StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone()))
+ StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())),
+ StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone()))
};
Ok(SslStream {
kind: kind
@@ -901,6 +1015,46 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
}
}
+#[cfg(unix)]
+impl<S: ::std::os::unix::io::AsRawFd> SslStream<S> {
+ pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
+ let ssl = try!(ssl.into_ssl());
+ let fd = stream.as_raw_fd() as c_int;
+ let stream = try!(DirectStream::new_client(ssl, stream, fd));
+ Ok(SslStream {
+ kind: StreamKind::Direct(stream)
+ })
+ }
+
+ pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
+ let ssl = try!(ssl.into_ssl());
+ let fd = stream.as_raw_fd() as c_int;
+ let stream = try!(DirectStream::new_server(ssl, stream, fd));
+ Ok(SslStream {
+ kind: StreamKind::Direct(stream)
+ })
+ }
+}
+
+#[cfg(windows)]
+impl<S: ::std::os::windows::io::AsRawSocket> SslStream<S> {
+ pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
+ let fd = stream.as_raw_socket() as c_int;
+ let stream = try!(DirectStream::new_client(ssl, stream, fd));
+ Ok(SslStream {
+ kind: StreamKind::Direct(stream)
+ })
+ }
+
+ pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
+ let fd = stream.as_raw_socket() as c_int;
+ let stream = try!(DirectStream::new_server(ssl, stream, fd));
+ Ok(SslStream {
+ kind: StreamKind::Direct(stream)
+ })
+ }
+}
+
impl<S: Read+Write> SslStream<S> {
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let stream = try!(IndirectStream::new_client(ssl, stream));
@@ -994,7 +1148,8 @@ impl<S: Read+Write> SslStream<S> {
impl<S: Read+Write> Read for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.kind {
- StreamKind::Indirect(ref mut s) => s.read(buf)
+ StreamKind::Indirect(ref mut s) => s.read(buf),
+ StreamKind::Direct(ref mut s) => s.read(buf),
}
}
}
@@ -1002,13 +1157,15 @@ impl<S: Read+Write> Read for SslStream<S> {
impl<S: Read+Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.kind {
- StreamKind::Indirect(ref mut s) => s.write(buf)
+ StreamKind::Indirect(ref mut s) => s.write(buf),
+ StreamKind::Direct(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self.kind {
- StreamKind::Indirect(ref mut s) => s.flush()
+ StreamKind::Indirect(ref mut s) => s.flush(),
+ StreamKind::Direct(ref mut s) => s.flush(),
}
}
}
diff --git a/openssl/src/ssl/tests.rs b/openssl/src/ssl/tests.rs
index a0e4a9d6..2ba940ab 100644
--- a/openssl/src/ssl/tests.rs
+++ b/openssl/src/ssl/tests.rs
@@ -317,8 +317,17 @@ fn test_write() {
stream.flush().unwrap();
}
+#[test]
+fn test_write_direct() {
+ let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
+ let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
+ stream.write_all("hello".as_bytes()).unwrap();
+ stream.flush().unwrap();
+ stream.write_all(" there".as_bytes()).unwrap();
+ stream.flush().unwrap();
+}
+
run_test!(get_peer_certificate, |method, stream| {
- //let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
let stream = SslStream::new_client(&SslContext::new(method).unwrap(), stream).unwrap();
let cert = stream.get_peer_certificate().unwrap();
let fingerprint = cert.fingerprint(SHA256).unwrap();
@@ -349,6 +358,14 @@ fn test_read() {
io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
}
+#[test]
+fn test_read_direct() {
+ let tcp = TcpStream::connect("127.0.0.1:15418").unwrap();
+ let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), tcp).unwrap();
+ stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap();
+ stream.flush().unwrap();
+ io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
+}
#[test]
fn test_pending() {