aboutsummaryrefslogtreecommitdiff
path: root/openssl/src
diff options
context:
space:
mode:
authorSteven Fackler <[email protected]>2015-12-12 15:46:17 -0800
committerSteven Fackler <[email protected]>2015-12-12 15:46:17 -0800
commitd6ce9afdf31faacaf435380feffcd13bf387255a (patch)
tree3a9251bb151561e1cad9ceba5f5c66e01edb8388 /openssl/src
parentBuild out a new error type (diff)
downloadrust-openssl-d6ce9afdf31faacaf435380feffcd13bf387255a.tar.xz
rust-openssl-d6ce9afdf31faacaf435380feffcd13bf387255a.zip
Have NonblockingSslStream delegate to SslStream
Diffstat (limited to 'openssl/src')
-rw-r--r--openssl/src/ssl/error.rs29
-rw-r--r--openssl/src/ssl/mod.rs213
2 files changed, 84 insertions, 158 deletions
diff --git a/openssl/src/ssl/error.rs b/openssl/src/ssl/error.rs
index 9a1a63b2..52ea6693 100644
--- a/openssl/src/ssl/error.rs
+++ b/openssl/src/ssl/error.rs
@@ -95,6 +95,11 @@ impl OpenSslError {
errs
}
+ /// Returns the raw OpenSSL error code for this error.
+ pub fn error_code(&self) -> c_ulong {
+ self.0
+ }
+
/// Returns the name of the library reporting the error.
pub fn library(&self) -> &'static str {
get_lib(self.0)
@@ -239,6 +244,17 @@ pub enum OpensslError {
}
}
+impl OpensslError {
+ pub fn from_error_code(err: c_ulong) -> OpensslError {
+ ffi::init();
+ UnknownError {
+ library: get_lib(err).to_owned(),
+ function: get_func(err).to_owned(),
+ reason: get_reason(err).to_owned()
+ }
+ }
+}
+
fn get_lib(err: c_ulong) -> &'static str {
unsafe {
let cstr = ffi::ERR_lib_error_string(err);
@@ -271,7 +287,7 @@ impl SslError {
loop {
match unsafe { ffi::ERR_get_error() } {
0 => break,
- err => errs.push(SslError::from_error_code(err))
+ err => errs.push(OpensslError::from_error_code(err))
}
}
OpenSslErrors(errs)
@@ -279,16 +295,7 @@ impl SslError {
/// Creates an `SslError` from the raw numeric error code.
pub fn from_error(err: c_ulong) -> SslError {
- OpenSslErrors(vec![SslError::from_error_code(err)])
- }
-
- fn from_error_code(err: c_ulong) -> OpensslError {
- ffi::init();
- UnknownError {
- library: get_lib(err).to_owned(),
- function: get_func(err).to_owned(),
- reason: get_reason(err).to_owned()
- }
+ OpenSslErrors(vec![OpensslError::from_error_code(err)])
}
}
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs
index 89c8bbfc..0ffa1120 100644
--- a/openssl/src/ssl/mod.rs
+++ b/openssl/src/ssl/mod.rs
@@ -10,7 +10,7 @@ use std::str;
use std::net;
use std::path::Path;
use std::ptr;
-use std::sync::{Once, ONCE_INIT, Arc, Mutex};
+use std::sync::{Once, ONCE_INIT, Mutex};
use std::cmp;
use std::any::Any;
#[cfg(any(feature = "npn", feature = "alpn"))]
@@ -18,11 +18,16 @@ use libc::{c_uchar, c_uint};
#[cfg(any(feature = "npn", feature = "alpn"))]
use std::slice;
use std::marker::PhantomData;
+#[cfg(unix)]
+use std::os::unix::io::{AsRawFd, RawFd};
+#[cfg(windows)]
+use std::os::windows::io::{AsRawSocket, RawSocket};
use ffi;
use ffi_extras;
use dh::DH;
-use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError};
+use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError,
+ OpensslError};
use x509::{X509StoreContext, X509FileType, X509};
use crypto::pkey::PKey;
@@ -935,6 +940,20 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
}
}
+#[cfg(unix)]
+impl<S: AsRawFd> AsRawFd for SslStream<S> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.get_ref().as_raw_fd()
+ }
+}
+
+#[cfg(windows)]
+impl<S: AsRawSocket> AsRawSocket for NonblockingSslStream<S> {
+ fn as_raw_fd(&self) -> RawSocket {
+ self.0.as_raw_socket()
+ }
+}
+
impl<S: Read+Write> SslStream<S> {
fn new_base(ssl: Ssl, stream: S) -> Self {
unsafe {
@@ -1247,65 +1266,38 @@ impl MaybeSslStream<net::TcpStream> {
/// # Deprecated
///
/// Use `SslStream` with `ssl_read` and `ssl_write`.
-#[derive(Clone)]
-pub struct NonblockingSslStream<S> {
- stream: S,
- ssl: Arc<Ssl>,
-}
+pub struct NonblockingSslStream<S>(SslStream<S>);
-impl NonblockingSslStream<net::TcpStream> {
- pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> {
- Ok(NonblockingSslStream {
- stream: try!(self.stream.try_clone()),
- ssl: self.ssl.clone(),
- })
+impl<S: Clone + Read + Write> Clone for NonblockingSslStream<S> {
+ fn clone(&self) -> Self {
+ NonblockingSslStream(self.0.clone())
}
}
-impl<S> NonblockingSslStream<S> {
- fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<NonblockingSslStream<S>, SslError> {
- unsafe {
- let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0));
- ffi_extras::BIO_set_nbio(bio, 1);
- ffi::SSL_set_bio(ssl.ssl, bio, bio);
- }
+#[cfg(unix)]
+impl<S: AsRawFd> AsRawFd for NonblockingSslStream<S> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.0.as_raw_fd()
+ }
+}
- Ok(NonblockingSslStream {
- stream: stream,
- ssl: Arc::new(ssl),
- })
+#[cfg(windows)]
+impl<S: AsRawSocket> AsRawSocket for NonblockingSslStream<S> {
+ fn as_raw_fd(&self) -> RawSocket {
+ self.0.as_raw_socket()
}
+}
- fn make_error(&self, ret: c_int) -> NonblockingSslError {
- match self.ssl.get_error(ret) {
- LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()),
- LibSslError::ErrorSyscall => {
- let err = SslError::get();
- let count = match err {
- SslError::OpenSslErrors(ref v) => v.len(),
- _ => unreachable!(),
- };
- let ssl_error = if count == 0 {
- if ret == 0 {
- SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
- "unexpected EOF observed"))
- } else {
- SslError::StreamError(io::Error::last_os_error())
- }
- } else {
- err
- };
- ssl_error.into()
- },
- LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite,
- LibSslError::ErrorWantRead => NonblockingSslError::WantRead,
- err => panic!("unexpected error {:?} with ret {}", err, ret),
- }
+impl NonblockingSslStream<net::TcpStream> {
+ pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> {
+ self.0.try_clone().map(NonblockingSslStream)
}
+}
+impl<S> NonblockingSslStream<S> {
/// Returns a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
- &self.stream
+ self.0.get_ref()
}
/// Returns a mutable reference to the underlying stream.
@@ -1315,91 +1307,23 @@ impl<S> NonblockingSslStream<S> {
/// It is inadvisable to read from or write to the underlying stream as it
/// will most likely corrupt the SSL session.
pub fn get_mut(&mut self) -> &mut S {
- &mut self.stream
+ self.0.get_mut()
}
/// Returns a reference to the Ssl.
pub fn ssl(&self) -> &Ssl {
- &self.ssl
- }
-}
-
-#[cfg(unix)]
-impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> {
- /// Create a new nonblocking client ssl connection on wrapped `stream`.
- ///
- /// Note that this method will most likely not actually complete the SSL
- /// handshake because doing so requires several round trips; the handshake will
- /// be completed in subsequent read/write calls managed by your event loop.
- pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
- let ssl = try!(ssl.into_ssl());
- let fd = stream.as_raw_fd() as c_int;
- let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
- let ret = ssl.ssl.connect();
- if ret > 0 {
- Ok(ssl)
- } else {
- // WantRead/WantWrite is okay here; we'll finish the handshake in
- // subsequent send/recv calls.
- match ssl.make_error(ret) {
- NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
- NonblockingSslError::SslError(other) => Err(other),
- }
- }
- }
-
- /// Create a new nonblocking server ssl connection on wrapped `stream`.
- ///
- /// Note that this method will most likely not actually complete the SSL
- /// handshake because doing so requires several round trips; the handshake will
- /// be completed in subsequent read/write calls managed by your event loop.
- pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
- let ssl = try!(ssl.into_ssl());
- let fd = stream.as_raw_fd() as c_int;
- let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
- let ret = ssl.ssl.accept();
- if ret > 0 {
- Ok(ssl)
- } else {
- // WantRead/WantWrite is okay here; we'll finish the handshake in
- // subsequent send/recv calls.
- match ssl.make_error(ret) {
- NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
- NonblockingSslError::SslError(other) => Err(other),
- }
- }
- }
-}
-
-#[cfg(unix)]
-impl<S: ::std::os::unix::io::AsRawFd> ::std::os::unix::io::AsRawFd for NonblockingSslStream<S> {
- fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd {
- self.stream.as_raw_fd()
+ self.0.ssl()
}
}
-#[cfg(windows)]
-impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> {
+impl<S: Read+Write> NonblockingSslStream<S> {
/// Create a new nonblocking client ssl connection on wrapped `stream`.
///
/// Note that this method will most likely not actually complete the SSL
/// handshake because doing so requires several round trips; the handshake will
/// be completed in subsequent read/write calls managed by your event loop.
pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
- let ssl = try!(ssl.into_ssl());
- let fd = stream.as_raw_socket() as c_int;
- let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
- let ret = ssl.ssl.connect();
- if ret > 0 {
- Ok(ssl)
- } else {
- // WantRead/WantWrite is okay here; we'll finish the handshake in
- // subsequent send/recv calls.
- match ssl.make_error(ret) {
- NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
- NonblockingSslError::SslError(other) => Err(other),
- }
- }
+ SslStream::connect(ssl, stream).map(NonblockingSslStream)
}
/// Create a new nonblocking server ssl connection on wrapped `stream`.
@@ -1408,24 +1332,25 @@ impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S>
/// handshake because doing so requires several round trips; the handshake will
/// be completed in subsequent read/write calls managed by your event loop.
pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
- let ssl = try!(ssl.into_ssl());
- let fd = stream.as_raw_socket() as c_int;
- let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
- let ret = ssl.ssl.accept();
- if ret > 0 {
- Ok(ssl)
- } else {
- // WantRead/WantWrite is okay here; we'll finish the handshake in
- // subsequent send/recv calls.
- match ssl.make_error(ret) {
- NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
- NonblockingSslError::SslError(other) => Err(other),
+ SslStream::accept(ssl, stream).map(NonblockingSslStream)
+ }
+
+ fn convert_err(&self, err: Error) -> NonblockingSslError {
+ match err {
+ Error::ZeroReturn => SslError::SslSessionClosed.into(),
+ Error::WantRead(_) => NonblockingSslError::WantRead,
+ Error::WantWrite(_) => NonblockingSslError::WantWrite,
+ Error::WantX509Lookup => unreachable!(),
+ Error::Stream(e) => SslError::StreamError(e).into(),
+ Error::Ssl(e) => {
+ SslError::OpenSslErrors(e.iter()
+ .map(|e| OpensslError::from_error_code(e.error_code()))
+ .collect())
+ .into()
}
}
}
-}
-impl<S: Read+Write> NonblockingSslStream<S> {
/// Read bytes from the SSL stream into `buf`.
///
/// Given the SSL state machine, this method may return either `WantWrite`
@@ -1442,11 +1367,10 @@ impl<S: Read+Write> NonblockingSslStream<S> {
/// On a return value of `Ok(count)`, count is the number of decrypted
/// plaintext bytes copied into the `buf` slice.
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> {
- let ret = self.ssl.read(buf);
- if ret >= 0 {
- Ok(ret as usize)
- } else {
- Err(self.make_error(ret))
+ match self.0.ssl_read(buf) {
+ Ok(n) => Ok(n),
+ Err(Error::ZeroReturn) => Ok(0),
+ Err(e) => Err(self.convert_err(e))
}
}
@@ -1466,11 +1390,6 @@ impl<S: Read+Write> NonblockingSslStream<S> {
/// Given a return value of `Ok(count)`, count is the number of plaintext bytes
/// from the `buf` slice that were encrypted and written onto the stream.
pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> {
- let ret = self.ssl.write(buf);
- if ret > 0 {
- Ok(ret as usize)
- } else {
- Err(self.make_error(ret))
- }
+ self.0.ssl_write(buf).map_err(|e| self.convert_err(e))
}
}