aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSteven Fackler <[email protected]>2014-03-09 17:01:37 -0700
committerSteven Fackler <[email protected]>2014-03-09 17:01:37 -0700
commit49678805041cc2824ce54c9a0a8cf9fb5447838e (patch)
tree7328211dde9e811519515a9a651c9949e1859c98
parentAdd SSLv2 support behind a cfg flag (diff)
downloadrust-openssl-49678805041cc2824ce54c9a0a8cf9fb5447838e.tar.xz
rust-openssl-49678805041cc2824ce54c9a0a8cf9fb5447838e.zip
Properly propogate errors
-rw-r--r--ssl/mod.rs30
-rw-r--r--ssl/tests.rs12
2 files changed, 24 insertions, 18 deletions
diff --git a/ssl/mod.rs b/ssl/mod.rs
index ad6dd023..e9b4e78b 100644
--- a/ssl/mod.rs
+++ b/ssl/mod.rs
@@ -18,6 +18,15 @@ static mut INIT: Once = ONCE_INIT;
static mut VERIFY_IDX: c_int = -1;
static mut MUTEXES: *mut ~[NativeMutex] = 0 as *mut ~[NativeMutex];
+macro_rules! try_ssl(
+ ($e:expr) => (
+ match $e {
+ Ok(ok) => ok,
+ Err(err) => return Err(StreamError(err))
+ }
+ )
+)
+
fn init() {
unsafe {
INIT.doit(|| {
@@ -480,14 +489,11 @@ impl<S: Stream> SslStream<S> {
match self.ssl.get_error(ret) {
ErrorWantRead => {
- self.flush();
- match self.stream.read(self.buf) {
- Ok(len) =>
- self.ssl.get_rbio().write(self.buf.slice_to(len)),
- Err(err) => return Err(StreamError(err))
- }
+ try_ssl!(self.flush());
+ let len = try_ssl!(self.stream.read(self.buf));
+ self.ssl.get_rbio().write(self.buf.slice_to(len));
}
- ErrorWantWrite => { self.flush(); }
+ ErrorWantWrite => { try_ssl!(self.flush()) }
ErrorZeroReturn => return Err(SslSessionClosed),
ErrorSsl => return Err(SslError::get()),
_ => unreachable!()
@@ -495,14 +501,14 @@ impl<S: Stream> SslStream<S> {
}
}
- fn write_through(&mut self) {
+ fn write_through(&mut self) -> IoResult<()> {
loop {
- // TODO propogate errors
match self.ssl.get_wbio().read(self.buf) {
- Some(len) => self.stream.write(self.buf.slice_to(len)),
+ Some(len) => try!(self.stream.write(self.buf.slice_to(len))),
None => break
};
}
+ Ok(())
}
}
@@ -533,13 +539,13 @@ impl<S: Stream> Writer for SslStream<S> {
Ok(len) => start += len as uint,
_ => unreachable!()
}
- self.write_through();
+ try!(self.write_through());
}
Ok(())
}
fn flush(&mut self) -> IoResult<()> {
- self.write_through();
+ try!(self.write_through());
self.stream.flush()
}
}
diff --git a/ssl/tests.rs b/ssl/tests.rs
index 751ca7ab..c7f738f5 100644
--- a/ssl/tests.rs
+++ b/ssl/tests.rs
@@ -144,18 +144,18 @@ fn test_verify_trusted_get_error_err() {
fn test_write() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
- stream.write("hello".as_bytes());
- stream.flush();
- stream.write(" there".as_bytes());
- stream.flush();
+ stream.write("hello".as_bytes()).unwrap();
+ stream.flush().unwrap();
+ stream.write(" there".as_bytes()).unwrap();
+ stream.flush().unwrap();
}
#[test]
fn test_read() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
- stream.write("GET /\r\n\r\n".as_bytes());
- stream.flush();
+ stream.write("GET /\r\n\r\n".as_bytes()).unwrap();
+ stream.flush().unwrap();
let buf = stream.read_to_end().ok().expect("read error");
print!("{}", str::from_utf8(buf));
}