Now works without async in the core or the service. Can still be used in the controller.

This commit is contained in:
Adam Ierymenko 2022-09-21 18:27:55 -04:00
parent dca7bb8e85
commit d66f19a2f2
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
9 changed files with 348 additions and 262 deletions

View file

@ -14,8 +14,7 @@ zerotier-network-hypervisor = { path = "../network-hypervisor" }
zerotier-crypto = { path = "../crypto" }
zerotier-utils = { path = "../utils" }
zerotier-vl1-service = { path = "../vl1-service" }
async-trait = "^0"
tokio = { version = "^1", features = ["fs", "io-util", "io-std", "net", "parking_lot", "process", "rt", "rt-multi-thread", "signal", "sync", "time"], default-features = false }
#tokio = { version = "^1", features = ["fs", "io-util", "io-std", "net", "parking_lot", "process", "rt", "rt-multi-thread", "signal", "sync", "time"], default-features = false }
serde = { version = "^1", features = ["derive"], default-features = false }
serde_json = { version = "^1", features = ["std"], default-features = false }
parking_lot = { version = "^0", features = [], default-features = false }
@ -27,3 +26,4 @@ winapi = { version = "^0", features = ["handleapi", "ws2ipdef", "ws2tcpip"] }
[target."cfg(not(windows))".dependencies]
libc = "^0"
signal-hook = "^0"

View file

@ -10,7 +10,7 @@ use zerotier_network_hypervisor::util::marshalable::Marshalable;
use zerotier_network_hypervisor::vl1::RootSet;
use zerotier_utils::json::to_json_pretty;
pub async fn cmd(_: Flags, cmd_args: &ArgMatches) -> i32 {
pub fn cmd(_: Flags, cmd_args: &ArgMatches) -> i32 {
match cmd_args.subcommand() {
Some(("add", _sc_args)) => todo!(),
@ -24,8 +24,8 @@ pub async fn cmd(_: Flags, cmd_args: &ArgMatches) -> i32 {
if path.is_some() && secret_arg.is_some() {
let path = path.unwrap();
let secret_arg = secret_arg.unwrap();
let secret = crate::utils::parse_cli_identity(secret_arg, true).await;
let json_data = crate::utils::read_limit(path, crate::utils::DEFAULT_FILE_IO_READ_LIMIT).await;
let secret = crate::utils::parse_cli_identity(secret_arg, true);
let json_data = crate::utils::read_limit(path, crate::utils::DEFAULT_FILE_IO_READ_LIMIT);
if secret.is_err() {
eprintln!("ERROR: unable to parse '{}' or read as a file.", secret_arg);
return exitcode::ERR_IOERR;
@ -61,7 +61,7 @@ pub async fn cmd(_: Flags, cmd_args: &ArgMatches) -> i32 {
let path = sc_args.value_of("path");
if path.is_some() {
let path = path.unwrap();
let json_data = crate::utils::read_limit(path, crate::utils::DEFAULT_FILE_IO_READ_LIMIT).await;
let json_data = crate::utils::read_limit(path, crate::utils::DEFAULT_FILE_IO_READ_LIMIT);
if json_data.is_err() {
eprintln!("ERROR: unable to read '{}'.", path);
return exitcode::ERR_IOERR;
@ -89,7 +89,7 @@ pub async fn cmd(_: Flags, cmd_args: &ArgMatches) -> i32 {
let path = sc_args.value_of("path");
if path.is_some() {
let path = path.unwrap();
let json_data = crate::utils::read_limit(path, 1048576).await;
let json_data = crate::utils::read_limit(path, 1048576);
if json_data.is_err() {
eprintln!("ERROR: unable to read '{}'.", path);
return exitcode::ERR_IOERR;

View file

@ -28,40 +28,31 @@ pub struct DataDir {
impl NodeStorage for DataDir {
fn load_node_identity(&self) -> Option<Identity> {
todo!()
/*
tokio::runtime::Handle::current().spawn(async {
let id_data = read_limit(self.base_path.join(IDENTITY_SECRET_FILENAME), 4096).await;
if id_data.is_err() {
return None;
}
let id_data = Identity::from_str(String::from_utf8_lossy(id_data.unwrap().as_slice()).as_ref());
if id_data.is_err() {
return None;
}
Some(id_data.unwrap())
})
*/
let id_data = read_limit(self.base_path.join(IDENTITY_SECRET_FILENAME), 4096);
if id_data.is_err() {
return None;
}
let id_data = Identity::from_str(String::from_utf8_lossy(id_data.unwrap().as_slice()).as_ref());
if id_data.is_err() {
return None;
}
Some(id_data.unwrap())
}
fn save_node_identity(&self, id: &Identity) {
/*
tokio::runtime::Handle::current().spawn(async move {
assert!(id.secret.is_some());
let id_secret_str = id.to_secret_string();
let id_public_str = id.to_string();
let secret_path = self.base_path.join(IDENTITY_SECRET_FILENAME);
// TODO: handle errors
let _ = tokio::fs::write(&secret_path, id_secret_str.as_bytes()).await;
assert!(crate::utils::fs_restrict_permissions(&secret_path));
let _ = tokio::fs::write(self.base_path.join(IDENTITY_PUBLIC_FILENAME), id_public_str.as_bytes()).await;
});
*/
assert!(id.secret.is_some());
let id_secret_str = id.to_secret_string();
let id_public_str = id.to_string();
let secret_path = self.base_path.join(IDENTITY_SECRET_FILENAME);
// TODO: handle errors
let _ = std::fs::write(&secret_path, id_secret_str.as_bytes());
assert!(crate::utils::fs_restrict_permissions(&secret_path));
let _ = std::fs::write(self.base_path.join(IDENTITY_PUBLIC_FILENAME), id_public_str.as_bytes());
}
}
impl DataDir {
pub async fn open<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
pub fn open<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
let base_path = path.as_ref().to_path_buf();
if !base_path.is_dir() {
let _ = std::fs::create_dir_all(&base_path);
@ -74,7 +65,7 @@ impl DataDir {
}
let config_path = base_path.join(CONFIG_FILENAME);
let config_data = read_limit(&config_path, DEFAULT_FILE_IO_READ_LIMIT).await;
let config_data = read_limit(&config_path, DEFAULT_FILE_IO_READ_LIMIT);
let config = RwLock::new(Arc::new(if config_data.is_ok() {
let c = serde_json::from_slice::<Config>(config_data.unwrap().as_slice());
if c.is_err() {
@ -93,17 +84,17 @@ impl DataDir {
}
/// Get authorization token for local API, creating and saving if it does not exist.
pub async fn authtoken(&self) -> std::io::Result<String> {
pub fn authtoken(&self) -> std::io::Result<String> {
let authtoken = self.authtoken.lock().clone();
if authtoken.is_empty() {
let authtoken_path = self.base_path.join(AUTH_TOKEN_FILENAME);
let authtoken_bytes = read_limit(&authtoken_path, 4096).await;
let authtoken_bytes = read_limit(&authtoken_path, 4096);
if authtoken_bytes.is_err() {
let mut tmp = String::with_capacity(AUTH_TOKEN_DEFAULT_LENGTH);
for _ in 0..AUTH_TOKEN_DEFAULT_LENGTH {
tmp.push(AUTH_TOKEN_POSSIBLE_CHARS.as_bytes()[(next_u32_secure() as usize) % AUTH_TOKEN_POSSIBLE_CHARS.len()] as char);
}
tokio::fs::write(&authtoken_path, tmp.as_bytes()).await?;
std::fs::write(&authtoken_path, tmp.as_bytes())?;
assert!(crate::utils::fs_restrict_permissions(&authtoken_path));
*self.authtoken.lock() = tmp;
} else {
@ -118,15 +109,15 @@ impl DataDir {
/// Use clone() to get a copy of the configuration if you want to modify it. Then use
/// save_config() to save the modified configuration and update the internal copy in
/// this structure.
pub async fn config(&self) -> Arc<Config> {
pub fn config(&self) -> Arc<Config> {
self.config.read().clone()
}
/// Save a modified copy of the configuration and replace the internal copy in this structure (if it's actually changed).
pub async fn save_config(&self, modified_config: Config) -> std::io::Result<()> {
pub fn save_config(&self, modified_config: Config) -> std::io::Result<()> {
if !modified_config.eq(&self.config.read()) {
let config_data = to_json_pretty(&modified_config);
tokio::fs::write(self.base_path.join(CONFIG_FILENAME), config_data.as_bytes()).await?;
std::fs::write(self.base_path.join(CONFIG_FILENAME), config_data.as_bytes())?;
*self.config.write() = Arc::new(modified_config);
}
Ok(())

View file

@ -8,9 +8,12 @@ pub mod utils;
pub mod vnic;
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use clap::error::{ContextKind, ContextValue};
#[allow(unused_imports)]
use clap::{Arg, ArgMatches, Command};
use zerotier_network_hypervisor::{VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION};
@ -41,8 +44,8 @@ pub struct Flags {
pub auth_token_override: Option<String>,
}
async fn open_datadir(flags: &Flags) -> Arc<DataDir> {
let datadir = DataDir::open(flags.base_path.as_str()).await;
fn open_datadir(flags: &Flags) -> Arc<DataDir> {
let datadir = DataDir::open(flags.base_path.as_str());
if datadir.is_ok() {
return Arc::new(datadir.unwrap());
}
@ -54,50 +57,6 @@ async fn open_datadir(flags: &Flags) -> Arc<DataDir> {
std::process::exit(exitcode::ERR_IOERR);
}
async fn async_main(flags: Flags, global_args: Box<ArgMatches>) -> i32 {
#[allow(unused)]
match global_args.subcommand() {
Some(("help", _)) => {
print_help();
exitcode::OK
}
Some(("version", _)) => {
println!("{}.{}.{}", VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION);
exitcode::OK
}
Some(("status", _)) => todo!(),
Some(("set", cmd_args)) => todo!(),
Some(("peer", cmd_args)) => todo!(),
Some(("network", cmd_args)) => todo!(),
Some(("join", cmd_args)) => todo!(),
Some(("leave", cmd_args)) => todo!(),
Some(("service", _)) => {
drop(global_args); // free unnecessary heap before starting service as we're done with CLI args
let test_inner = Arc::new(zerotier_network_hypervisor::vl1::DummyInnerProtocol::default());
let test_path_filter = Arc::new(zerotier_network_hypervisor::vl1::DummyPathFilter::default());
let datadir = open_datadir(&flags).await;
let svc = VL1Service::new(datadir, test_inner, test_path_filter, zerotier_vl1_service::VL1Settings::default()).await;
if svc.is_ok() {
let svc = svc.unwrap();
svc.node().init_default_roots();
let _ = tokio::signal::ctrl_c().await;
println!("Terminate signal received, shutting down...");
exitcode::OK
} else {
println!("FATAL: error launching service: {}", svc.err().unwrap().to_string());
exitcode::ERR_IOERR
}
}
Some(("identity", cmd_args)) => todo!(),
Some(("rootset", cmd_args)) => cli::rootset::cmd(flags, cmd_args).await,
_ => {
eprintln!("Invalid command line. Use 'help' for help.");
exitcode::ERR_USAGE
}
}
}
fn main() {
let global_args = Box::new({
Command::new("zerotier")
@ -231,11 +190,58 @@ fn main() {
auth_token_override: global_args.value_of("token").map(|t| t.to_string()),
};
std::process::exit(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async_main(flags, global_args)),
);
#[allow(unused)]
let exit_code = match global_args.subcommand() {
Some(("help", _)) => {
print_help();
exitcode::OK
}
Some(("version", _)) => {
println!("{}.{}.{}", VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION);
exitcode::OK
}
Some(("status", _)) => todo!(),
Some(("set", cmd_args)) => todo!(),
Some(("peer", cmd_args)) => todo!(),
Some(("network", cmd_args)) => todo!(),
Some(("join", cmd_args)) => todo!(),
Some(("leave", cmd_args)) => todo!(),
Some(("service", _)) => {
drop(global_args); // free unnecessary heap before starting service as we're done with CLI args
let test_inner = Arc::new(zerotier_network_hypervisor::vl1::DummyInnerProtocol::default());
let test_path_filter = Arc::new(zerotier_network_hypervisor::vl1::DummyPathFilter::default());
let datadir = open_datadir(&flags);
let svc = VL1Service::new(datadir, test_inner, test_path_filter, zerotier_vl1_service::VL1Settings::default());
if svc.is_ok() {
let svc = svc.unwrap();
svc.node().init_default_roots();
#[cfg(unix)]
{
let term = Arc::new(AtomicBool::new(false));
let _ = signal_hook::flag::register(libc::SIGINT, term.clone());
let _ = signal_hook::flag::register(libc::SIGTERM, term.clone());
let _ = signal_hook::flag::register(libc::SIGQUIT, term.clone());
while !term.load(Ordering::Relaxed) {
std::thread::sleep(Duration::from_secs(1));
}
}
println!("Terminate signal received, shutting down...");
exitcode::OK
} else {
println!("FATAL: error launching service: {}", svc.err().unwrap().to_string());
exitcode::ERR_IOERR
}
}
Some(("identity", cmd_args)) => todo!(),
Some(("rootset", cmd_args)) => cli::rootset::cmd(flags, cmd_args),
_ => {
eprintln!("Invalid command line. Use 'help' for help.");
exitcode::ERR_USAGE
}
};
std::process::exit(exit_code);
}

View file

@ -1,11 +1,10 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::str::FromStr;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use zerotier_network_hypervisor::vl1::Identity;
/// Default sanity limit parameter for read_limit() used throughout the service.
@ -14,12 +13,12 @@ pub const DEFAULT_FILE_IO_READ_LIMIT: usize = 1048576;
/// Convenience function to read up to limit bytes from a file.
///
/// If the file is larger than limit, the excess is not read.
pub async fn read_limit<P: AsRef<Path>>(path: P, limit: usize) -> std::io::Result<Vec<u8>> {
let mut f = File::open(path).await?;
let bytes = f.metadata().await?.len().min(limit as u64) as usize;
pub fn read_limit<P: AsRef<Path>>(path: P, limit: usize) -> std::io::Result<Vec<u8>> {
let mut f = File::open(path)?;
let bytes = f.metadata()?.len().min(limit as u64) as usize;
let mut v: Vec<u8> = Vec::with_capacity(bytes);
v.resize(bytes, 0);
f.read_exact(v.as_mut_slice()).await?;
f.read_exact(v.as_mut_slice())?;
Ok(v)
}
@ -70,7 +69,7 @@ pub fn is_valid_port(v: &str) -> Result<(), String> {
}
/// Read an identity as either a literal or from a file.
pub async fn parse_cli_identity(input: &str, validate: bool) -> Result<Identity, String> {
pub fn parse_cli_identity(input: &str, validate: bool) -> Result<Identity, String> {
let parse_func = |s: &str| {
Identity::from_str(s).map_or_else(
|e| Err(format!("invalid identity: {}", e.to_string())),
@ -86,7 +85,7 @@ pub async fn parse_cli_identity(input: &str, validate: bool) -> Result<Identity,
let input_p = Path::new(input);
if input_p.is_file() {
read_limit(input_p, 16384).await.map_or_else(
read_limit(input_p, 16384).map_or_else(
|e| Err(e.to_string()),
|v| String::from_utf8(v).map_or_else(|e| Err(e.to_string()), |s| parse_func(s.as_str())),
)

View file

@ -1,22 +1,19 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
use async_trait::async_trait;
use zerotier_network_hypervisor::vl1::{InetAddress, MAC};
use zerotier_network_hypervisor::vl2::MulticastGroup;
/// Virtual network interface
#[async_trait]
pub trait VNIC {
/// Add a new IPv4 or IPv6 address to this interface, returning true on success.
async fn add_ip(&self, ip: &InetAddress) -> bool;
fn add_ip(&self, ip: &InetAddress) -> bool;
/// Remove an IPv4 or IPv6 address, returning true on success.
/// Nothing happens if the address is not found.
async fn remove_ip(&self, ip: &InetAddress) -> bool;
fn remove_ip(&self, ip: &InetAddress) -> bool;
/// Enumerate all IPs on this interface including ones assigned outside ZeroTier.
async fn ips(&self) -> Vec<InetAddress>;
fn ips(&self) -> Vec<InetAddress>;
/// Get the OS-specific device name for this interface, e.g. zt## or tap##.
fn device_name(&self) -> String;
@ -25,8 +22,8 @@ pub trait VNIC {
/// This doesn't do any IGMP snooping. It just reports the groups the port
/// knows about. On some OSes this may not be supported in which case it
/// will return an empty set.
async fn get_multicast_groups(&self) -> std::collections::BTreeSet<MulticastGroup>;
fn get_multicast_groups(&self) -> std::collections::BTreeSet<MulticastGroup>;
/// Inject an Ethernet frame into this port.
async fn put(&self, source_mac: &MAC, dest_mac: &MAC, ethertype: u16, vlan_id: u16, data: &[u8]) -> bool;
fn put(&self, source_mac: &MAC, dest_mac: &MAC, ethertype: u16, vlan_id: u16, data: &[u8]) -> bool;
}

View file

@ -10,7 +10,6 @@ zerotier-network-hypervisor = { path = "../network-hypervisor" }
zerotier-crypto = { path = "../crypto" }
zerotier-utils = { path = "../utils" }
num-traits = "^0"
tokio = { version = "^1", features = ["fs", "io-util", "io-std", "net", "parking_lot", "process", "rt", "rt-multi-thread", "signal", "sync", "time"], default-features = false }
parking_lot = { version = "^0", features = [], default-features = false }
serde = { version = "^1", features = ["derive"], default-features = false }
serde_json = { version = "^1", features = ["std"], default-features = false }

View file

@ -7,11 +7,11 @@ use std::mem::{size_of, transmute, MaybeUninit};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
#[allow(unused_imports)]
use std::ptr::{null, null_mut};
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::Arc;
#[cfg(unix)]
use std::os::unix::io::{FromRawFd, RawFd};
use std::os::unix::io::RawFd;
use crate::localinterface::LocalInterface;
@ -34,19 +34,17 @@ pub struct BoundUdpPort {
/// A socket bound to a specific interface and IP.
pub struct BoundUdpSocket {
pub address: InetAddress,
pub socket: Arc<tokio::net::UdpSocket>,
pub interface: LocalInterface,
pub associated_tasks: parking_lot::Mutex<Vec<tokio::task::JoinHandle<()>>>,
last_receive_time: AtomicI64,
fd: RawFd,
lock: parking_lot::RwLock<()>,
open: AtomicBool,
}
impl Drop for BoundUdpSocket {
fn drop(&mut self) {
let mut associated_tasks = self.associated_tasks.lock();
for t in associated_tasks.drain(..) {
t.abort();
}
self.close();
let _wait_for_close = self.lock.write();
}
}
@ -66,26 +64,82 @@ impl BoundUdpSocket {
};
}
pub fn send_sync_nonblock(&self, dest: &InetAddress, b: &[u8], packet_ttl: u8) -> bool {
let mut ok = false;
#[cfg(unix)]
pub fn send(&self, dest: &InetAddress, data: &[u8], packet_ttl: u8) -> bool {
if dest.family() == self.address.family() {
if packet_ttl > 0 && dest.is_ipv4() {
self.set_ttl(packet_ttl);
ok = self.socket.try_send_to(b, dest.try_into().unwrap()).is_ok();
self.set_ttl(0xff);
let (c_sockaddr, c_addrlen) = dest.c_sockaddr();
if packet_ttl == 0 || !dest.is_ipv4() {
unsafe {
return libc::sendto(
self.fd.as_(),
data.as_ptr().cast(),
data.len().as_(),
0,
c_sockaddr.cast(),
c_addrlen.as_(),
) >= 0;
}
} else {
ok = self.socket.try_send_to(b, dest.try_into().unwrap()).is_ok();
self.set_ttl(packet_ttl);
let ok = unsafe {
libc::sendto(
self.fd.as_(),
data.as_ptr().cast(),
data.len().as_(),
0,
c_sockaddr.cast(),
c_addrlen.as_(),
) >= 0
};
self.set_ttl(0xff);
return ok;
}
}
ok
return false;
}
pub async fn receive<B: AsMut<[u8]> + Send>(&self, mut buffer: B, current_time: i64) -> tokio::io::Result<(usize, SocketAddr)> {
let result = self.socket.recv_from(buffer.as_mut()).await;
if result.is_ok() {
self.last_receive_time.store(current_time, Ordering::Relaxed);
fn close(&self) {
unsafe {
self.open.store(false, Ordering::SeqCst);
let mut timeo: libc::timeval = std::mem::zeroed();
timeo.tv_sec = 0;
timeo.tv_usec = 1;
libc::setsockopt(
self.fd.as_(),
libc::SOL_SOCKET.as_(),
libc::SO_RCVTIMEO.as_(),
(&mut timeo as *mut libc::timeval).cast(),
std::mem::size_of::<libc::timeval>().as_(),
);
libc::shutdown(self.fd.as_(), libc::SHUT_RDWR);
libc::close(self.fd.as_());
}
}
/// Receive a packet or return None if this UDP socket is being closed.
#[cfg(unix)]
pub fn blocking_receive<B: AsMut<[u8]>>(&self, mut buffer: B, current_time: i64) -> Option<(usize, InetAddress)> {
unsafe {
let _hold = self.lock.read();
let b = buffer.as_mut();
let mut from = InetAddress::new();
while self.open.load(Ordering::Relaxed) {
let mut addrlen = std::mem::size_of::<InetAddress>().as_();
let s = libc::recvfrom(
self.fd.as_(),
b.as_mut_ptr().cast(),
b.len().as_(),
0,
(&mut from as *mut InetAddress).cast(),
&mut addrlen,
);
if s > 0 {
self.last_receive_time.store(current_time, Ordering::Relaxed);
return Some((s as usize, from));
}
}
return None;
}
result
}
}
@ -129,7 +183,7 @@ impl BoundUdpPort {
.insert(s.address.clone(), s);
}
let mut errors = Vec::new();
let mut errors: Vec<(LocalInterface, InetAddress, std::io::Error)> = Vec::new();
let mut new_sockets = Vec::new();
getifaddrs::for_each_address(|address, interface| {
let interface_str = interface.to_string();
@ -146,10 +200,10 @@ impl BoundUdpPort {
&& !ipv6::is_ipv6_temporary(interface_str.as_str(), address)
{
let mut found = false;
if let Some(byaddr) = existing_bindings.get(interface) {
if let Some(socket) = byaddr.get(&addr_with_port) {
if let Some(byaddr) = existing_bindings.get_mut(interface) {
if let Some(socket) = byaddr.remove(&addr_with_port) {
found = true;
self.sockets.push(socket.clone());
self.sockets.push(socket);
}
}
@ -157,20 +211,23 @@ impl BoundUdpPort {
let s = unsafe { bind_udp_to_device(interface_str.as_str(), &addr_with_port) };
if s.is_ok() {
let fd = s.unwrap();
let s = tokio::net::UdpSocket::from_std(unsafe { std::net::UdpSocket::from_raw_fd(fd) });
if s.is_ok() {
let s = Arc::new(BoundUdpSocket {
address: addr_with_port,
socket: Arc::new(s.unwrap()),
interface: interface.clone(),
associated_tasks: parking_lot::Mutex::new(Vec::new()),
last_receive_time: AtomicI64::new(i64::MIN),
fd,
lock: parking_lot::RwLock::new(()),
open: AtomicBool::new(true),
});
self.sockets.push(s.clone());
new_sockets.push(s);
} else {
errors.push((interface.clone(), addr_with_port, s.err().unwrap()));
errors.push((
interface.clone(),
addr_with_port,
std::io::Error::new(std::io::ErrorKind::Other, s.err().unwrap()),
));
}
} else {
errors.push((
@ -183,10 +240,24 @@ impl BoundUdpPort {
}
});
for (_, byaddr) in existing_bindings.iter() {
for (_, s) in byaddr.iter() {
s.close();
}
}
(errors, new_sockets)
}
}
impl Drop for BoundUdpPort {
fn drop(&mut self) {
for s in self.sockets.iter() {
s.close();
}
}
}
/// Attempt to bind universally to a given UDP port and then close to determine if we can use it.
///
/// This succeeds if either IPv4 or IPv6 global can be bound.
@ -216,12 +287,25 @@ unsafe fn bind_udp_to_device(device_name: &str, address: &InetAddress) -> Result
return Err("unable to create new UDP socket");
}
assert_ne!(libc::fcntl(s, libc::F_SETFL, libc::O_NONBLOCK), -1);
#[allow(unused_variables)]
let mut setsockopt_results: libc::c_int = 0;
let mut fl: libc::c_int;
let mut fl;
//assert_ne!(libc::fcntl(s, libc::F_SETFL, libc::O_NONBLOCK), -1);
let mut timeo: libc::timeval = std::mem::zeroed();
timeo.tv_sec = 1;
timeo.tv_usec = 0;
setsockopt_results |= libc::setsockopt(
s,
libc::SOL_SOCKET.as_(),
libc::SO_RCVTIMEO.as_(),
(&mut timeo as *mut libc::timeval).cast(),
std::mem::size_of::<libc::timeval>().as_(),
);
debug_assert!(setsockopt_results == 0);
/*
fl = 1;
setsockopt_results |= libc::setsockopt(
s,
@ -231,6 +315,7 @@ unsafe fn bind_udp_to_device(device_name: &str, address: &InetAddress) -> Result
std::mem::size_of::<libc::c_int>().as_(),
);
debug_assert!(setsockopt_results == 0);
*/
fl = 1;
setsockopt_results |= libc::setsockopt(

View file

@ -3,6 +3,8 @@
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Duration;
use zerotier_crypto::random;
use zerotier_network_hypervisor::protocol::{PacketBufferFactory, PacketBufferPool};
@ -14,8 +16,11 @@ use crate::settings::VL1Settings;
use crate::sys::udp::{udp_test_bind, BoundUdpPort};
use crate::LocalSocket;
use tokio::task::JoinHandle;
use tokio::time::Duration;
/// This can be adjusted to trade thread count for maximum I/O concurrency.
const MAX_PER_SOCKET_CONCURRENCY: usize = 8;
/// Update UDP bindings every this many seconds.
const UPDATE_UDP_BINDINGS_EVERY_SECS: usize = 10;
/// VL1 service that connects to the physical network and hosts an inner protocol like ZeroTier VL2.
///
@ -40,12 +45,13 @@ struct VL1ServiceMutableState {
daemons: Vec<JoinHandle<()>>,
udp_sockets: HashMap<u16, parking_lot::RwLock<BoundUdpPort>>,
settings: VL1Settings,
running: bool,
}
impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'static, InnerProtocolImpl: InnerProtocol + 'static>
VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
{
pub async fn new(
pub fn new(
storage: Arc<NodeStorageImpl>,
inner: Arc<InnerProtocolImpl>,
path_filter: Arc<PathFilterImpl>,
@ -56,6 +62,7 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
daemons: Vec::with_capacity(2),
udp_sockets: HashMap::with_capacity(8),
settings,
running: true,
}),
storage,
inner,
@ -70,8 +77,10 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
let service = Arc::new(service);
let mut daemons = Vec::new();
daemons.push(tokio::spawn(service.clone().udp_bind_daemon()));
daemons.push(tokio::spawn(service.clone().node_background_task_daemon()));
let s = service.clone();
daemons.push(std::thread::spawn(move || {
s.background_task_daemon();
}));
service.state.write().daemons = daemons;
Ok(service)
@ -87,122 +96,119 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
self.state.read().udp_sockets.keys().cloned().collect()
}
async fn udp_bind_daemon(self: Arc<Self>) {
loop {
{
let state = self.state.read();
let mut need_fixed_ports: HashSet<u16> = HashSet::from_iter(state.settings.fixed_ports.iter().cloned());
let mut have_random_port_count = 0;
for (p, _) in state.udp_sockets.iter() {
need_fixed_ports.remove(p);
have_random_port_count += (!state.settings.fixed_ports.contains(p)) as usize;
fn update_udp_bindings(self: &Arc<Self>) {
let state = self.state.read();
let mut need_fixed_ports: HashSet<u16> = HashSet::from_iter(state.settings.fixed_ports.iter().cloned());
let mut have_random_port_count = 0;
for (p, _) in state.udp_sockets.iter() {
need_fixed_ports.remove(p);
have_random_port_count += (!state.settings.fixed_ports.contains(p)) as usize;
}
let desired_random_port_count = state.settings.random_port_count;
let state = if !need_fixed_ports.is_empty() || have_random_port_count != desired_random_port_count {
drop(state);
let mut state = self.state.write();
for p in need_fixed_ports.iter() {
state.udp_sockets.insert(*p, parking_lot::RwLock::new(BoundUdpPort::new(*p)));
}
while have_random_port_count > desired_random_port_count {
let mut most_stale_binding_liveness = (usize::MAX, i64::MAX);
let mut most_stale_binding_port = 0;
for (p, s) in state.udp_sockets.iter() {
if !state.settings.fixed_ports.contains(p) {
let (total_smart_ptr_handles, most_recent_receive) = s.read().liveness();
if total_smart_ptr_handles < most_stale_binding_liveness.0
|| (total_smart_ptr_handles == most_stale_binding_liveness.0
&& most_recent_receive <= most_stale_binding_liveness.1)
{
most_stale_binding_liveness.0 = total_smart_ptr_handles;
most_stale_binding_liveness.1 = most_recent_receive;
most_stale_binding_port = *p;
}
}
}
let desired_random_port_count = state.settings.random_port_count;
let state = if !need_fixed_ports.is_empty() || have_random_port_count != desired_random_port_count {
drop(state);
let mut state = self.state.write();
for p in need_fixed_ports.iter() {
state.udp_sockets.insert(*p, parking_lot::RwLock::new(BoundUdpPort::new(*p)));
}
while have_random_port_count > desired_random_port_count {
let mut most_stale_binding_liveness = (usize::MAX, i64::MAX);
let mut most_stale_binding_port = 0;
for (p, s) in state.udp_sockets.iter() {
if !state.settings.fixed_ports.contains(p) {
let (total_smart_ptr_handles, most_recent_receive) = s.read().liveness();
if total_smart_ptr_handles < most_stale_binding_liveness.0
|| (total_smart_ptr_handles == most_stale_binding_liveness.0
&& most_recent_receive <= most_stale_binding_liveness.1)
{
most_stale_binding_liveness.0 = total_smart_ptr_handles;
most_stale_binding_liveness.1 = most_recent_receive;
most_stale_binding_port = *p;
}
}
}
if most_stale_binding_port != 0 {
have_random_port_count -= state.udp_sockets.remove(&most_stale_binding_port).is_some() as usize;
} else {
break;
}
}
'outer_add_port_loop: while have_random_port_count < desired_random_port_count {
let rn = random::xorshift64_random() as usize;
for i in 0..UNASSIGNED_PRIVILEGED_PORTS.len() {
let p = UNASSIGNED_PRIVILEGED_PORTS[rn.wrapping_add(i) % UNASSIGNED_PRIVILEGED_PORTS.len()];
if !state.udp_sockets.contains_key(&p) && udp_test_bind(p) {
let _ = state.udp_sockets.insert(p, parking_lot::RwLock::new(BoundUdpPort::new(p)));
continue 'outer_add_port_loop;
}
}
let p = 50000 + ((random::xorshift64_random() as u16) % 15535);
if !state.udp_sockets.contains_key(&p) && udp_test_bind(p) {
have_random_port_count += state
.udp_sockets
.insert(p, parking_lot::RwLock::new(BoundUdpPort::new(p)))
.is_none() as usize;
}
}
drop(state);
self.state.read()
if most_stale_binding_port != 0 {
have_random_port_count -= state.udp_sockets.remove(&most_stale_binding_port).is_some() as usize;
} else {
state
};
let num_cores = std::thread::available_parallelism().map_or(1, |c| c.get());
for (_, binding) in state.udp_sockets.iter() {
let mut binding = binding.write();
let (_, mut new_sockets) =
binding.update_bindings(&state.settings.interface_prefix_blacklist, &state.settings.cidr_blacklist);
for s in new_sockets.drain(..) {
// Start one async task per system core. This is technically not necessary because tokio
// schedules and multiplexes, but this enables tokio to grab and schedule packets
// concurrently for up to the number of cores available for any given socket and is
// probably faster than other patterns that involve iterating through sockets and creating
// arrays of futures or using channels.
let mut socket_tasks = Vec::with_capacity(num_cores);
for _ in 0..num_cores {
let self_copy = self.clone();
let s_copy = s.clone();
let local_socket = LocalSocket::new(&s);
socket_tasks.push(tokio::spawn(async move {
loop {
let mut buf = self_copy.buffer_pool.get();
let now = ms_monotonic();
if let Ok((bytes, from_sockaddr)) = s_copy.receive(unsafe { buf.entire_buffer_mut() }, now).await {
unsafe { buf.set_size_unchecked(bytes) };
self_copy.node().handle_incoming_physical_packet(
&*self_copy,
&*self_copy.inner,
&Endpoint::IpUdp(InetAddress::from(from_sockaddr)),
&local_socket,
&s_copy.interface,
buf,
);
}
}
}));
}
debug_assert!(s.associated_tasks.lock().is_empty());
*s.associated_tasks.lock() = socket_tasks;
}
break;
}
}
tokio::time::sleep(Duration::from_secs(10)).await;
'outer_add_port_loop: while have_random_port_count < desired_random_port_count {
let rn = random::xorshift64_random() as usize;
for i in 0..UNASSIGNED_PRIVILEGED_PORTS.len() {
let p = UNASSIGNED_PRIVILEGED_PORTS[rn.wrapping_add(i) % UNASSIGNED_PRIVILEGED_PORTS.len()];
if !state.udp_sockets.contains_key(&p) && udp_test_bind(p) {
let _ = state.udp_sockets.insert(p, parking_lot::RwLock::new(BoundUdpPort::new(p)));
continue 'outer_add_port_loop;
}
}
let p = 50000 + ((random::xorshift64_random() as u16) % 15535);
if !state.udp_sockets.contains_key(&p) && udp_test_bind(p) {
have_random_port_count += state
.udp_sockets
.insert(p, parking_lot::RwLock::new(BoundUdpPort::new(p)))
.is_none() as usize;
}
}
drop(state);
self.state.read()
} else {
state
};
let per_socket_concurrency = std::thread::available_parallelism()
.map_or(1, |c| c.get())
.min(MAX_PER_SOCKET_CONCURRENCY);
for (_, binding) in state.udp_sockets.iter() {
let mut binding = binding.write();
let (_, mut new_sockets) = binding.update_bindings(&state.settings.interface_prefix_blacklist, &state.settings.cidr_blacklist);
for s in new_sockets.drain(..) {
for _ in 0..per_socket_concurrency {
let self_copy = self.clone();
let s_copy = s.clone();
std::thread::spawn(move || loop {
let local_socket = LocalSocket::new(&s_copy);
loop {
let mut buf = self_copy.buffer_pool.get();
let now = ms_monotonic();
if let Some((bytes, from)) = s_copy.blocking_receive(unsafe { buf.entire_buffer_mut() }, now) {
unsafe { buf.set_size_unchecked(bytes) };
self_copy.node().handle_incoming_physical_packet(
&*self_copy,
&*self_copy.inner,
&Endpoint::IpUdp(from),
&local_socket,
&s_copy.interface,
buf,
);
} else {
break;
}
}
});
}
}
}
}
async fn node_background_task_daemon(self: Arc<Self>) {
tokio::time::sleep(Duration::from_secs(1)).await;
fn background_task_daemon(self: Arc<Self>) {
std::thread::sleep(Duration::from_secs(1));
let mut udp_binding_check_every: usize = 0;
loop {
tokio::time::sleep(self.node().do_background_tasks(self.as_ref())).await;
if !self.state.read().running {
break;
}
if (udp_binding_check_every % UPDATE_UDP_BINDINGS_EVERY_SECS) == 0 {
self.update_udp_bindings();
}
udp_binding_check_every = udp_binding_check_every.wrapping_add(1);
std::thread::sleep(self.node().do_background_tasks(self.as_ref()));
}
}
}
@ -238,7 +244,7 @@ impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl
// This is the fast path -- the socket is known to the core so just send it.
if let Some(s) = local_socket {
if let Some(s) = s.0.upgrade() {
s.send_sync_nonblock(address, data, packet_ttl);
s.send(address, data, packet_ttl);
} else {
return;
}
@ -255,7 +261,7 @@ impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl
for _ in 0..p.sockets.len() {
let s = p.sockets.get(i).unwrap();
if s.interface.eq(specific_interface) {
if s.send_sync_nonblock(address, data, packet_ttl) {
if s.send(address, data, packet_ttl) {
break 'socket_search;
}
}
@ -273,7 +279,7 @@ impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl
for _ in 0..p.sockets.len() {
let s = p.sockets.get(i).unwrap();
if !sent_on_interfaces.contains(&s.interface) {
if s.send_sync_nonblock(address, data, packet_ttl) {
if s.send(address, data, packet_ttl) {
sent_on_interfaces.insert(s.interface.clone());
}
}
@ -304,9 +310,12 @@ impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl
{
fn drop(&mut self) {
let mut state = self.state.write();
for d in state.daemons.drain(..) {
d.abort();
}
state.running = false;
state.udp_sockets.clear();
let mut daemons: Vec<JoinHandle<()>> = state.daemons.drain(..).collect();
drop(state);
for d in daemons.drain(..) {
d.join();
}
}
}