Controller work and basic in-filesystem DB, slightly different from v1 version and supporting live editing of the network in the filesystem.

This commit is contained in:
Adam Ierymenko 2022-09-22 17:29:10 -04:00
parent be000c2046
commit 373adb028d
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
10 changed files with 372 additions and 124 deletions

View file

@ -23,7 +23,7 @@ pub struct Controller<DatabaseImpl: Database> {
} }
impl<DatabaseImpl: Database> Controller<DatabaseImpl> { impl<DatabaseImpl: Database> Controller<DatabaseImpl> {
pub async fn new(database: Arc<DatabaseImpl>, runtime: tokio::runtime::Handle) -> Arc<Self> { pub fn new(database: Arc<DatabaseImpl>, runtime: tokio::runtime::Handle) -> Arc<Self> {
Arc::new(Self { database, reaper: Reaper::new(&runtime), runtime }) Arc::new(Self { database, reaper: Reaper::new(&runtime), runtime })
} }

View file

@ -1,19 +1,18 @@
use async_trait::async_trait; use async_trait::async_trait;
use zerotier_network_hypervisor::vl1::{Address, NodeStorage}; use zerotier_network_hypervisor::vl1::Address;
use zerotier_network_hypervisor::vl2::NetworkId; use zerotier_network_hypervisor::vl2::NetworkId;
use crate::model::*; use crate::model::*;
#[async_trait] #[async_trait]
pub trait Database: NodeStorage + Sync + Send + 'static { pub trait Database: Sync + Send + Sized + 'static {
type Error; type Error;
async fn get_network(&self, id: NetworkId) -> Result<Option<Network>, Self::Error>; async fn get_network(&self, id: NetworkId) -> Result<Option<Network>, Self::Error>;
async fn save_network(&self, obj: Network) -> Result<(), Self::Error>; async fn save_network(&self, obj: &Network) -> Result<(), Self::Error>;
async fn get_network_members(&self, id: NetworkId) -> Result<Vec<Address>, Self::Error>;
async fn list_members(&self, network_id: NetworkId) -> Result<Vec<Address>, Self::Error>;
async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Self::Error>; async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Self::Error>;
async fn save_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Self::Error>; async fn save_member(&self, obj: &Member) -> Result<(), Self::Error>;
} }

View file

@ -0,0 +1,156 @@
use std::error::Error;
use std::path::{Path, PathBuf};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde::Serialize;
use zerotier_network_hypervisor::vl1::Address;
use zerotier_network_hypervisor::vl2::NetworkId;
use zerotier_utils::json::{json_patch, to_json_pretty};
use zerotier_utils::tokio::fs;
use zerotier_utils::tokio::io::ErrorKind;
use crate::database::Database;
use crate::model::*;
pub struct FileDatabase {
base: PathBuf,
live: PathBuf,
}
fn network_path(base: &PathBuf, network_id: NetworkId) -> PathBuf {
base.join(network_id.to_string()).join("config.json")
}
fn member_path(base: &PathBuf, network_id: NetworkId, member_id: Address) -> PathBuf {
base.join(network_id.to_string()).join(format!("m{}.json", member_id.to_string()))
}
impl FileDatabase {
pub async fn new<P: AsRef<Path>>(base_path: P) -> Self {
let base: PathBuf = base_path.as_ref().into();
let live: PathBuf = base_path.as_ref().join("live");
let _ = fs::create_dir_all(&live).await;
Self { base, live }
}
async fn merge_with_live<O: Serialize + DeserializeOwned>(&self, live_path: PathBuf, changes: O) -> O {
if let Ok(changes) = serde_json::to_value(&changes) {
if let Ok(old_raw_json) = fs::read(&live_path).await {
if let Ok(mut patched) = serde_json::from_slice::<serde_json::Value>(old_raw_json.as_slice()) {
json_patch(&mut patched, &changes, 64);
if let Ok(patched) = serde_json::from_value::<O>(patched) {
if let Ok(patched_json) = serde_json::to_vec(&patched) {
if let Ok(to_replace) = fs::read(&live_path).await {
if to_replace.as_slice().eq(patched_json.as_slice()) {
return patched;
}
}
let _ = fs::write(live_path, patched_json.as_slice()).await;
return patched;
}
}
}
}
}
// TODO: report error
return changes;
}
}
#[async_trait]
impl Database for FileDatabase {
type Error = Box<dyn Error>;
async fn get_network(&self, id: NetworkId) -> Result<Option<Network>, Self::Error> {
let r = fs::read(network_path(&self.base, id)).await;
if let Ok(raw) = r {
let r = serde_json::from_slice::<Network>(raw.as_slice());
if let Ok(network) = r {
return Ok(Some(self.merge_with_live(network_path(&self.live, id), network).await));
} else {
return Err(Box::new(r.err().unwrap()));
}
} else {
let e = r.unwrap_err();
if matches!(e.kind(), ErrorKind::NotFound) {
let _ = fs::remove_dir_all(self.live.join(id.to_string())).await;
return Ok(None);
} else {
return Err(Box::new(e));
}
}
}
async fn save_network(&self, obj: &Network) -> Result<(), Self::Error> {
let _ = fs::create_dir_all(self.base.join(obj.id.to_string())).await;
let _ = fs::create_dir_all(self.live.join(obj.id.to_string())).await;
let base_network_path = network_path(&self.base, obj.id);
if !fs::metadata(&base_network_path).await.is_ok() {
fs::write(base_network_path, to_json_pretty(obj).as_bytes()).await?;
}
fs::write(network_path(&self.live, obj.id), serde_json::to_vec(obj)?.as_slice()).await?;
Ok(())
}
async fn list_members(&self, network_id: NetworkId) -> Result<Vec<Address>, Self::Error> {
let mut members = Vec::new();
let mut dir = fs::read_dir(self.base.join(network_id.to_string())).await?;
while let Ok(Some(ent)) = dir.next_entry().await {
let osname = ent.file_name();
let name = osname.to_string_lossy();
if name.len() == (zerotier_network_hypervisor::protocol::ADDRESS_SIZE_STRING + 6)
&& name.starts_with("m")
&& name.ends_with(".json")
{
if let Ok(member_address) = u64::from_str_radix(&name[1..11], 16) {
if let Some(member_address) = Address::from_u64(member_address) {
members.push(member_address);
}
}
}
}
Ok(members)
}
async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result<Option<Member>, Self::Error> {
let r = fs::read(member_path(&self.base, network_id, node_id)).await;
if let Ok(raw) = r {
let r = serde_json::from_slice::<Member>(raw.as_slice());
if let Ok(member) = r {
return Ok(Some(
self.merge_with_live(member_path(&self.live, network_id, node_id), member).await,
));
} else {
return Err(Box::new(r.err().unwrap()));
}
} else {
let e = r.unwrap_err();
if matches!(e.kind(), ErrorKind::NotFound) {
let _ = fs::remove_file(member_path(&self.live, network_id, node_id)).await;
return Ok(None);
} else {
return Err(Box::new(e));
}
}
}
async fn save_member(&self, obj: &Member) -> Result<(), Self::Error> {
let base_member_path = member_path(&self.base, obj.network_id, obj.node_id);
if !fs::metadata(&base_member_path).await.is_ok() {
fs::write(base_member_path, to_json_pretty(obj).as_bytes()).await?;
}
fs::write(
member_path(&self.live, obj.network_id, obj.node_id),
serde_json::to_vec(obj)?.as_slice(),
)
.await?;
Ok(())
}
}

View file

@ -2,4 +2,5 @@
pub mod controller; pub mod controller;
pub mod database; pub mod database;
pub mod filedatabase;
pub mod model; pub mod model;

View file

@ -1,11 +1,11 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. // (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
use std::collections::{HashMap, HashSet};
use std::hash::Hash; use std::hash::Hash;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use zerotier_network_hypervisor::vl1::Address; use zerotier_network_hypervisor::vl1::{Address, Endpoint, Identity, InetAddress};
use zerotier_network_hypervisor::vl1::InetAddress;
use zerotier_network_hypervisor::vl2::NetworkId; use zerotier_network_hypervisor::vl2::NetworkId;
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -46,6 +46,12 @@ pub struct IpAssignmentPool {
ip_range_end: InetAddress, ip_range_end: InetAddress,
} }
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Tag {
pub id: u32,
pub value: u32,
}
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Network { pub struct Network {
pub id: NetworkId, pub id: NetworkId,
@ -85,43 +91,97 @@ impl Hash for Network {
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Member { pub struct Member {
#[serde(rename = "address")]
pub node_id: Address,
#[serde(rename = "networkId")] #[serde(rename = "networkId")]
pub network_id: NetworkId, pub network_id: NetworkId,
pub address: Address, pub identity: Option<Identity>,
pub name: String, pub name: String,
pub description: String,
#[serde(rename = "creationTime")] #[serde(rename = "creationTime")]
pub creation_time: i64, pub creation_time: i64,
#[serde(rename = "revision")]
pub last_modified_time: i64,
pub authorized: bool,
#[serde(rename = "lastAuthorizedTime")]
pub last_authorized_time: Option<i64>,
#[serde(rename = "lastDeauthorizedTime")]
pub last_deauthorized_time: Option<i64>,
#[serde(rename = "ipAssignments")] #[serde(rename = "ipAssignments")]
pub ip_assignments: Vec<InetAddress>, pub ip_assignments: HashSet<InetAddress>,
#[serde(rename = "noAutoAssignIps")] #[serde(rename = "noAutoAssignIps")]
pub no_auto_assign_ips: bool, pub no_auto_assign_ips: bool,
#[serde(rename = "vMajor")] /// If true this member is a full Ethernet bridge.
pub version_major: u16,
#[serde(rename = "vMinor")]
pub version_minor: u16,
#[serde(rename = "vRev")]
pub version_revision: u16,
#[serde(rename = "vProto")]
pub version_protocol: u16,
pub authorized: bool,
#[serde(rename = "activeBridge")] #[serde(rename = "activeBridge")]
pub bridge: bool, pub bridge: bool,
pub tags: Vec<Tag>,
#[serde(rename = "ssoExempt")] #[serde(rename = "ssoExempt")]
pub sso_exempt: bool, pub sso_exempt: bool,
/// If true this node is explicitly listed in every member's network configuration.
#[serde(rename = "advertised")]
pub advertised: bool,
/// Most recently generated and signed network configuration for this member in binary format.
#[serde(rename = "networkConfig")]
pub network_config: Option<Vec<u8>>,
/// API object type documentation field, not actually edited/used.
#[serde(default = "ObjectType::member")] #[serde(default = "ObjectType::member")]
pub objtype: ObjectType, pub objtype: ObjectType,
} }
impl Hash for Member { /// A complete network with all member configuration information for import/export or blob storage.
#[inline(always)] #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { pub struct NetworkExport {
self.network_id.hash(state); pub network: Network,
self.address.hash(state); pub members: HashMap<Address, Member>,
} }
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum AuthorizationResult {
#[serde(rename = "r")]
Rejected = 0,
#[serde(rename = "rs")]
RejectedViaSSO = 1,
#[serde(rename = "rt")]
RejectedViaToken = 2,
#[serde(rename = "ro")]
RejectedTooOld = 3,
#[serde(rename = "a")]
Approved = 16,
#[serde(rename = "as")]
ApprovedViaSSO = 17,
#[serde(rename = "at")]
ApprovedViaToken = 18,
}
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct QueryLogItem {
#[serde(rename = "nwid")]
pub network_id: NetworkId,
#[serde(rename = "nid")]
pub node_id: Address,
#[serde(rename = "cid")]
pub controller_node_id: Address,
#[serde(rename = "md")]
pub metadata: Vec<u8>,
#[serde(rename = "ts")]
pub timestamp: i64,
#[serde(rename = "v")]
pub version: (u16, u16, u16, u16),
#[serde(rename = "s")]
pub source_remote_endpoint: Endpoint,
#[serde(rename = "sh")]
pub source_hops: u8,
#[serde(rename = "r")]
pub result: AuthorizationResult,
} }

View file

@ -142,8 +142,7 @@ impl Dictionary {
/// Write a dictionary in transport format to a byte vector. /// Write a dictionary in transport format to a byte vector.
pub fn to_bytes(&self) -> Vec<u8> { pub fn to_bytes(&self) -> Vec<u8> {
let mut b: Vec<u8> = Vec::new(); let mut b: Vec<u8> = Vec::with_capacity(32 * self.0.len());
b.reserve(32 * self.0.len());
let _ = self.write_to(&mut b); let _ = self.write_to(&mut b);
b b
} }

View file

@ -12,9 +12,8 @@ path = "src/main.rs"
[dependencies] [dependencies]
zerotier-network-hypervisor = { path = "../network-hypervisor" } zerotier-network-hypervisor = { path = "../network-hypervisor" }
zerotier-crypto = { path = "../crypto" } zerotier-crypto = { path = "../crypto" }
zerotier-utils = { path = "../utils" } zerotier-utils = { path = "../utils", features = ["tokio"] }
zerotier-vl1-service = { path = "../vl1-service" } zerotier-vl1-service = { path = "../vl1-service" }
#tokio = { version = "^1", features = ["fs", "io-util", "io-std", "net", "parking_lot", "process", "rt", "rt-multi-thread", "signal", "sync", "time"], default-features = false }
serde = { version = "^1", features = ["derive"], default-features = false } serde = { version = "^1", features = ["derive"], default-features = false }
serde_json = { version = "^1", features = ["std"], default-features = false } serde_json = { version = "^1", features = ["std"], default-features = false }
parking_lot = { version = "^0", features = [], default-features = false } parking_lot = { version = "^0", features = [], default-features = false }

View file

@ -208,30 +208,35 @@ fn main() {
Some(("leave", cmd_args)) => todo!(), Some(("leave", cmd_args)) => todo!(),
Some(("service", _)) => { Some(("service", _)) => {
drop(global_args); // free unnecessary heap before starting service as we're done with CLI args drop(global_args); // free unnecessary heap before starting service as we're done with CLI args
if let Ok(_tokio_runtime) = zerotier_utils::tokio::runtime::Builder::new_multi_thread().enable_all().build() {
let test_inner = Arc::new(zerotier_network_hypervisor::vl1::DummyInnerProtocol::default());
let test_path_filter = Arc::new(zerotier_network_hypervisor::vl1::DummyPathFilter::default());
let datadir = open_datadir(&flags);
let svc = VL1Service::new(datadir, test_inner, test_path_filter, zerotier_vl1_service::VL1Settings::default());
if svc.is_ok() {
let svc = svc.unwrap();
svc.node().init_default_roots();
let test_inner = Arc::new(zerotier_network_hypervisor::vl1::DummyInnerProtocol::default()); // Wait for kill signal on Unix-like platforms.
let test_path_filter = Arc::new(zerotier_network_hypervisor::vl1::DummyPathFilter::default()); #[cfg(unix)]
let datadir = open_datadir(&flags); {
let svc = VL1Service::new(datadir, test_inner, test_path_filter, zerotier_vl1_service::VL1Settings::default()); let term = Arc::new(AtomicBool::new(false));
if svc.is_ok() { let _ = signal_hook::flag::register(libc::SIGINT, term.clone());
let svc = svc.unwrap(); let _ = signal_hook::flag::register(libc::SIGTERM, term.clone());
svc.node().init_default_roots(); let _ = signal_hook::flag::register(libc::SIGQUIT, term.clone());
while !term.load(Ordering::Relaxed) {
#[cfg(unix)] std::thread::sleep(Duration::from_secs(1));
{ }
let term = Arc::new(AtomicBool::new(false));
let _ = signal_hook::flag::register(libc::SIGINT, term.clone());
let _ = signal_hook::flag::register(libc::SIGTERM, term.clone());
let _ = signal_hook::flag::register(libc::SIGQUIT, term.clone());
while !term.load(Ordering::Relaxed) {
std::thread::sleep(Duration::from_secs(1));
} }
}
println!("Terminate signal received, shutting down..."); println!("Terminate signal received, shutting down...");
exitcode::OK exitcode::OK
} else {
eprintln!("FATAL: error launching service: {}", svc.err().unwrap().to_string());
exitcode::ERR_IOERR
}
} else { } else {
println!("FATAL: error launching service: {}", svc.err().unwrap().to_string()); eprintln!("FATAL: error launching service: can't start async runtime");
exitcode::ERR_IOERR exitcode::ERR_IOERR
} }
} }

View file

@ -10,18 +10,45 @@ use std::ptr::{null, null_mut};
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::Arc; use std::sync::Arc;
#[cfg(unix)]
use std::os::unix::io::RawFd;
use crate::localinterface::LocalInterface; use crate::localinterface::LocalInterface;
#[allow(unused_imports)] #[allow(unused_imports)]
use num_traits::AsPrimitive; use num_traits::AsPrimitive;
use zerotier_network_hypervisor::protocol::{PacketBufferPool, PooledPacketBuffer};
use zerotier_network_hypervisor::vl1::inetaddress::*; use zerotier_network_hypervisor::vl1::inetaddress::*;
use zerotier_utils::ms_monotonic;
use crate::sys::{getifaddrs, ipv6}; use crate::sys::{getifaddrs, ipv6};
fn socket_read_concurrency() -> usize {
const MAX_PER_SOCKET_CONCURRENCY: usize = 8;
static mut THREADS_PER_SOCKET: usize = 0;
unsafe {
let mut t = THREADS_PER_SOCKET;
if t == 0 {
t = std::thread::available_parallelism()
.unwrap()
.get()
.max(1)
.min(MAX_PER_SOCKET_CONCURRENCY);
THREADS_PER_SOCKET = t;
}
t
}
}
pub trait UdpPacketHandler: Send + Sync + 'static {
fn incoming_udp_packet(
self: &Arc<Self>,
time_ticks: i64,
socket: &Arc<BoundUdpSocket>,
source_address: &InetAddress,
packet: PooledPacketBuffer,
);
}
/// A local port to which one or more UDP sockets is bound. /// A local port to which one or more UDP sockets is bound.
/// ///
/// To bind a port we must bind sockets to each interface/IP pair directly. Sockets must /// To bind a port we must bind sockets to each interface/IP pair directly. Sockets must
@ -36,7 +63,7 @@ pub struct BoundUdpSocket {
pub address: InetAddress, pub address: InetAddress,
pub interface: LocalInterface, pub interface: LocalInterface,
last_receive_time: AtomicI64, last_receive_time: AtomicI64,
fd: RawFd, fd: i32,
lock: parking_lot::RwLock<()>, lock: parking_lot::RwLock<()>,
open: AtomicBool, open: AtomicBool,
} }
@ -97,32 +124,6 @@ impl BoundUdpSocket {
return false; return false;
} }
/// Receive a packet or return None if this UDP socket is being closed.
#[cfg(unix)]
pub fn blocking_receive<B: AsMut<[u8]>>(&self, mut buffer: B, current_time: i64) -> Option<(usize, InetAddress)> {
unsafe {
let _hold = self.lock.read();
let b = buffer.as_mut();
let mut from = InetAddress::new();
while self.open.load(Ordering::Relaxed) {
let mut addrlen = std::mem::size_of::<InetAddress>().as_();
let s = libc::recvfrom(
self.fd.as_(),
b.as_mut_ptr().cast(),
b.len().as_(),
0,
(&mut from as *mut InetAddress).cast(),
&mut addrlen,
);
if s > 0 {
self.last_receive_time.store(current_time, Ordering::Relaxed);
return Some((s as usize, from));
}
}
return None;
}
}
#[cfg(unix)] #[cfg(unix)]
fn close(&self) { fn close(&self) {
unsafe { unsafe {
@ -170,11 +171,13 @@ impl BoundUdpPort {
/// The caller can check the 'sockets' member variable after calling to determine which if any bindings were /// The caller can check the 'sockets' member variable after calling to determine which if any bindings were
/// successful. Any errors that occurred are returned as tuples of (interface, address, error). The second vector /// successful. Any errors that occurred are returned as tuples of (interface, address, error). The second vector
/// returned contains newly bound sockets. /// returned contains newly bound sockets.
pub fn update_bindings( pub fn update_bindings<UdpPacketHandlerImpl: UdpPacketHandler>(
&mut self, &mut self,
interface_prefix_blacklist: &Vec<String>, interface_prefix_blacklist: &Vec<String>,
cidr_blacklist: &Vec<InetAddress>, cidr_blacklist: &Vec<InetAddress>,
) -> (Vec<(LocalInterface, InetAddress, std::io::Error)>, Vec<Arc<BoundUdpSocket>>) { buffer_pool: &Arc<PacketBufferPool>,
handler: &Arc<UdpPacketHandlerImpl>,
) -> Vec<(LocalInterface, InetAddress, std::io::Error)> {
let mut existing_bindings: HashMap<LocalInterface, HashMap<InetAddress, Arc<BoundUdpSocket>>> = HashMap::with_capacity(4); let mut existing_bindings: HashMap<LocalInterface, HashMap<InetAddress, Arc<BoundUdpSocket>>> = HashMap::with_capacity(4);
for s in self.sockets.drain(..) { for s in self.sockets.drain(..) {
existing_bindings existing_bindings
@ -184,7 +187,6 @@ impl BoundUdpPort {
} }
let mut errors: Vec<(LocalInterface, InetAddress, std::io::Error)> = Vec::new(); let mut errors: Vec<(LocalInterface, InetAddress, std::io::Error)> = Vec::new();
let mut new_sockets = Vec::new();
getifaddrs::for_each_address(|address, interface| { getifaddrs::for_each_address(|address, interface| {
let interface_str = interface.to_string(); let interface_str = interface.to_string();
let mut addr_with_port = address.clone(); let mut addr_with_port = address.clone();
@ -220,8 +222,36 @@ impl BoundUdpPort {
lock: parking_lot::RwLock::new(()), lock: parking_lot::RwLock::new(()),
open: AtomicBool::new(true), open: AtomicBool::new(true),
}); });
self.sockets.push(s.clone());
new_sockets.push(s); for _ in 0..socket_read_concurrency() {
let ss = s.clone();
let bp = buffer_pool.clone();
let h = handler.clone();
std::thread::spawn(move || unsafe {
let _hold = ss.lock.read();
let mut from = InetAddress::new();
while ss.open.load(Ordering::Relaxed) {
let mut b = bp.get();
let mut addrlen = std::mem::size_of::<InetAddress>().as_();
let s = libc::recvfrom(
ss.fd.as_(),
b.entire_buffer_mut().as_mut_ptr().cast(),
b.capacity().as_(),
0,
(&mut from as *mut InetAddress).cast(),
&mut addrlen,
);
if s > 0 {
b.set_size_unchecked(s as usize);
let time_ticks = ms_monotonic();
ss.last_receive_time.store(time_ticks, Ordering::Relaxed);
h.incoming_udp_packet(time_ticks, &ss, &from, b);
}
}
});
}
self.sockets.push(s);
} else { } else {
errors.push(( errors.push((
interface.clone(), interface.clone(),
@ -246,7 +276,7 @@ impl BoundUdpPort {
} }
} }
(errors, new_sockets) errors
} }
} }
@ -273,7 +303,7 @@ pub fn udp_test_bind(port: u16) -> bool {
#[allow(unused_variables)] #[allow(unused_variables)]
#[cfg(unix)] #[cfg(unix)]
unsafe fn bind_udp_to_device(device_name: &str, address: &InetAddress) -> Result<RawFd, &'static str> { unsafe fn bind_udp_to_device(device_name: &str, address: &InetAddress) -> Result<i32, &'static str> {
let (af, sa_len) = match address.family() { let (af, sa_len) = match address.family() {
AF_INET => (AF_INET, std::mem::size_of::<libc::sockaddr_in>().as_()), AF_INET => (AF_INET, std::mem::size_of::<libc::sockaddr_in>().as_()),
AF_INET6 => (AF_INET6, std::mem::size_of::<libc::sockaddr_in6>().as_()), AF_INET6 => (AF_INET6, std::mem::size_of::<libc::sockaddr_in6>().as_()),
@ -429,5 +459,5 @@ unsafe fn bind_udp_to_device(device_name: &str, address: &InetAddress) -> Result
return Err("bind to address failed"); return Err("bind to address failed");
} }
Ok(s as RawFd) Ok(s as i32)
} }

View file

@ -13,7 +13,7 @@ use zerotier_utils::{ms_monotonic, ms_since_epoch};
use crate::constants::UNASSIGNED_PRIVILEGED_PORTS; use crate::constants::UNASSIGNED_PRIVILEGED_PORTS;
use crate::settings::VL1Settings; use crate::settings::VL1Settings;
use crate::sys::udp::{udp_test_bind, BoundUdpPort}; use crate::sys::udp::{udp_test_bind, BoundUdpPort, UdpPacketHandler};
use crate::LocalSocket; use crate::LocalSocket;
/// This can be adjusted to trade thread count for maximum I/O concurrency. /// This can be adjusted to trade thread count for maximum I/O concurrency.
@ -37,7 +37,7 @@ pub struct VL1Service<
storage: Arc<NodeStorageImpl>, storage: Arc<NodeStorageImpl>,
inner: Arc<InnerProtocolImpl>, inner: Arc<InnerProtocolImpl>,
path_filter: Arc<PathFilterImpl>, path_filter: Arc<PathFilterImpl>,
buffer_pool: PacketBufferPool, buffer_pool: Arc<PacketBufferPool>,
node_container: Option<Node<Self>>, node_container: Option<Node<Self>>,
} }
@ -67,10 +67,10 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
storage, storage,
inner, inner,
path_filter, path_filter,
buffer_pool: PacketBufferPool::new( buffer_pool: Arc::new(PacketBufferPool::new(
std::thread::available_parallelism().map_or(2, |c| c.get() + 2), std::thread::available_parallelism().map_or(2, |c| c.get() + 2),
PacketBufferFactory::new(), PacketBufferFactory::new(),
), )),
node_container: None, node_container: None,
}; };
service.node_container.replace(Node::new(&service, &*service.storage, true, false)?); service.node_container.replace(Node::new(&service, &*service.storage, true, false)?);
@ -162,43 +162,20 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
state state
}; };
let per_socket_concurrency = std::thread::available_parallelism()
.map_or(1, |c| c.get())
.min(MAX_PER_SOCKET_CONCURRENCY);
for (_, binding) in state.udp_sockets.iter() { for (_, binding) in state.udp_sockets.iter() {
let mut binding = binding.write(); let mut binding = binding.write();
let (_, mut new_sockets) = binding.update_bindings(&state.settings.interface_prefix_blacklist, &state.settings.cidr_blacklist); let _ = binding.update_bindings(
for s in new_sockets.drain(..) { &state.settings.interface_prefix_blacklist,
for _ in 0..per_socket_concurrency { &state.settings.cidr_blacklist,
let self_copy = self.clone(); &self.buffer_pool,
let s_copy = s.clone(); self,
std::thread::spawn(move || loop { );
let local_socket = LocalSocket::new(&s_copy); // TODO: if no bindings were successful do something with errors
loop {
let mut buf = self_copy.buffer_pool.get();
let now = ms_monotonic();
if let Some((bytes, from)) = s_copy.blocking_receive(unsafe { buf.entire_buffer_mut() }, now) {
unsafe { buf.set_size_unchecked(bytes) };
self_copy.node().handle_incoming_physical_packet(
&*self_copy,
&*self_copy.inner,
&Endpoint::IpUdp(from),
&local_socket,
&s_copy.interface,
buf,
);
} else {
break;
}
}
});
}
}
} }
} }
fn background_task_daemon(self: Arc<Self>) { fn background_task_daemon(self: Arc<Self>) {
std::thread::sleep(Duration::from_secs(1)); std::thread::sleep(Duration::from_millis(500));
let mut udp_binding_check_every: usize = 0; let mut udp_binding_check_every: usize = 0;
loop { loop {
if !self.state.read().running { if !self.state.read().running {
@ -213,6 +190,28 @@ impl<NodeStorageImpl: NodeStorage + 'static, PathFilterImpl: PathFilter + 'stati
} }
} }
impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl: InnerProtocol> UdpPacketHandler
for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
{
#[inline(always)]
fn incoming_udp_packet(
self: &Arc<Self>,
_time_ticks: i64,
socket: &Arc<crate::sys::udp::BoundUdpSocket>,
source_address: &InetAddress,
packet: zerotier_network_hypervisor::protocol::PooledPacketBuffer,
) {
self.node().handle_incoming_physical_packet(
&*self,
&*self.inner,
&Endpoint::IpUdp(source_address.clone()),
&LocalSocket::new(socket),
&socket.interface,
packet,
);
}
}
impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl: InnerProtocol> HostSystem impl<NodeStorageImpl: NodeStorage, PathFilterImpl: PathFilter, InnerProtocolImpl: InnerProtocol> HostSystem
for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl> for VL1Service<NodeStorageImpl, PathFilterImpl, InnerProtocolImpl>
{ {