Fast UDP I/O for MacOS, others will be easy to add.

This commit is contained in:
Adam Ierymenko 2021-01-15 21:28:46 -05:00
parent a6c39a1952
commit 346bb7cf99
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
8 changed files with 240 additions and 123 deletions

View file

@ -808,17 +808,23 @@ unsigned int ZT_InetAddress_port(const ZT_InetAddress *ia)
int ZT_InetAddress_isNil(const ZT_InetAddress *ia)
{
return (int)( (ia == nullptr) || ((reinterpret_cast<const sockaddr_storage *>(ia))->ss_family == 0) );
if (!ia)
return 0;
return (int)((bool)(*reinterpret_cast<const ZeroTier::InetAddress *>(ia)));
}
int ZT_InetAddress_isV4(const ZT_InetAddress *ia)
{
return (int)( (ia != nullptr) && ((reinterpret_cast<const sockaddr_storage *>(ia))->ss_family == AF_INET) );
if (!ia)
return 0;
return (int)(reinterpret_cast<const ZeroTier::InetAddress *>(ia))->isV4();
}
int ZT_InetAddress_isV6(const ZT_InetAddress *ia)
{
return (int)( (ia != nullptr) && ((reinterpret_cast<const sockaddr_storage *>(ia))->ss_family == AF_INET6) );
if (!ia)
return 0;
return (int)(reinterpret_cast<const ZeroTier::InetAddress *>(ia))->isV6();
}
enum ZT_InetAddress_IpScope ZT_InetAddress_ipScope(const ZT_InetAddress *ia)

View file

@ -11,13 +11,11 @@
*/
/****/
use std::cell::Cell;
use std::ffi::CString;
use std::hash::{Hash, Hasher};
use std::mem::{MaybeUninit, zeroed};
use std::mem::zeroed;
use std::os::raw::{c_char, c_uint, c_void};
use std::ptr::{copy_nonoverlapping, null, null_mut};
use std::sync::Mutex;
use num_derive::{FromPrimitive, ToPrimitive};
use num_traits::FromPrimitive;
@ -455,6 +453,7 @@ pub struct CertificateSubject {
pub unique_id_proof_signature: Vec<u8>,
}
#[allow(unused)]
pub(crate) struct CertificateSubjectCAPIContainer {
pub(crate) subject: ztcore::ZT_Certificate_Subject,
subject_identities: Vec<ztcore::ZT_Certificate_Identity>,
@ -635,6 +634,7 @@ pub struct Certificate {
pub signature: Vec<u8>,
}
#[allow(unused)]
pub(crate) struct CertificateCAPIContainer {
pub(crate) certificate: ztcore::ZT_Certificate,
subject_container: CertificateSubjectCAPIContainer,

View file

@ -13,7 +13,7 @@
use std::ffi::CString;
use std::mem::MaybeUninit;
use std::os::raw::{c_char, c_int, c_void};
use std::os::raw::c_char;
use num_derive::{FromPrimitive, ToPrimitive};
use num_traits::FromPrimitive;

View file

@ -12,9 +12,8 @@
/****/
use std::ffi::CString;
use std::mem::{MaybeUninit, transmute, size_of};
use std::mem::{MaybeUninit, transmute};
use serde::{Deserialize, Serialize};
use num_derive::{FromPrimitive, ToPrimitive};
use num_traits::FromPrimitive;
@ -96,9 +95,7 @@ impl InetAddress {
/// The type S MUST have a size equal to the size of this type and the
/// OS's sockaddr_storage. If not, this may crash.
pub unsafe fn transmute_raw_sockaddr_storage<S>(ss: &S) -> &InetAddress {
unsafe {
transmute(ss)
}
transmute(ss)
}
/// Transmute a ZT_InetAddress from the core into a reference to a Rust
@ -163,7 +160,7 @@ impl InetAddress {
if !self.is_nil() {
unsafe {
if ztcore::ZT_InetAddress_isV4(self.as_capi_ptr()) != 0 {
return InetAddressFamily::IPv6;
return InetAddressFamily::IPv4;
}
if ztcore::ZT_InetAddress_isV6(self.as_capi_ptr()) != 0 {
return InetAddressFamily::IPv6;

View file

@ -12,14 +12,13 @@
/****/
use std::any::Any;
use std::cell::{Cell, RefCell};
use std::cell::Cell;
use std::collections::hash_map::HashMap;
use std::ffi::CStr;
use std::fs::copy;
use std::intrinsics::copy_nonoverlapping;
use std::mem::{MaybeUninit, transmute};
use std::os::raw::{c_int, c_uint, c_ulong, c_void};
use std::ptr::{null, null_mut, slice_from_raw_parts};
use std::ptr::{null_mut, slice_from_raw_parts};
use std::sync::*;
use std::sync::atomic::*;
use std::time::Duration;
@ -305,7 +304,7 @@ extern "C" fn zt_path_lookup_function<T: NodeEventHandler + 'static>(
if sock_addr.is_null() {
return 0;
}
let mut sock_family2: InetAddressFamily = InetAddressFamily::Nil;
let sock_family2: InetAddressFamily;
unsafe {
if sock_family == ztcore::ZT_AF_INET {
sock_family2 = InetAddressFamily::IPv4;

View file

@ -11,8 +11,6 @@
*/
/****/
use std::mem::{size_of, transmute, zeroed};
use serde::{Deserialize, Serialize};
use num_derive::{FromPrimitive, ToPrimitive};
use num_traits::FromPrimitive;

View file

@ -1,161 +1,189 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::atomic::{AtomicBool, Ordering};
use zerotier_core::{Buffer, InetAddress, InetAddressFamily};
use std::ffi::CString;
#[cfg(windows)]
pub type RawOsSocket = winapi::um::winsock2::SOCKET;
#[cfg(windows)]
type AfInet = winapi::um::winsock2::AF_INET;
#[cfg(windows)]
type AfInet6 = winapi::um::winsock2::AF_INET6;
pub type FastUDPRawOsSocket = winapi::um::winsock2::SOCKET;
#[cfg(unix)]
pub type RawOsSocket = std::os::raw::c_int;
#[cfg(unix)]
type AfInet = libc::AF_INET;
#[cfg(unix)]
type AfInet6 = libc::AF_INET6;
pub type FastUDPRawOsSocket = libc::c_int;
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#[cfg(target_os = "macos")]
unsafe fn bind_udp_socket(device_name: &CString, address: &InetAddress, af: libc::c_int) -> Option<RawOsSocket> {
let s = libc::socket(af, libc::SOCK_DGRAM, 0);
if s < 0 {
return None;
}
fn bind_udp_socket(_: &str, address: &InetAddress) -> Result<FastUDPRawOsSocket, &'static str> {
unsafe {
let af;
let sa_len;
match address.family() {
InetAddressFamily::IPv4 => {
af = libc::AF_INET;
sa_len = std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t;
},
InetAddressFamily::IPv6 => {
af = libc::AF_INET6;
sa_len = std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t;
},
_ => {
return Err("unrecognized address family");
}
};
let mut fl: libc::c_int;
let fl_size = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
let mut setsockopt_results: libc::c_int = 0;
fl = 1;
setsockopt_results |= libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_REUSEPORT, &mut fl, fl_size);
fl = 1;
setsockopt_results |= libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_REUSEADDR, &mut fl, fl_size);
fl = 1;
setsockopt_results |= libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_BROADCAST, &mut fl, fl_size);
if setsockopt_results != 0 {
libc::close(s);
return None;
}
fl = 1;
libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_NOSIGPIPE, &mut fl, fl_size);
if af == libc::AF_INET {
fl = 1;
libc::setsockopt(s, libc::IPPROTO_IP, 0x4000 /* IP_DF */, &mut fl, fl_size);
}
if af == libc::AF_INET6 {
fl = 1;
libc::setsockopt(s, libc::IPPROTO_IPV6, 62 /* IPV6_DONTFRAG */, &mut fl, fl_size);
fl = 1;
libc::setsockopt(s, libc::IPPROTO_IPV6, libc::IPV6_V6ONLY, &mut fl, fl_size);
}
fl = 1048576;
while fl >= 131072 {
if libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_RCVBUF, &mut fl, fl_size) == 0 {
break;
let s = libc::socket(af, libc::SOCK_DGRAM, 0);
if s < 0 {
return Err("unable to create socket");
}
fl -= 65536;
}
fl = 1048576;
while fl >= 131072 {
if libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_SNDBUF, &mut fl, fl_size) == 0 {
break;
let mut fl: libc::c_int;
let fl_size = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
let mut setsockopt_results: libc::c_int = 0;
// Set options that must succeed: reuse port for multithreading, enable broadcast, disable SIGPIPE, and
// for IPv6 sockets disable receipt of IPv4 packets.
fl = 1;
setsockopt_results |= libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_REUSEPORT, (&mut fl as *mut libc::c_int).cast(), fl_size);
//fl = 1;
//setsockopt_results |= libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_REUSEADDR, (&mut fl as *mut libc::c_int).cast(), fl_size);
fl = 1;
setsockopt_results |= libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_BROADCAST, (&mut fl as *mut libc::c_int).cast(), fl_size);
fl = 1;
setsockopt_results |= libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_NOSIGPIPE, (&mut fl as *mut libc::c_int).cast(), fl_size);
if af == libc::AF_INET6 {
fl = 1;
setsockopt_results |= libc::setsockopt(s, libc::IPPROTO_IPV6, libc::IPV6_V6ONLY, (&mut fl as *mut libc::c_int).cast(), fl_size);
}
if setsockopt_results != 0 {
libc::close(s);
return Err("setsockopt() failed");
}
fl -= 65536;
}
let namidx = libc::if_nametoindex(device_name.as_ptr()) as libc::c_int;
if namidx != 0 {
libc::setsockopt(s, libc::IPPROTO_IP, 25 /* IP_BOUND_IF */, &namidx, fl_size);
}
// Enable UDP fragmentation, which should never really be needed but might make this work if
// somebody finds themselves on a weird network. These are okay if they fail.
if af == libc::AF_INET {
fl = 0;
libc::setsockopt(s, libc::IPPROTO_IP, 0x4000 /* IP_DF */, (&mut fl as *mut libc::c_int).cast(), fl_size);
}
if af == libc::AF_INET6 {
fl = 0;
libc::setsockopt(s, libc::IPPROTO_IPV6, 62 /* IPV6_DONTFRAG */, (&mut fl as *mut libc::c_int).cast(), fl_size);
}
if libc::bind(s, (address as *const InetAddress).cast::<libc::sockaddr>(), std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t) != 0 {
libc::close(s);
return None;
}
// Set send and receive buffers to the largest acceptable value up to desired 1MiB.
fl = 1048576;
while fl >= 131072 {
if libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_RCVBUF, (&mut fl as *mut libc::c_int).cast(), fl_size) == 0 {
break;
}
fl -= 65536;
}
fl = 1048576;
while fl >= 131072 {
if libc::setsockopt(s, libc::SOL_SOCKET, libc::SO_SNDBUF, (&mut fl as *mut libc::c_int).cast(), fl_size) == 0 {
break;
}
fl -= 65536;
}
Some(s)
/*
// Bind socket directly to device to allow ZeroTier to work if it overrides the default route.
if device_name.as_bytes().len() > 0 {
let namidx = libc::if_nametoindex(device_name.as_ptr()) as libc::c_int;
if namidx != 0 {
if libc::setsockopt(s, libc::IPPROTO_IP, 25 /* IP_BOUND_IF */, (&namidx as *const libc::c_int).cast(), std::mem::size_of_val(&namidx) as libc::socklen_t) != 0 {
//libc::perror(std::ptr::null());
libc::close(s);
return Err("bind to interface failed");
}
}
}
*/
if libc::bind(s, (address as *const InetAddress).cast(), sa_len) != 0 {
//libc::perror(std::ptr::null());
libc::close(s);
return Err("bind to address failed");
}
Ok(s)
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
pub trait FastUDPSocketPacketHandler {
fn incoming_udp_packet(socket: &RawOsSocket, from_adddress: &InetAddress, mut data: Buffer);
fn incoming_udp_packet(&self, raw_socket: &FastUDPRawOsSocket, from_adddress: &InetAddress, data: Buffer);
}
/// A multi-threaded (or otherwise fast) UDP socket that binds to both IPv4 and IPv6 addresses.
pub struct FastUDPSocket<H: FastUDPSocketPacketHandler + 'static> {
pub struct FastUDPSocket<H: FastUDPSocketPacketHandler + Send + Sync + 'static> {
handler: Arc<H>,
threads: Vec<std::thread::JoinHandle<()>>,
thread_run: Arc<AtomicBool>,
sockets: Vec<RawOsSocket>,
sockets: Vec<FastUDPRawOsSocket>,
bind_address: InetAddress,
}
#[cfg(unix)]
#[inline(always)]
pub fn fast_udp_socket_send_buffer(socket: &RawOsSocket, to_address: &InetAddress, data: &[u8], packet_ttl: i32) {
pub fn fast_udp_socket_sendto(socket: &FastUDPRawOsSocket, to_address: &InetAddress, data: &[u8], packet_ttl: i32) {
unsafe {
if packet_ttl <= 0 {
libc::sendto(*socket, data.as_ptr(), data.len() as libc::size_t, 0, (to_address as *const InetAddress).cast::<libc::sockaddr>(), std::mem::size_of::<InetAddress>() as libc::socklen_t);
libc::sendto(*socket, data.as_ptr().cast(), data.len() as libc::size_t, 0, (to_address as *const InetAddress).cast(), std::mem::size_of::<InetAddress>() as libc::socklen_t);
} else {
let mut ttl = packet_ttl as libc::c_int;
libc::setsockopt(*socket, libc::IPPROTO_IP, libc::IP_TTL, &mut ttl, std::mem::size_of::<libc::c_int>() as libc::socklen_t);
libc::sendto(*socket, data.as_ptr(), data.len() as libc::size_t, 0, (to_address as *const InetAddress).cast::<libc::sockaddr>(), std::mem::size_of::<InetAddress>() as libc::socklen_t);
libc::setsockopt(*socket, libc::IPPROTO_IP, libc::IP_TTL, (&mut ttl as *mut libc::c_int).cast(), std::mem::size_of::<libc::c_int>() as libc::socklen_t);
libc::sendto(*socket, data.as_ptr().cast(), data.len() as libc::size_t, 0, (to_address as *const InetAddress).cast(), std::mem::size_of::<InetAddress>() as libc::socklen_t);
ttl = 255;
libc::setsockopt(*socket, libc::IPPROTO_IP, libc::IP_TTL, &mut ttl, std::mem::size_of::<libc::c_int>() as libc::socklen_t);
libc::setsockopt(*socket, libc::IPPROTO_IP, libc::IP_TTL, (&mut ttl as *mut libc::c_int).cast(), std::mem::size_of::<libc::c_int>() as libc::socklen_t);
}
}
}
#[cfg(windows)]
#[inline(always)]
pub fn fast_udp_socket_send_buffer(socket: &RawOsSocket, to_address: &InetAddress, data: &[u8], packet_ttl: i32) {
pub fn fast_udp_socket_sendto(socket: &FastUDPRawOsSocket, to_address: &InetAddress, data: &[u8], packet_ttl: i32) {
}
impl<H: FastUDPSocketPacketHandler + 'static> FastUDPSocket<H> {
pub fn new(device_name: &str, address: &InetAddress, handler: &Arc<H>) -> Result<FastUDPSocket<H>, String> {
#[cfg(unix)]
#[inline(always)]
fn fast_udp_socket_recvfrom(socket: &FastUDPRawOsSocket, buf: &mut Buffer, from_address: &mut InetAddress) -> i32 {
unsafe {
let mut addrlen = std::mem::size_of::<InetAddress>() as libc::socklen_t;
libc::recvfrom(*socket, buf.as_mut_ptr().cast(), Buffer::CAPACITY as libc::size_t, 0, (from_address as *mut InetAddress).cast(), &mut addrlen) as i32
}
}
// Integer incremented to select sockets on a mostly round robin basis. This
// isn't synchronized since if all cores don't see it the same there is no
// significant impact. It's just a faster way to pick a socket for sending
// than a random number generator.
static mut SOCKET_SPIN_INT: usize = 0;
impl<H: FastUDPSocketPacketHandler + Send + Sync + 'static> FastUDPSocket<H> {
pub fn new(device_name: &str, address: &InetAddress, handler: &Arc<H>) -> Result<FastUDPSocket<H>, &'static str> {
let thread_count = num_cpus::get();
let mut s = FastUDPSocket{
handler: handler.clone(),
threads: Vec::new(),
thread_run: Arc::new(AtomicBool::new(true)),
threads: Vec::new(),
sockets: Vec::new(),
bind_address: address.clone()
};
let device_name_c = CString::from(device_name);
let af = match address.family() {
InetAddressFamily::IPv4 => AfInet,
InetAddressFamily::IPv6 => AfInet6,
_ => { return Err(String::from("unrecognized address family")); }
};
for _ in 0..thread_count {
let thread_socket = unsafe { bind_udp_socket(&device_name_c, address, af) };
if thread_socket.is_some() {
let thread_socket = bind_udp_socket(device_name, address);
if thread_socket.is_ok() {
let thread_socket = thread_socket.unwrap();
s.sockets.push(thread_socket);
let thread_run = s.thread_run.clone();
let handler_weak = Arc::downgrade(handler);
let handler_weak = Arc::downgrade(&s.handler);
s.threads.push(std::thread::spawn(move || {
let mut from_address = InetAddress::new();
while thread_run.load(Ordering::Relaxed) {
let mut buf = Buffer::new();
let mut addrlen = std::mem::size_of::<InetAddress>() as libc::socklen_t;
let read_length = unsafe { libc::recvfrom(thread_socket, buf.as_mut_ptr(), Buffer::CAPACITY as libc::size_t, 0, (&mut from_address as *mut InetAddress).cast::<libc::sockaddr>(), &mut addrlen) };
let read_length = fast_udp_socket_recvfrom(&thread_socket, &mut buf, &mut from_address);
if read_length > 0 {
let handler = handler_weak.upgrade();
if handler.is_some() {
@ -172,37 +200,125 @@ impl<H: FastUDPSocketPacketHandler + 'static> FastUDPSocket<H> {
}
}
if s.threads.is_empty() {
return Err(String::from("unable to bind to address for IPv4 or IPv6"));
if s.sockets.is_empty() {
return Err("unable to bind to address for IPv4 or IPv6");
}
Ok(s)
}
/// Get a socket suitable for sending.
/// Send from this socket.
/// This actually picks a thread's socket and sends from it. Since all
/// are bound to the same IP:port which one is chosen doesn't matter.
/// Sockets are thread safe.
pub fn send(&self, to_address: &InetAddress, data: &[u8], packet_ttl: i32) {
let mut i;
unsafe {
i = SOCKET_SPIN_INT;
SOCKET_SPIN_INT = i + 1;
i %= self.sockets.len();
}
let s = self.sockets.get(i).unwrap();
fast_udp_socket_sendto(s, to_address, data, packet_ttl);
}
/// Get the number of threads this socket is currently running.
#[inline(always)]
pub fn socket(&self) -> RawOsSocket {
return *self.sockets.get(0).unwrap();
pub fn thread_count(&self) -> usize {
self.threads.len()
}
}
impl<H: FastUDPSocketPacketHandler + 'static> Drop for FastUDPSocket<H> {
impl<H: FastUDPSocketPacketHandler + Send + Sync + 'static> Drop for FastUDPSocket<H> {
#[cfg(windows)]
fn drop(&mut self) {
self.thread_run.store(false, Ordering::Relaxed);
// TODO
for t in self.threads.iter() {
t.join()
}
}
#[cfg(unix)]
fn drop(&mut self) {
let tmp: [u8; 1] = [0];
self.thread_run.store(false, Ordering::Relaxed);
for s in self.sockets.iter() {
unsafe {
libc::sendto(*s as libc::c_int, tmp.as_ptr(), 0, 0, (&self.bind_address as *const InetAddress).cast::<libc::sockaddr>(), std::mem::size_of::<InetAddress>() as libc::socklen_t);
libc::sendto(*s as libc::c_int, tmp.as_ptr().cast(), 0, 0, (&self.bind_address as *const InetAddress).cast(), std::mem::size_of::<InetAddress>() as libc::socklen_t);
}
}
for s in self.sockets.iter() {
unsafe {
libc::shutdown(*s as libc::c_int, libc::SHUT_RDWR);
}
}
for s in self.sockets.iter() {
unsafe {
libc::close(*s as libc::c_int);
}
}
for t in self.threads.iter() {
t.join()
while !self.threads.is_empty() {
self.threads.pop().unwrap().join().expect("unable to join to thread");
}
}
}
#[cfg(test)]
mod tests {
use crate::fastudp::*;
use zerotier_core::{InetAddress, Buffer};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[allow(dead_code)]
struct TestPacketHandler {
cnt: AtomicU32,
side: &'static str
}
impl FastUDPSocketPacketHandler for TestPacketHandler {
#[allow(unused)]
fn incoming_udp_packet(&self, raw_socket: &FastUDPRawOsSocket, from_adddress: &InetAddress, data: Buffer) {
self.cnt.fetch_add(1, Ordering::Relaxed);
//println!("{}: {} bytes from {} (socket: {})", self.side, data.len(), from_adddress.to_string().as_str(), *raw_socket);
}
}
#[test]
fn test_udp_bind_and_transfer() {
{
let ba1 = InetAddress::new_from_string("127.0.0.1/23333");
assert!(ba1.is_some());
let ba1 = ba1.unwrap();
let h1: Arc<TestPacketHandler> = Arc::new(TestPacketHandler {
cnt: AtomicU32::new(0),
side: "Alice",
});
let s1 = FastUDPSocket::new("lo0", &ba1, &h1);
assert!(s1.is_ok());
let s1 = s1.ok().unwrap();
let ba2 = InetAddress::new_from_string("127.0.0.1/23334");
assert!(ba2.is_some());
let ba2 = ba2.unwrap();
let h2: Arc<TestPacketHandler> = Arc::new(TestPacketHandler {
cnt: AtomicU32::new(0),
side: "Bob",
});
let s2 = FastUDPSocket::new("lo0", &ba2, &h2);
assert!(s2.is_ok());
let s2 = s2.ok().unwrap();
let data_bytes = [0_u8; 1024];
loop {
s1.send(&ba2, &data_bytes, 0);
s2.send(&ba1, &data_bytes, 0);
if h1.cnt.load(Ordering::Relaxed) > 10000 && h2.cnt.load(Ordering::Relaxed) > 10000 {
break;
}
}
}
//println!("FastUDPSocket shutdown successful");
}
}

View file

@ -1 +1,2 @@
pub mod fastudpsocket;
pub use fastudpsocket::*;