aboutsummaryrefslogtreecommitdiff
path: root/openssl/src
diff options
context:
space:
mode:
authorManuel Schölling <[email protected]>2015-03-04 22:32:16 +0100
committerManuel Schölling <[email protected]>2015-04-06 12:14:36 +0200
commit5408b641ddbddd9f40ec203901dd7cb1a7afa3c0 (patch)
treefc37e6d0da4a424178f1b605b7ebb9b1e2da8da2 /openssl/src
parentRelease v0.6.0 (diff)
downloadrust-openssl-5408b641ddbddd9f40ec203901dd7cb1a7afa3c0.tar.xz
rust-openssl-5408b641ddbddd9f40ec203901dd7cb1a7afa3c0.zip
Add connect() support for UDP sockets
Diffstat (limited to 'openssl/src')
-rw-r--r--openssl/src/ssl/connected_socket.rs301
-rw-r--r--openssl/src/ssl/mod.rs8
-rw-r--r--openssl/src/ssl/tests.rs70
3 files changed, 360 insertions, 19 deletions
diff --git a/openssl/src/ssl/connected_socket.rs b/openssl/src/ssl/connected_socket.rs
new file mode 100644
index 00000000..1ae5fc8d
--- /dev/null
+++ b/openssl/src/ssl/connected_socket.rs
@@ -0,0 +1,301 @@
+use libc::funcs::bsd43::connect;
+use std::os;
+use std::os::unix::AsRawFd;
+use std::os::unix::Fd;
+use std::net::UdpSocket;
+use std::net::ToSocketAddrs;
+use std::net::SocketAddr;
+use std::io::Error;
+use std::io::ErrorKind;
+use std::io::Read;
+use std::io::Write;
+use std::mem;
+use std::time::duration::Duration;
+use libc::types::os::common::bsd44::socklen_t;
+use libc::types::os::common::bsd44::sockaddr_in;
+use libc::types::os::common::bsd44::sockaddr_in6;
+use libc::types::os::common::bsd44::in_addr;
+use libc::types::os::common::bsd44::in6_addr;
+use libc::types::os::common::posix01::timeval;
+use libc::funcs::bsd43::setsockopt;
+use libc::consts::os::bsd44::SOL_SOCKET;
+use libc::consts::os::bsd44::AF_INET;
+use libc::consts::os::bsd44::AF_INET6;
+use libc::consts::os::posix88::EAGAIN;
+use std::net::IpAddr;
+use libc::types::os::arch::c95::c_int;
+use libc::types::os::arch::c95::c_char;
+use libc::types::common::c95::c_void;
+use libc::funcs::bsd43::send;
+use libc::funcs::bsd43::recv;
+use std::num::Int;
+use std::os::errno;
+use std::ffi::CString;
+
+const SO_RCVTIMEO:c_int = 20;
+
+extern {
+ fn inet_pton(family: c_int, src: *const c_char, dst: *mut c_void) -> c_int;
+}
+
+pub struct ConnectedSocket<S: ?Sized> {
+ sock: S
+}
+
+impl<S: AsRawFd+?Sized> AsRawFd for ConnectedSocket<S> {
+ fn as_raw_fd(&self) -> Fd {
+ self.sock.as_raw_fd()
+ }
+}
+
+enum SockaddrIn {
+ V4(sockaddr_in),
+ V6(sockaddr_in6),
+}
+
+trait IntoSockaddrIn {
+ fn into_sockaddr_in(self) -> Result<SockaddrIn, Error>;
+}
+
+impl IntoSockaddrIn for SocketAddr {
+ fn into_sockaddr_in(self) -> Result<SockaddrIn, Error> {
+ let ip = format!("{}", self.ip());
+
+ match self.ip() {
+ IpAddr::V4(_) => {
+ let mut addr = sockaddr_in {
+ sin_zero: [0; 8],
+ sin_family: AF_INET as u16,
+ sin_port: Int::to_be(self.port()),
+ sin_addr: in_addr {
+ s_addr: 0
+ }
+ };
+ let cstr = CString::new(ip.clone()).unwrap();
+ let res = unsafe {
+ inet_pton(addr.sin_family as c_int,
+ cstr.as_ptr() as *const i8,
+ mem::transmute(&mut addr.sin_addr))
+ };
+
+ if res == 1 {
+ Ok(SockaddrIn::V4(addr))
+ } else {
+ warn!("inet_pton() failed for IPv4: {}", ip);
+ Err(Error::new(ErrorKind::Other,
+ "calling inet_pton() for ipv4", None))
+ }
+ },
+
+ IpAddr::V6(_) => {
+ let mut addr = sockaddr_in6 {
+ sin6_family: AF_INET6 as u16,
+ sin6_port: Int::to_be(self.port()),
+ sin6_flowinfo: 0,
+ sin6_scope_id: 0,
+ sin6_addr: in6_addr {
+ s6_addr: [0; 8],
+ }
+ };
+ let cstr = CString::new(ip.clone()).unwrap();
+ let res = unsafe {
+ inet_pton(addr.sin6_family as c_int,
+ cstr.as_ptr() as *const i8,
+ mem::transmute(&mut addr.sin6_addr))
+ };
+
+ if res > 0 {
+ Ok(SockaddrIn::V6(addr))
+ } else {
+ Err(Error::new(ErrorKind::Other,
+ "calling inet_pton() for ipv6", None))
+ }
+ }
+ }
+ }
+}
+
+pub trait Connect {
+ fn connect<A: ToSocketAddrs + ?Sized>(self, addr: &A) -> Result<ConnectedSocket<Self>,Error>;
+}
+
+impl Connect for UdpSocket {
+ fn connect<A: ToSocketAddrs + ?Sized>(self, address: &A) -> Result<ConnectedSocket<Self>,Error> {
+ let fd = self.as_raw_fd();
+
+ let addr = try!(address.to_socket_addrs()).next();
+ if addr.is_none() {
+ return Err(Error::new(ErrorKind::InvalidInput,
+ "no addresses to connect to", None));
+ }
+
+ let saddr = try!(addr.unwrap().into_sockaddr_in());
+
+ let res = match saddr {
+ SockaddrIn::V4(s) => unsafe {
+ let len = mem::size_of_val(&s) as socklen_t;
+ let addrp = Box::new(s);
+ connect(fd, mem::transmute(&*addrp), len)
+ },
+ SockaddrIn::V6(s) => unsafe {
+ let len = mem::size_of_val(&s) as socklen_t;
+ let addrp = Box::new(s);
+ connect(fd, mem::transmute(&*addrp), len)
+ },
+ };
+
+ if res == 0 {
+ Ok(ConnectedSocket { sock: self })
+ } else {
+ Err(Error::new(ErrorKind::Other,
+ "error calling connect()", None))
+ }
+ }
+}
+
+impl<S: AsRawFd+?Sized> Read for ConnectedSocket<S> {
+ fn read(&mut self, buf: &mut [u8]) -> Result<usize,Error> {
+ let flags = 0;
+ let ptr = buf.as_mut_ptr() as *mut c_void;
+
+ debug!("recv'ing...");
+ let len = unsafe {
+ recv(self.as_raw_fd(), ptr, buf.len() as u64, flags)
+ };
+
+ debug!("recv'ed len={:?}", len);
+ match len {
+ -1 => {
+ match errno() {
+ EAGAIN => Err(Error::new(ErrorKind::Interrupted, "EAGAIN", None)),
+ _ => Err(Error::new(ErrorKind::Other,
+ "recv() returned -1", None)),
+ }
+ },
+ 0 => Err(Error::new(ErrorKind::Other,
+ "connection is closed", None)),
+ _ => Ok(len as usize),
+ }
+ }
+}
+
+impl<S: AsRawFd+?Sized> Write for ConnectedSocket<S> {
+ fn write(&mut self, buf: &[u8]) -> Result<usize,Error> {
+ let flags = 0;
+ let ptr = buf.as_ptr() as *const c_void;
+
+ debug!("sending {:?}", buf.len());
+ let res = unsafe {
+ send(self.as_raw_fd(), ptr, buf.len() as u64, flags)
+ };
+ if res == (buf.len() as i64) {
+ Ok(res as usize)
+ } else {
+ warn!("send() found {}, expected {}", res, buf.len());
+ Err(Error::new(ErrorKind::Other, "send() failed", Some(os::error_string(os::errno() as i32))))
+ }
+ }
+
+ fn flush(&mut self) -> Result<(),Error> {
+ Ok(())
+ }
+}
+
+pub trait SetTimeout {
+ fn set_timeout(&self, timeout: Duration);
+}
+
+impl<S:AsRawFd> SetTimeout for S {
+ fn set_timeout(&self, timeout: Duration) {
+ let tv = timeval {
+ tv_sec: timeout.num_seconds(),
+ tv_usec: 0,
+ };
+
+ unsafe {
+ setsockopt(self.as_raw_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ mem::transmute(&tv), mem::size_of_val(&tv) as u32)
+ };
+ }
+}
+
+#[test]
+fn connect4_works() {
+ let socket1 = UdpSocket::bind("127.0.0.1:34200").unwrap();
+ let socket2 = UdpSocket::bind("127.0.0.1:34201").unwrap();
+ let conn1 = socket1.connect("127.0.0.1:34200").unwrap();
+ let conn2 = socket2.connect("127.0.0.1:34201").unwrap();
+}
+
+#[test]
+fn sendrecv_works() {
+ let socket1 = UdpSocket::bind("127.0.0.1:34200").unwrap();
+ let socket2 = UdpSocket::bind("127.0.0.1:34201").unwrap();
+ let mut conn1 = socket1.connect("127.0.0.1:34201").unwrap();
+ let mut conn2 = socket2.connect("127.0.0.1:34200").unwrap();
+
+ let send1 = [0,1,2,3];
+ let send2 = [9,8,7,6];
+ conn1.write(&send1).unwrap();
+ conn2.write(&send2).unwrap();
+
+ let mut recv1 = [0;4];
+ let mut recv2 = [0;4];
+ conn1.read(&mut recv1).unwrap();
+ conn2.read(&mut recv2).unwrap();
+
+ assert_eq!(send1, recv2);
+ assert_eq!(send2, recv1);
+}
+
+#[test]
+fn sendrecv_respects_packet_borders() {
+ let socket1 = UdpSocket::bind("127.0.0.1:34202").unwrap();
+ let socket2 = UdpSocket::bind("127.0.0.1:34203").unwrap();
+ let mut conn1 = socket1.connect("127.0.0.1:34203").unwrap();
+ let mut conn2 = socket2.connect("127.0.0.1:34202").unwrap();
+
+ let send1 = [0,1,2,3];
+ let send2 = [9,8,7,6];
+ conn1.write(&send1).unwrap();
+ conn1.write(&send2).unwrap();
+
+ let mut recv1 = [0;3];
+ let mut recv2 = [0;3];
+ conn2.read(&mut recv1).unwrap();
+ conn2.read(&mut recv2).unwrap();
+
+ assert!(send1[0..3] == recv1[0..3]);
+ assert!(send2[0..3] == recv2[0..3]);
+}
+
+#[test]
+fn connect6_works() {
+ let socket1 = UdpSocket::bind("::1:34200").unwrap();
+ let socket2 = UdpSocket::bind("::1:34201").unwrap();
+ let conn1 = socket1.connect("::1:34200").unwrap();
+ let conn2 = socket2.connect("::1:34201").unwrap();
+}
+
+#[test]
+#[should_fail]
+fn detect_invalid_ipv4() {
+ let s = UdpSocket::bind("127.0.0.1:34300").unwrap();
+ s.connect("254.254.254.254:34200").unwrap();
+}
+
+#[test]
+#[should_fail]
+fn detect_invalid_ipv6() {
+ let s = UdpSocket::bind("::1:34300").unwrap();
+ s.connect("1200::AB00:1234::2552:7777:1313:34300").unwrap();
+}
+
+#[test]
+#[should_fail]
+fn double_bind() {
+ let socket1 = UdpSocket::bind("127.0.0.1:34301").unwrap();
+ let socket2 = UdpSocket::bind("127.0.0.1:34301").unwrap();
+ drop(socket1);
+ drop(socket2);
+}
diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs
index 4c0b13f1..710a287d 100644
--- a/openssl/src/ssl/mod.rs
+++ b/openssl/src/ssl/mod.rs
@@ -25,6 +25,7 @@ use x509::{X509StoreContext, X509FileType, X509};
use crypto::pkey::PKey;
pub mod error;
+pub mod connected_socket;
#[cfg(test)]
mod tests;
@@ -97,6 +98,9 @@ pub enum SslMethod {
#[cfg(feature = "tlsv1_2")]
/// Support TLSv1.2 protocol, requires the `tlsv1_2` feature.
Tlsv1_2,
+ #[cfg(feature = "dtlsv1")]
+ /// Support DTLSv1 protocol, requires the `dtlsv1` feature.
+ Dtlsv1,
}
impl SslMethod {
@@ -110,7 +114,9 @@ impl SslMethod {
#[cfg(feature = "tlsv1_1")]
SslMethod::Tlsv1_1 => ffi::TLSv1_1_method(),
#[cfg(feature = "tlsv1_2")]
- SslMethod::Tlsv1_2 => ffi::TLSv1_2_method()
+ SslMethod::Tlsv1_2 => ffi::TLSv1_2_method(),
+ #[cfg(feature = "dtlsv1")]
+ SslMethod::Dtlsv1 => ffi::TLSv1_method(),
}
}
}
diff --git a/openssl/src/ssl/tests.rs b/openssl/src/ssl/tests.rs
index 05c9fe79..1da42082 100644
--- a/openssl/src/ssl/tests.rs
+++ b/openssl/src/ssl/tests.rs
@@ -11,6 +11,7 @@ use std::fs::File;
use crypto::hash::Type::{SHA256};
use ssl;
+use ssl::SslMethod;
use ssl::SslMethod::Sslv23;
use ssl::{SslContext, SslStream, VerifyCallback};
use ssl::SSL_VERIFY_PEER;
@@ -20,21 +21,23 @@ use x509::X509FileType;
use x509::X509;
use crypto::pkey::PKey;
+const PROTOCOL:SslMethod = Sslv23;
+
#[test]
fn test_new_ctx() {
- SslContext::new(Sslv23).unwrap();
+ SslContext::new(PROTOCOL).unwrap();
}
#[test]
fn test_new_sslstream() {
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- SslStream::new(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
+ SslStream::new(&SslContext::new(PROTOCOL).unwrap(), stream).unwrap();
}
#[test]
fn test_verify_untrusted() {
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, None);
match SslStream::new(&ctx, stream) {
Ok(_) => panic!("expected failure"),
@@ -45,8 +48,9 @@ fn test_verify_untrusted() {
#[test]
fn test_verify_trusted() {
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, None);
+
match ctx.set_CA_file(&Path::new("test/cert.pem")) {
Ok(_) => {}
Err(err) => panic!("Unexpected error {:?}", err)
@@ -63,8 +67,9 @@ fn test_verify_untrusted_callback_override_ok() {
true
}
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback));
+
match SslStream::new(&ctx, stream) {
Ok(_) => (),
Err(err) => panic!("Expected success, got {:?}", err)
@@ -77,8 +82,9 @@ fn test_verify_untrusted_callback_override_bad() {
false
}
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback));
+
assert!(SslStream::new(&ctx, stream).is_err());
}
@@ -88,8 +94,9 @@ fn test_verify_trusted_callback_override_ok() {
true
}
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback));
+
match ctx.set_CA_file(&Path::new("test/cert.pem")) {
Ok(_) => {}
Err(err) => panic!("Unexpected error {:?}", err)
@@ -106,8 +113,9 @@ fn test_verify_trusted_callback_override_bad() {
false
}
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback));
+
match ctx.set_CA_file(&Path::new("test/cert.pem")) {
Ok(_) => {}
Err(err) => panic!("Unexpected error {:?}", err)
@@ -122,8 +130,9 @@ fn test_verify_callback_load_certs() {
true
}
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback));
+
assert!(SslStream::new(&ctx, stream).is_ok());
}
@@ -134,8 +143,9 @@ fn test_verify_trusted_get_error_ok() {
true
}
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback));
+
match ctx.set_CA_file(&Path::new("test/cert.pem")) {
Ok(_) => {}
Err(err) => panic!("Unexpected error {:?}", err)
@@ -150,8 +160,9 @@ fn test_verify_trusted_get_error_err() {
false
}
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback));
+
assert!(SslStream::new(&ctx, stream).is_err());
}
@@ -168,7 +179,7 @@ fn test_verify_callback_data() {
}
}
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut ctx = SslContext::new(Sslv23).unwrap();
+ let mut ctx = SslContext::new(PROTOCOL).unwrap();
// Node id was generated as SHA256 hash of certificate "test/cert.pem"
// in DER format.
@@ -234,7 +245,7 @@ fn test_clear_ctx_options() {
#[test]
fn test_write() {
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut stream = SslStream::new(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
+ let mut stream = SslStream::new(&SslContext::new(PROTOCOL).unwrap(), stream).unwrap();
stream.write_all("hello".as_bytes()).unwrap();
stream.flush().unwrap();
stream.write_all(" there".as_bytes()).unwrap();
@@ -244,7 +255,7 @@ fn test_write() {
#[test]
fn test_read() {
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
- let mut stream = SslStream::new(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
+ let mut stream = SslStream::new(&SslContext::new(PROTOCOL).unwrap(), stream).unwrap();
stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap();
stream.flush().unwrap();
println!("written");
@@ -261,7 +272,7 @@ fn test_connect_with_unilateral_npn() {
ctx.set_verify(SSL_VERIFY_PEER, None);
ctx.set_npn_protocols(&[b"http/1.1", b"spdy/3.1"]);
match ctx.set_CA_file(&Path::new("test/cert.pem")) {
- Ok(_)=> {}
+ Ok(_) => {}
Err(err) => panic!("Unexpected error {:?}", err)
}
let stream = match SslStream::new(&ctx, stream) {
@@ -285,7 +296,7 @@ fn test_connect_with_npn_successful_multiple_matching() {
ctx.set_verify(SSL_VERIFY_PEER, None);
ctx.set_npn_protocols(&[b"spdy/3.1", b"http/1.1"]);
match ctx.set_CA_file(&Path::new("test/cert.pem")) {
- Ok(_)=> {}
+ Ok(_) => {}
Err(err) => panic!("Unexpected error {:?}", err)
}
let stream = match SslStream::new(&ctx, stream) {
@@ -310,7 +321,7 @@ fn test_connect_with_npn_successful_single_match() {
ctx.set_verify(SSL_VERIFY_PEER, None);
ctx.set_npn_protocols(&[b"spdy/3.1"]);
match ctx.set_CA_file(&Path::new("test/cert.pem")) {
- Ok(_)=> {}
+ Ok(_) => {}
Err(err) => panic!("Unexpected error {:?}", err)
}
let stream = match SslStream::new(&ctx, stream) {
@@ -350,7 +361,7 @@ fn test_npn_server_advertise_multiple() {
ctx.set_verify(SSL_VERIFY_PEER, None);
ctx.set_npn_protocols(&[b"spdy/3.1"]);
match ctx.set_CA_file(&Path::new("test/cert.pem")) {
- Ok(_)=> {}
+ Ok(_) => {}
Err(err) => panic!("Unexpected error {:?}", err)
}
// Now connect to the socket and make sure the protocol negotiation works...
@@ -362,3 +373,26 @@ fn test_npn_server_advertise_multiple() {
// SPDY is selected since that's the only thing the client supports.
assert_eq!(b"spdy/3.1", stream.get_selected_npn_protocol().unwrap());
}
+
+#[cfg(feature="dtlsv1")]
+#[cfg(test)]
+mod dtlsv1 {
+ use serialize::hex::FromHex;
+ use std::old_io::net::tcp::TcpStream;
+ use std::old_io::{Writer};
+ use std::thread;
+
+ use crypto::hash::Type::{SHA256};
+ use ssl::SslMethod;
+ use ssl::SslMethod::Dtlsv1;
+ use ssl::{SslContext, SslStream, VerifyCallback};
+ use ssl::SslVerifyMode::SSL_VERIFY_PEER;
+ use x509::{X509StoreContext};
+
+ const PROTOCOL:SslMethod = Dtlsv1;
+
+ #[test]
+ fn test_new_ctx() {
+ SslContext::new(PROTOCOL).unwrap();
+ }
+}