diff options
| author | Manuel Schölling <[email protected]> | 2015-03-04 22:32:16 +0100 |
|---|---|---|
| committer | Manuel Schölling <[email protected]> | 2015-04-06 12:14:36 +0200 |
| commit | 5408b641ddbddd9f40ec203901dd7cb1a7afa3c0 (patch) | |
| tree | fc37e6d0da4a424178f1b605b7ebb9b1e2da8da2 /openssl/src/ssl/connected_socket.rs | |
| parent | Release v0.6.0 (diff) | |
| download | rust-openssl-5408b641ddbddd9f40ec203901dd7cb1a7afa3c0.tar.xz rust-openssl-5408b641ddbddd9f40ec203901dd7cb1a7afa3c0.zip | |
Add connect() support for UDP sockets
Diffstat (limited to 'openssl/src/ssl/connected_socket.rs')
| -rw-r--r-- | openssl/src/ssl/connected_socket.rs | 301 |
1 files changed, 301 insertions, 0 deletions
diff --git a/openssl/src/ssl/connected_socket.rs b/openssl/src/ssl/connected_socket.rs new file mode 100644 index 00000000..1ae5fc8d --- /dev/null +++ b/openssl/src/ssl/connected_socket.rs @@ -0,0 +1,301 @@ +use libc::funcs::bsd43::connect; +use std::os; +use std::os::unix::AsRawFd; +use std::os::unix::Fd; +use std::net::UdpSocket; +use std::net::ToSocketAddrs; +use std::net::SocketAddr; +use std::io::Error; +use std::io::ErrorKind; +use std::io::Read; +use std::io::Write; +use std::mem; +use std::time::duration::Duration; +use libc::types::os::common::bsd44::socklen_t; +use libc::types::os::common::bsd44::sockaddr_in; +use libc::types::os::common::bsd44::sockaddr_in6; +use libc::types::os::common::bsd44::in_addr; +use libc::types::os::common::bsd44::in6_addr; +use libc::types::os::common::posix01::timeval; +use libc::funcs::bsd43::setsockopt; +use libc::consts::os::bsd44::SOL_SOCKET; +use libc::consts::os::bsd44::AF_INET; +use libc::consts::os::bsd44::AF_INET6; +use libc::consts::os::posix88::EAGAIN; +use std::net::IpAddr; +use libc::types::os::arch::c95::c_int; +use libc::types::os::arch::c95::c_char; +use libc::types::common::c95::c_void; +use libc::funcs::bsd43::send; +use libc::funcs::bsd43::recv; +use std::num::Int; +use std::os::errno; +use std::ffi::CString; + +const SO_RCVTIMEO:c_int = 20; + +extern { + fn inet_pton(family: c_int, src: *const c_char, dst: *mut c_void) -> c_int; +} + +pub struct ConnectedSocket<S: ?Sized> { + sock: S +} + +impl<S: AsRawFd+?Sized> AsRawFd for ConnectedSocket<S> { + fn as_raw_fd(&self) -> Fd { + self.sock.as_raw_fd() + } +} + +enum SockaddrIn { + V4(sockaddr_in), + V6(sockaddr_in6), +} + +trait IntoSockaddrIn { + fn into_sockaddr_in(self) -> Result<SockaddrIn, Error>; +} + +impl IntoSockaddrIn for SocketAddr { + fn into_sockaddr_in(self) -> Result<SockaddrIn, Error> { + let ip = format!("{}", self.ip()); + + match self.ip() { + IpAddr::V4(_) => { + let mut addr = sockaddr_in { + sin_zero: [0; 8], + sin_family: AF_INET as u16, + sin_port: Int::to_be(self.port()), + sin_addr: in_addr { + s_addr: 0 + } + }; + let cstr = CString::new(ip.clone()).unwrap(); + let res = unsafe { + inet_pton(addr.sin_family as c_int, + cstr.as_ptr() as *const i8, + mem::transmute(&mut addr.sin_addr)) + }; + + if res == 1 { + Ok(SockaddrIn::V4(addr)) + } else { + warn!("inet_pton() failed for IPv4: {}", ip); + Err(Error::new(ErrorKind::Other, + "calling inet_pton() for ipv4", None)) + } + }, + + IpAddr::V6(_) => { + let mut addr = sockaddr_in6 { + sin6_family: AF_INET6 as u16, + sin6_port: Int::to_be(self.port()), + sin6_flowinfo: 0, + sin6_scope_id: 0, + sin6_addr: in6_addr { + s6_addr: [0; 8], + } + }; + let cstr = CString::new(ip.clone()).unwrap(); + let res = unsafe { + inet_pton(addr.sin6_family as c_int, + cstr.as_ptr() as *const i8, + mem::transmute(&mut addr.sin6_addr)) + }; + + if res > 0 { + Ok(SockaddrIn::V6(addr)) + } else { + Err(Error::new(ErrorKind::Other, + "calling inet_pton() for ipv6", None)) + } + } + } + } +} + +pub trait Connect { + fn connect<A: ToSocketAddrs + ?Sized>(self, addr: &A) -> Result<ConnectedSocket<Self>,Error>; +} + +impl Connect for UdpSocket { + fn connect<A: ToSocketAddrs + ?Sized>(self, address: &A) -> Result<ConnectedSocket<Self>,Error> { + let fd = self.as_raw_fd(); + + let addr = try!(address.to_socket_addrs()).next(); + if addr.is_none() { + return Err(Error::new(ErrorKind::InvalidInput, + "no addresses to connect to", None)); + } + + let saddr = try!(addr.unwrap().into_sockaddr_in()); + + let res = match saddr { + SockaddrIn::V4(s) => unsafe { + let len = mem::size_of_val(&s) as socklen_t; + let addrp = Box::new(s); + connect(fd, mem::transmute(&*addrp), len) + }, + SockaddrIn::V6(s) => unsafe { + let len = mem::size_of_val(&s) as socklen_t; + let addrp = Box::new(s); + connect(fd, mem::transmute(&*addrp), len) + }, + }; + + if res == 0 { + Ok(ConnectedSocket { sock: self }) + } else { + Err(Error::new(ErrorKind::Other, + "error calling connect()", None)) + } + } +} + +impl<S: AsRawFd+?Sized> Read for ConnectedSocket<S> { + fn read(&mut self, buf: &mut [u8]) -> Result<usize,Error> { + let flags = 0; + let ptr = buf.as_mut_ptr() as *mut c_void; + + debug!("recv'ing..."); + let len = unsafe { + recv(self.as_raw_fd(), ptr, buf.len() as u64, flags) + }; + + debug!("recv'ed len={:?}", len); + match len { + -1 => { + match errno() { + EAGAIN => Err(Error::new(ErrorKind::Interrupted, "EAGAIN", None)), + _ => Err(Error::new(ErrorKind::Other, + "recv() returned -1", None)), + } + }, + 0 => Err(Error::new(ErrorKind::Other, + "connection is closed", None)), + _ => Ok(len as usize), + } + } +} + +impl<S: AsRawFd+?Sized> Write for ConnectedSocket<S> { + fn write(&mut self, buf: &[u8]) -> Result<usize,Error> { + let flags = 0; + let ptr = buf.as_ptr() as *const c_void; + + debug!("sending {:?}", buf.len()); + let res = unsafe { + send(self.as_raw_fd(), ptr, buf.len() as u64, flags) + }; + if res == (buf.len() as i64) { + Ok(res as usize) + } else { + warn!("send() found {}, expected {}", res, buf.len()); + Err(Error::new(ErrorKind::Other, "send() failed", Some(os::error_string(os::errno() as i32)))) + } + } + + fn flush(&mut self) -> Result<(),Error> { + Ok(()) + } +} + +pub trait SetTimeout { + fn set_timeout(&self, timeout: Duration); +} + +impl<S:AsRawFd> SetTimeout for S { + fn set_timeout(&self, timeout: Duration) { + let tv = timeval { + tv_sec: timeout.num_seconds(), + tv_usec: 0, + }; + + unsafe { + setsockopt(self.as_raw_fd(), SOL_SOCKET, SO_RCVTIMEO, + mem::transmute(&tv), mem::size_of_val(&tv) as u32) + }; + } +} + +#[test] +fn connect4_works() { + let socket1 = UdpSocket::bind("127.0.0.1:34200").unwrap(); + let socket2 = UdpSocket::bind("127.0.0.1:34201").unwrap(); + let conn1 = socket1.connect("127.0.0.1:34200").unwrap(); + let conn2 = socket2.connect("127.0.0.1:34201").unwrap(); +} + +#[test] +fn sendrecv_works() { + let socket1 = UdpSocket::bind("127.0.0.1:34200").unwrap(); + let socket2 = UdpSocket::bind("127.0.0.1:34201").unwrap(); + let mut conn1 = socket1.connect("127.0.0.1:34201").unwrap(); + let mut conn2 = socket2.connect("127.0.0.1:34200").unwrap(); + + let send1 = [0,1,2,3]; + let send2 = [9,8,7,6]; + conn1.write(&send1).unwrap(); + conn2.write(&send2).unwrap(); + + let mut recv1 = [0;4]; + let mut recv2 = [0;4]; + conn1.read(&mut recv1).unwrap(); + conn2.read(&mut recv2).unwrap(); + + assert_eq!(send1, recv2); + assert_eq!(send2, recv1); +} + +#[test] +fn sendrecv_respects_packet_borders() { + let socket1 = UdpSocket::bind("127.0.0.1:34202").unwrap(); + let socket2 = UdpSocket::bind("127.0.0.1:34203").unwrap(); + let mut conn1 = socket1.connect("127.0.0.1:34203").unwrap(); + let mut conn2 = socket2.connect("127.0.0.1:34202").unwrap(); + + let send1 = [0,1,2,3]; + let send2 = [9,8,7,6]; + conn1.write(&send1).unwrap(); + conn1.write(&send2).unwrap(); + + let mut recv1 = [0;3]; + let mut recv2 = [0;3]; + conn2.read(&mut recv1).unwrap(); + conn2.read(&mut recv2).unwrap(); + + assert!(send1[0..3] == recv1[0..3]); + assert!(send2[0..3] == recv2[0..3]); +} + +#[test] +fn connect6_works() { + let socket1 = UdpSocket::bind("::1:34200").unwrap(); + let socket2 = UdpSocket::bind("::1:34201").unwrap(); + let conn1 = socket1.connect("::1:34200").unwrap(); + let conn2 = socket2.connect("::1:34201").unwrap(); +} + +#[test] +#[should_fail] +fn detect_invalid_ipv4() { + let s = UdpSocket::bind("127.0.0.1:34300").unwrap(); + s.connect("254.254.254.254:34200").unwrap(); +} + +#[test] +#[should_fail] +fn detect_invalid_ipv6() { + let s = UdpSocket::bind("::1:34300").unwrap(); + s.connect("1200::AB00:1234::2552:7777:1313:34300").unwrap(); +} + +#[test] +#[should_fail] +fn double_bind() { + let socket1 = UdpSocket::bind("127.0.0.1:34301").unwrap(); + let socket2 = UdpSocket::bind("127.0.0.1:34301").unwrap(); + drop(socket1); + drop(socket2); +} |