aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSteven Fackler <[email protected]>2013-10-21 22:51:18 -0700
committerSteven Fackler <[email protected]>2013-10-21 22:51:18 -0700
commit302590c2b5fc75da807157073e3bd393d89b385e (patch)
treef7762682860ba7129dd3f47deef5bdcecf43addf
parentFill out the context methods (diff)
downloadrust-openssl-302590c2b5fc75da807157073e3bd393d89b385e.tar.xz
rust-openssl-302590c2b5fc75da807157073e3bd393d89b385e.zip
Major rewrite for better error handling
-rw-r--r--error.rs18
-rw-r--r--ffi.rs3
-rw-r--r--lib.rs333
-rw-r--r--tests.rs23
4 files changed, 205 insertions, 172 deletions
diff --git a/error.rs b/error.rs
new file mode 100644
index 00000000..b5fe0b6b
--- /dev/null
+++ b/error.rs
@@ -0,0 +1,18 @@
+use std::libc::c_ulong;
+
+use super::ffi;
+
+pub enum SslError {
+ StreamEof,
+ SslSessionClosed,
+ UnknownError(c_ulong)
+}
+
+impl SslError {
+ pub fn get() -> Option<SslError> {
+ match unsafe { ffi::ERR_get_error() } {
+ 0 => None,
+ err => Some(UnknownError(err))
+ }
+ }
+}
diff --git a/ffi.rs b/ffi.rs
index 57adfaf4..7fec9ad0 100644
--- a/ffi.rs
+++ b/ffi.rs
@@ -45,6 +45,8 @@ externfn!(fn SSL_CTX_load_verify_locations(ctx: *SSL_CTX, CAfile: *c_char,
externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL)
externfn!(fn SSL_free(ssl: *SSL))
externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO))
+externfn!(fn SSL_get_rbio(ssl: *SSL) -> *BIO)
+externfn!(fn SSL_get_wbio(ssl: *SSL) -> *BIO)
externfn!(fn SSL_set_connect_state(ssl: *SSL))
externfn!(fn SSL_connect(ssl: *SSL) -> c_int)
externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int)
@@ -54,5 +56,6 @@ externfn!(fn SSL_shutdown(ssl: *SSL) -> c_int)
externfn!(fn BIO_s_mem() -> *BIO_METHOD)
externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO)
+externfn!(fn BIO_free_all(a: *BIO))
externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int)
externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int)
diff --git a/lib.rs b/lib.rs
index 9c5f5845..ce3cc8db 100644
--- a/lib.rs
+++ b/lib.rs
@@ -1,15 +1,19 @@
-use std::rt::io::{Reader, Writer, Stream, Decorator};
-use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
-use std::task;
+use std::libc::{c_int, c_void};
use std::ptr;
+use std::task;
+use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
+use std::rt::io::{Stream, Reader, Writer, Decorator};
use std::vec;
-use std::libc::{c_int, c_void};
-mod ffi;
+use error::{SslError, SslSessionClosed, StreamEof};
+
+pub mod error;
#[cfg(test)]
mod tests;
+mod ffi;
+
static mut STARTED_INIT: AtomicBool = INIT_ATOMIC_BOOL;
static mut FINISHED_INIT: AtomicBool = INIT_ATOMIC_BOOL;
@@ -35,7 +39,7 @@ pub enum SslMethod {
}
impl SslMethod {
- unsafe fn to_fn(&self) -> *ffi::SSL_METHOD {
+ unsafe fn to_raw(&self) -> *ffi::SSL_METHOD {
match *self {
Sslv2 => ffi::SSLv2_method(),
Sslv3 => ffi::SSLv3_method(),
@@ -45,56 +49,60 @@ impl SslMethod {
}
}
-pub struct SslCtx {
+pub enum SslVerifyMode {
+ SslVerifyPeer = ffi::SSL_VERIFY_PEER,
+ SslVerifyNone = ffi::SSL_VERIFY_NONE
+}
+
+pub struct SslContext {
priv ctx: *ffi::SSL_CTX
}
-impl Drop for SslCtx {
+impl Drop for SslContext {
fn drop(&mut self) {
- unsafe { ffi::SSL_CTX_free(self.ctx); }
+ unsafe { ffi::SSL_CTX_free(self.ctx) }
}
}
-impl SslCtx {
- pub fn new(method: SslMethod) -> SslCtx {
+impl SslContext {
+ pub fn try_new(method: SslMethod) -> Result<SslContext, SslError> {
init();
- let ctx = unsafe { ffi::SSL_CTX_new(method.to_fn()) };
- assert!(ctx != ptr::null());
-
- SslCtx {
- ctx: ctx
+ let ctx = unsafe { ffi::SSL_CTX_new(method.to_raw()) };
+ if ctx == ptr::null() {
+ return Err(SslError::get().unwrap());
}
+
+ Ok(SslContext { ctx: ctx })
}
- pub fn set_verify(&mut self, mode: SslVerifyMode) {
- unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None) }
+ pub fn new(method: SslMethod) -> SslContext {
+ match SslContext::try_new(method) {
+ Ok(ctx) => ctx,
+ Err(err) => fail!("Error creating SSL context: {:?}", err)
+ }
}
- pub fn set_verify_locations(&mut self, CAfile: &str) {
- do CAfile.with_c_str |CAfile| {
- unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, CAfile,
- ptr::null()); }
+ // TODO: support callback (see SSL_CTX_set_ex_data)
+ pub fn set_verify(&mut self, mode: SslVerifyMode) {
+ unsafe {
+ ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None);
}
}
-}
-pub enum SslVerifyMode {
- SslVerifyNone = ffi::SSL_VERIFY_NONE,
- SslVerifyPeer = ffi::SSL_VERIFY_PEER
-}
+ pub fn set_CA_file(&mut self, file: &str) -> Option<SslError> {
+ let ret = do file.with_c_str |file| {
+ unsafe {
+ ffi::SSL_CTX_load_verify_locations(self.ctx, file, ptr::null())
+ }
+ };
-#[deriving(Eq, FromPrimitive)]
-enum SslError {
- ErrorNone = ffi::SSL_ERROR_NONE,
- ErrorSsl = ffi::SSL_ERROR_SSL,
- ErrorWantRead = ffi::SSL_ERROR_WANT_READ,
- ErrorWantWrite = ffi::SSL_ERROR_WANT_WRITE,
- ErrorWantX509Lookup = ffi::SSL_ERROR_WANT_X509_LOOKUP,
- ErrorSyscall = ffi::SSL_ERROR_SYSCALL,
- ErrorZeroReturn = ffi::SSL_ERROR_ZERO_RETURN,
- ErrorWantConnect = ffi::SSL_ERROR_WANT_CONNECT,
- ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT,
+ if ret == 0 {
+ Some(SslError::get().unwrap())
+ } else {
+ None
+ }
+ }
}
struct Ssl {
@@ -103,138 +111,155 @@ struct Ssl {
impl Drop for Ssl {
fn drop(&mut self) {
- unsafe { ffi::SSL_free(self.ssl); }
+ unsafe { ffi::SSL_free(self.ssl) }
}
}
impl Ssl {
- fn new(ctx: &SslCtx) -> Ssl {
+ fn try_new(ctx: &SslContext) -> Result<Ssl, SslError> {
let ssl = unsafe { ffi::SSL_new(ctx.ctx) };
- assert!(ssl != ptr::null());
+ if ssl == ptr::null() {
+ return Err(SslError::get().unwrap());
+ }
+ let ssl = Ssl { ssl: ssl };
+
+ let rbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
+ if rbio == ptr::null() {
+ return Err(SslError::get().unwrap());
+ }
- Ssl { ssl: ssl }
+ let wbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
+ if wbio == ptr::null() {
+ unsafe { ffi::BIO_free_all(rbio) }
+ return Err(SslError::get().unwrap());
+ }
+
+ unsafe { ffi::SSL_set_bio(ssl.ssl, rbio, wbio) }
+ Ok(ssl)
}
- fn set_bio(&self, rbio: &MemBio, wbio: &MemBio) {
- unsafe { ffi::SSL_set_bio(self.ssl, rbio.bio, wbio.bio); }
+ fn get_rbio<'a>(&'a self) -> MemBio<'a> {
+ let bio = unsafe { ffi::SSL_get_rbio(self.ssl) };
+ assert!(bio != ptr::null());
+
+ MemBio {
+ ssl: self,
+ bio: bio
+ }
}
- fn set_connect_state(&self) {
- unsafe { ffi::SSL_set_connect_state(self.ssl); }
+ fn get_wbio<'a>(&'a self) -> MemBio<'a> {
+ let bio = unsafe { ffi::SSL_get_wbio(self.ssl) };
+ assert!(bio != ptr::null());
+
+ MemBio {
+ ssl: self,
+ bio: bio
+ }
}
- fn connect(&self) -> int {
- unsafe { ffi::SSL_connect(self.ssl) as int }
+ fn connect(&self) -> c_int {
+ unsafe { ffi::SSL_connect(self.ssl) }
}
- fn get_error(&self, ret: int) -> SslError {
- let err = unsafe { ffi::SSL_get_error(self.ssl, ret as c_int) };
- match FromPrimitive::from_int(err as int) {
- Some(err) => err,
- None => fail2!("Unknown error {}", err)
- }
+ fn read(&self, buf: &mut [u8]) -> c_int {
+ unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void,
+ buf.len() as c_int) }
}
- fn read(&self, buf: &[u8]) -> int {
- unsafe {
- ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void,
- buf.len() as c_int) as int
- }
+ fn write(&self, buf: &[u8]) -> c_int {
+ unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void,
+ buf.len() as c_int) }
}
- fn write(&self, buf: &[u8]) -> int {
- unsafe {
- ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void,
- buf.len() as c_int) as int
+ fn get_error(&self, ret: c_int) -> LibSslError {
+ let err = unsafe { ffi::SSL_get_error(self.ssl, ret) };
+ match FromPrimitive::from_int(err as int) {
+ Some(err) => err,
+ None => unreachable!()
}
}
+}
- fn shutdown(&self) -> int {
- unsafe { ffi::SSL_shutdown(self.ssl) as int }
- }
+#[deriving(FromPrimitive)]
+enum LibSslError {
+ ErrorNone = ffi::SSL_ERROR_NONE,
+ ErrorSsl = ffi::SSL_ERROR_SSL,
+ ErrorWantRead = ffi::SSL_ERROR_WANT_READ,
+ ErrorWantWrite = ffi::SSL_ERROR_WANT_WRITE,
+ ErrorWantX509Lookup = ffi::SSL_ERROR_WANT_X509_LOOKUP,
+ ErrorSyscall = ffi::SSL_ERROR_SYSCALL,
+ ErrorZeroReturn = ffi::SSL_ERROR_ZERO_RETURN,
+ ErrorWantConnect = ffi::SSL_ERROR_WANT_CONNECT,
+ ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT,
}
-// BIOs are freed by SSL_free
-struct MemBio {
+struct MemBio<'self> {
+ ssl: &'self Ssl,
bio: *ffi::BIO
}
-impl MemBio {
- fn new() -> MemBio {
- let bio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
- assert!(bio != ptr::null());
-
- MemBio { bio: bio }
- }
+impl<'self> MemBio<'self> {
+ fn read(&self, buf: &mut [u8]) -> Option<uint> {
+ let ret = unsafe {
+ ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
+ buf.len() as c_int)
+ };
- fn write(&self, buf: &[u8]) {
- unsafe {
- let ret = ffi::BIO_write(self.bio,
- vec::raw::to_ptr(buf) as *c_void,
- buf.len() as c_int);
- if ret < 0 {
- fail2!("write returned {}", ret);
- }
+ if ret < 0 {
+ None
+ } else {
+ Some(ret as uint)
}
}
- fn read(&self, buf: &[u8]) -> uint {
- unsafe {
- let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
- buf.len() as c_int);
- if ret < 0 {
- 0
- } else {
- ret as uint
- }
- }
+ fn write(&self, buf: &[u8]) {
+ let ret = unsafe {
+ ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void,
+ buf.len() as c_int)
+ };
+ assert_eq!(buf.len(), ret as uint);
}
}
pub struct SslStream<S> {
- priv ctx: SslCtx,
+ priv stream: S,
priv ssl: Ssl,
- priv buf: ~[u8],
- priv rbio: MemBio,
- priv wbio: MemBio,
- priv stream: S
+ priv buf: ~[u8]
}
impl<S: Stream> SslStream<S> {
- pub fn new(ctx: SslCtx, stream: S) -> Result<SslStream<S>, uint> {
- let ssl = Ssl::new(&ctx);
-
- let rbio = MemBio::new();
- let wbio = MemBio::new();
-
- ssl.set_bio(&rbio, &wbio);
- ssl.set_connect_state();
+ pub fn try_new(ctx: &SslContext, stream: S) -> Result<SslStream<S>,
+ SslError> {
+ let ssl = match Ssl::try_new(ctx) {
+ Ok(ssl) => ssl,
+ Err(err) => return Err(err)
+ };
- let mut stream = SslStream {
- ctx: ctx,
+ let mut ssl = SslStream {
+ stream: stream,
ssl: ssl,
- // Max record size for SSLv3/TLSv1 is 16k
- buf: vec::from_elem(16 * 1024, 0u8),
- rbio: rbio,
- wbio: wbio,
- stream: stream
+ // Maximum TLS record size is 16k
+ buf: vec::from_elem(16 * 1024, 0u8)
};
- let ret = do stream.in_retry_wrapper |ssl| {
- ssl.ssl.connect()
- };
+ match ssl.in_retry_wrapper(|ssl| { ssl.connect() }) {
+ Ok(_) => Ok(ssl),
+ Err(err) => Err(err)
+ }
+ }
- match ret {
- Ok(_) => Ok(stream),
- // FIXME
- Err(_err) => Err(unsafe { ffi::ERR_get_error() as uint })
+ pub fn new(ctx: &SslContext, stream: S) -> SslStream<S> {
+ match SslStream::try_new(ctx, stream) {
+ Ok(stream) => stream,
+ Err(err) => fail!("Error creating SSL stream: {:?}", err)
}
}
- fn in_retry_wrapper(&mut self, blk: &fn(&mut SslStream<S>) -> int)
- -> Result<int, SslError> {
+ fn in_retry_wrapper(&mut self, blk: &fn(&Ssl) -> c_int)
+ -> Result<c_int, SslError> {
loop {
- let ret = blk(self);
+ let ret = blk(&self.ssl);
if ret > 0 {
return Ok(ret);
}
@@ -243,34 +268,24 @@ impl<S: Stream> SslStream<S> {
ErrorWantRead => {
self.flush();
match self.stream.read(self.buf) {
- Some(len) => self.rbio.write(self.buf.slice_to(len)),
- None => return Err(ErrorZeroReturn) // FIXME
+ Some(len) =>
+ self.ssl.get_rbio().write(self.buf.slice_to(len)),
+ None => return Err(StreamEof)
}
}
ErrorWantWrite => self.flush(),
- err => return Err(err)
+ ErrorZeroReturn => return Err(SslSessionClosed),
+ ErrorSsl => return Err(SslError::get().unwrap()),
+ _ => unreachable!()
}
}
}
fn write_through(&mut self) {
loop {
- let len = self.wbio.read(self.buf);
- if len == 0 {
- return;
- }
- self.stream.write(self.buf.slice_to(len));
- }
- }
-
- pub fn shutdown(&mut self) {
- loop {
- let ret = do self.in_retry_wrapper |ssl| {
- ssl.ssl.shutdown()
- };
-
- if ret != Ok(0) {
- break;
+ match self.ssl.get_wbio().read(self.buf) {
+ Some(len) => self.stream.write(self.buf.slice_to(len)),
+ None => break
}
}
}
@@ -278,13 +293,10 @@ impl<S: Stream> SslStream<S> {
impl<S: Stream> Reader for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> Option<uint> {
- let ret = do self.in_retry_wrapper |ssl| {
- ssl.ssl.read(buf)
- };
-
- match ret {
- Ok(num) => Some(num as uint),
- Err(_) => None
+ match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
+ Ok(len) => Some(len as uint),
+ Err(StreamEof) | Err(SslSessionClosed) => None,
+ _ => unreachable!()
}
}
@@ -295,25 +307,26 @@ impl<S: Stream> Reader for SslStream<S> {
impl<S: Stream> Writer for SslStream<S> {
fn write(&mut self, buf: &[u8]) {
- let ret = do self.in_retry_wrapper |ssl| {
- ssl.ssl.write(buf)
- };
-
- match ret {
- Ok(_) => (),
- Err(err) => fail2!("Write error: {:?}", err)
+ let mut start = 0;
+ while start < buf.len() {
+ let ret = do self.in_retry_wrapper |ssl| {
+ ssl.write(buf.slice_from(start))
+ };
+ match ret {
+ Ok(len) => start += len as uint,
+ _ => unreachable!()
+ }
+ self.write_through();
}
-
- self.write_through();
}
fn flush(&mut self) {
self.write_through();
- self.stream.flush();
+ self.stream.flush()
}
}
-impl<S: Stream> Decorator<S> for SslStream<S> {
+impl<S> Decorator<S> for SslStream<S> {
fn inner(self) -> S {
self.stream
}
diff --git a/tests.rs b/tests.rs
index 639ce1b1..b167cda8 100644
--- a/tests.rs
+++ b/tests.rs
@@ -3,37 +3,37 @@ use std::rt::io::extensions::ReaderUtil;
use std::rt::io::net::tcp::TcpStream;
use std::str;
-use super::{Sslv23, SslCtx, SslStream, SslVerifyPeer};
+use super::{Sslv23, SslContext, SslStream, SslVerifyPeer};
#[test]
fn test_new_ctx() {
- SslCtx::new(Sslv23);
+ SslContext::new(Sslv23);
}
#[test]
fn test_new_sslstream() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
- SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
+ SslStream::new(&SslContext::new(Sslv23), stream);
}
#[test]
fn test_verify_untrusted() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
- let mut ctx = SslCtx::new(Sslv23);
+ let mut ctx = SslContext::new(Sslv23);
ctx.set_verify(SslVerifyPeer);
- match SslStream::new(ctx, stream) {
+ match SslStream::try_new(&ctx, stream) {
Ok(_) => fail2!("expected failure"),
- Err(err) => println!("error {}", err)
+ Err(err) => println!("error {:?}", err)
}
}
#[test]
fn test_verify_trusted() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
- let mut ctx = SslCtx::new(Sslv23);
+ let mut ctx = SslContext::new(Sslv23);
ctx.set_verify(SslVerifyPeer);
- ctx.set_verify_locations("cert.pem");
- match SslStream::new(ctx, stream) {
+ assert!(ctx.set_CA_file("cert.pem").is_none());
+ match SslStream::try_new(&ctx, stream) {
Ok(_) => (),
Err(err) => fail2!("Expected success, got {:?}", err)
}
@@ -42,18 +42,17 @@ fn test_verify_trusted() {
#[test]
fn test_write() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
- let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
+ let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("hello".as_bytes());
stream.flush();
stream.write(" there".as_bytes());
stream.flush();
- stream.shutdown();
}
#[test]
fn test_read() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
- let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
+ let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("GET /\r\n\r\n".as_bytes());
stream.flush();
let buf = stream.read_to_end();