diff options
| author | Steven Fackler <[email protected]> | 2014-03-09 17:01:37 -0700 |
|---|---|---|
| committer | Steven Fackler <[email protected]> | 2014-03-09 17:01:37 -0700 |
| commit | 49678805041cc2824ce54c9a0a8cf9fb5447838e (patch) | |
| tree | 7328211dde9e811519515a9a651c9949e1859c98 | |
| parent | Add SSLv2 support behind a cfg flag (diff) | |
| download | rust-openssl-49678805041cc2824ce54c9a0a8cf9fb5447838e.tar.xz rust-openssl-49678805041cc2824ce54c9a0a8cf9fb5447838e.zip | |
Properly propogate errors
| -rw-r--r-- | ssl/mod.rs | 30 | ||||
| -rw-r--r-- | ssl/tests.rs | 12 |
2 files changed, 24 insertions, 18 deletions
@@ -18,6 +18,15 @@ static mut INIT: Once = ONCE_INIT; static mut VERIFY_IDX: c_int = -1; static mut MUTEXES: *mut ~[NativeMutex] = 0 as *mut ~[NativeMutex]; +macro_rules! try_ssl( + ($e:expr) => ( + match $e { + Ok(ok) => ok, + Err(err) => return Err(StreamError(err)) + } + ) +) + fn init() { unsafe { INIT.doit(|| { @@ -480,14 +489,11 @@ impl<S: Stream> SslStream<S> { match self.ssl.get_error(ret) { ErrorWantRead => { - self.flush(); - match self.stream.read(self.buf) { - Ok(len) => - self.ssl.get_rbio().write(self.buf.slice_to(len)), - Err(err) => return Err(StreamError(err)) - } + try_ssl!(self.flush()); + let len = try_ssl!(self.stream.read(self.buf)); + self.ssl.get_rbio().write(self.buf.slice_to(len)); } - ErrorWantWrite => { self.flush(); } + ErrorWantWrite => { try_ssl!(self.flush()) } ErrorZeroReturn => return Err(SslSessionClosed), ErrorSsl => return Err(SslError::get()), _ => unreachable!() @@ -495,14 +501,14 @@ impl<S: Stream> SslStream<S> { } } - fn write_through(&mut self) { + fn write_through(&mut self) -> IoResult<()> { loop { - // TODO propogate errors match self.ssl.get_wbio().read(self.buf) { - Some(len) => self.stream.write(self.buf.slice_to(len)), + Some(len) => try!(self.stream.write(self.buf.slice_to(len))), None => break }; } + Ok(()) } } @@ -533,13 +539,13 @@ impl<S: Stream> Writer for SslStream<S> { Ok(len) => start += len as uint, _ => unreachable!() } - self.write_through(); + try!(self.write_through()); } Ok(()) } fn flush(&mut self) -> IoResult<()> { - self.write_through(); + try!(self.write_through()); self.stream.flush() } } diff --git a/ssl/tests.rs b/ssl/tests.rs index 751ca7ab..c7f738f5 100644 --- a/ssl/tests.rs +++ b/ssl/tests.rs @@ -144,18 +144,18 @@ fn test_verify_trusted_get_error_err() { fn test_write() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); - stream.write("hello".as_bytes()); - stream.flush(); - stream.write(" there".as_bytes()); - stream.flush(); + stream.write("hello".as_bytes()).unwrap(); + stream.flush().unwrap(); + stream.write(" there".as_bytes()).unwrap(); + stream.flush().unwrap(); } #[test] fn test_read() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); - stream.write("GET /\r\n\r\n".as_bytes()); - stream.flush(); + stream.write("GET /\r\n\r\n".as_bytes()).unwrap(); + stream.flush().unwrap(); let buf = stream.read_to_end().ok().expect("read error"); print!("{}", str::from_utf8(buf)); } |