From a24be3dbe5c5971209d31b74bd96bd974eb69065 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Wed, 29 Mar 2023 12:37:27 -0400 Subject: [PATCH] Rewrite fileDB to be simpler in controller. Simplify controller. ZOMG IT BUILDS AND TESTS PASS. --- {controller/src => attic}/cache.rs | 91 ++++- controller/Cargo.toml | 1 - controller/src/controller.rs | 95 +++-- controller/src/database.rs | 44 ++- controller/src/filedatabase.rs | 503 ++++++++------------------ controller/src/lib.rs | 2 - controller/src/main.rs | 65 +++- controller/src/model/member.rs | 24 +- controller/src/model/mod.rs | 9 - network-hypervisor/src/vl1/address.rs | 173 +++++++-- network-hypervisor/src/vl1/api.rs | 11 + network-hypervisor/src/vl1/node.rs | 2 +- network-hypervisor/src/vl1/peermap.rs | 7 +- service/src/vl1/vl1service.rs | 9 +- 14 files changed, 544 insertions(+), 492 deletions(-) rename {controller/src => attic}/cache.rs (50%) diff --git a/controller/src/cache.rs b/attic/cache.rs similarity index 50% rename from controller/src/cache.rs rename to attic/cache.rs index 107fc9426..420ac622f 100644 --- a/controller/src/cache.rs +++ b/attic/cache.rs @@ -1,14 +1,15 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently proprietary pending actual release and licensing. See LICENSE.md. -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::error::Error; use std::mem::replace; +use std::ops::Bound; use std::sync::{Mutex, RwLock}; use crate::database::Database; use crate::model::{Member, Network}; -use zerotier_network_hypervisor::vl1::Address; +use zerotier_network_hypervisor::vl1::{Address, PartialAddress}; use zerotier_network_hypervisor::vl2::NetworkId; /// Network and member cache used by database implementations to implement change detection. @@ -18,7 +19,7 @@ use zerotier_network_hypervisor::vl2::NetworkId; /// okay but calls out of order will result in extra updated events being generated for /// movements forward and backward in time. Calls must be temporally ordered. pub struct Cache { - by_nwid: RwLock>)>>, + by_nwid: RwLock>)>>, } impl Cache { @@ -34,12 +35,14 @@ impl Cache { let networks = db.list_networks().await?; for network_id in networks { if let Some(network) = db.get_network(&network_id).await? { - let network_entry = by_nwid.entry(network_id.clone()).or_insert_with(|| (network, Mutex::new(HashMap::new()))); + let network_entry = by_nwid + .entry(network_id.clone()) + .or_insert_with(|| (network, Mutex::new(BTreeMap::new()))); let mut by_node_id = network_entry.1.lock().unwrap(); let members = db.list_members(&network_id).await?; for node_id in members { if let Some(member) = db.get_member(&network_id, &node_id).await? { - let _ = by_node_id.insert(node_id.clone(), member); + let _ = by_node_id.insert(node_id, member); } } } @@ -49,6 +52,7 @@ impl Cache { } /// Update a network if changed, returning whether or not any update was made and the old version if any. + /// /// A value of (true, None) indicates that there was no network by that ID in which case it is added. pub fn on_network_updated(&self, network: Network) -> (bool, Option) { let mut by_nwid = self.by_nwid.write().unwrap(); @@ -59,31 +63,80 @@ impl Cache { (false, None) } } else { - let _ = by_nwid.insert(network.id.clone(), (network.clone(), Mutex::new(HashMap::new()))); + assert!(by_nwid + .insert(network.id.clone(), (network.clone(), Mutex::new(BTreeMap::new()))) + .is_none()); (true, None) } } /// Update a member if changed, returning whether or not any update was made and the old version if any. - /// A value of (true, None) indicates that there was no member with that ID. If there is no network with - /// the member's network ID (false, None) is returned and no action is taken. + /// + /// A value of (true, None) indicates that there was no member with that ID and that it was added. If + /// there is no network with the member's network ID (false, None) is returned and no action is taken. pub fn on_member_updated(&self, member: Member) -> (bool, Option) { let by_nwid = self.by_nwid.read().unwrap(); if let Some(network) = by_nwid.get(&member.network_id) { let mut by_node_id = network.1.lock().unwrap(); - if let Some(prev_member) = by_node_id.get_mut(&member.node_id) { - if !member.eq(prev_member) { - (true, Some(replace(prev_member, member))) - } else { - (false, None) + if let Some(exact_address_match) = by_node_id.get_mut(&member.node_id) { + if !member.eq(exact_address_match) { + return (true, Some(std::mem::replace(exact_address_match, member))); } } else { - let _ = by_node_id.insert(member.node_id.clone(), member); - (true, None) + let mut partial_address_match = None; + for m in by_node_id.range_mut::, Bound<&PartialAddress>)>(( + Bound::Included(&member.node_id), + Bound::Unbounded, + )) { + if m.0.matches_partial(&member.node_id) { + if partial_address_match.is_some() { + return (false, None); + } + let _ = partial_address_match.insert(m.1); + } else { + break; + } + } + + if let Some(partial_address_match) = partial_address_match { + if !member.eq(partial_address_match) { + return (true, Some(std::mem::replace(partial_address_match, member))); + } else { + return (false, None); + } + } + + let mut partial_address_match = None; + for m in by_node_id + .range_mut::, Bound<&PartialAddress>)>(( + Bound::Unbounded, + Bound::Included(&member.node_id), + )) + .rev() + { + if m.0.matches_partial(&member.node_id) { + if partial_address_match.is_some() { + return (false, None); + } + let _ = partial_address_match.insert(m.1); + } else { + break; + } + } + + if let Some(partial_address_match) = partial_address_match { + if !member.eq(partial_address_match) { + return (true, Some(std::mem::replace(partial_address_match, member))); + } else { + return (false, None); + } + } + + assert!(by_node_id.insert(member.node_id.clone(), member).is_none()); + return (true, None); } - } else { - (false, None) } + return (false, None); } /// Delete a network, returning it if it existed. @@ -91,7 +144,7 @@ impl Cache { let mut by_nwid = self.by_nwid.write().unwrap(); let network = by_nwid.remove(&network_id)?; let mut members = network.1.lock().unwrap(); - Some((network.0, members.drain().map(|(_, v)| v).collect())) + Some((network.0, members.values().cloned().collect())) } /// Delete a member, returning it if it existed. @@ -99,6 +152,6 @@ impl Cache { let by_nwid = self.by_nwid.read().unwrap(); let network = by_nwid.get(&network_id)?; let mut members = network.1.lock().unwrap(); - members.remove(&node_id) + members.remove(&node_id.to_partial()) } } diff --git a/controller/Cargo.toml b/controller/Cargo.toml index a143d2bcc..5c6873288 100644 --- a/controller/Cargo.toml +++ b/controller/Cargo.toml @@ -15,7 +15,6 @@ zerotier-service = { path = "../service" } async-trait = "^0" serde = { version = "^1", features = ["derive"], default-features = false } serde_json = { version = "^1", features = ["std"], default-features = false } -serde_yaml = "^0" clap = { version = "^3", features = ["std", "suggestions"], default-features = false } notify = { version = "^5", features = ["macos_fsevent"], default-features = false } tokio-postgres = "^0" diff --git a/controller/src/controller.rs b/controller/src/controller.rs index a6736c342..02f2b5e81 100644 --- a/controller/src/controller.rs +++ b/controller/src/controller.rs @@ -15,12 +15,12 @@ use zerotier_network_hypervisor::vl2::multicastauthority::MulticastAuthority; use zerotier_network_hypervisor::vl2::v1::networkconfig::*; use zerotier_network_hypervisor::vl2::v1::Revocation; use zerotier_network_hypervisor::vl2::NetworkId; +use zerotier_service::vl1::VL1Service; use zerotier_utils::buffer::OutOfBoundsError; use zerotier_utils::cast::cast_ref; use zerotier_utils::reaper::Reaper; use zerotier_utils::tokio; use zerotier_utils::{ms_monotonic, ms_since_epoch}; -use zerotier_vl1_service::VL1Service; use crate::database::*; use crate::model::{AuthenticationResult, Member, RequestLogItem, CREDENTIAL_WINDOW_SIZE_DEFAULT}; @@ -53,7 +53,7 @@ impl Controller { local_identity: IdentitySecret, database: Arc, ) -> Result, Box> { - let c = Arc::new_cyclic(|self_ref| Self { + Ok(Arc::new_cyclic(|self_ref| Self { self_ref: self_ref.clone(), reaper: Reaper::new(&runtime), runtime, @@ -62,16 +62,22 @@ impl Controller { multicast_authority: MulticastAuthority::new(), daemons: Mutex::new(Vec::with_capacity(2)), recently_authorized: RwLock::new(HashMap::new()), - }); + })) + } - if let Some(cw) = c.database.changes().await.map(|mut ch| { - let self2 = c.self_ref.clone(); - c.runtime.spawn(async move { + /// Start this controller's background tasks. + /// + /// Note that the controller only holds a Weak> to avoid circular references. + pub async fn start(&self, app: &Arc>) { + if let Some(cw) = self.database.changes().await.map(|mut ch| { + let controller_weak = self.self_ref.clone(); + let app_weak = Arc::downgrade(app); + self.runtime.spawn(async move { loop { if let Ok(change) = ch.recv().await { - if let Some(self2) = self2.upgrade() { - self2.reaper.add( - self2.runtime.spawn(self2.clone().handle_change_notification(change)), + if let (Some(controller), Some(app)) = (controller_weak.upgrade(), app_weak.upgrade()) { + controller.reaper.add( + controller.runtime.spawn(controller.clone().handle_change_notification(app, change)), Instant::now().checked_add(REQUEST_TIMEOUT).unwrap(), ); } else { @@ -81,19 +87,19 @@ impl Controller { } }) }) { - c.daemons.lock().unwrap().push(cw); + self.daemons.lock().unwrap().push(cw); } - let self2 = c.self_ref.clone(); - c.daemons.lock().unwrap().push(c.runtime.spawn(async move { + let controller_weak = self.self_ref.clone(); + self.daemons.lock().unwrap().push(self.runtime.spawn(async move { let sleep_duration = Duration::from_millis((protocol::VL2_DEFAULT_MULTICAST_LIKE_EXPIRE / 2).min(2500) as u64); loop { tokio::time::sleep(sleep_duration).await; - if let Some(self2) = self2.upgrade() { + if let Some(controller) = controller_weak.upgrade() { let time_ticks = ms_monotonic(); - self2.multicast_authority.clean(time_ticks); - self2.recently_authorized.write().unwrap().retain(|_, by_network| { + controller.multicast_authority.clean(time_ticks); + controller.recently_authorized.write().unwrap().retain(|_, by_network| { by_network.retain(|_, timeout| *timeout > time_ticks); !by_network.is_empty() }); @@ -102,12 +108,10 @@ impl Controller { } } })); - - Ok(c) } /// Launched as a task when the DB informs us of a change. - async fn handle_change_notification(self: Arc, change: Change) { + async fn handle_change_notification(self: Arc, app: Arc>, change: Change) { match change { Change::NetworkCreated(_) => {} Change::NetworkChanged(_, _) => {} @@ -115,10 +119,10 @@ impl Controller { Change::MemberCreated(_) => {} Change::MemberChanged(old_member, new_member) => { if !new_member.authorized() && old_member.authorized() { - self.deauthorize_member(&new_member).await; + self.deauthorize_member(&app, &new_member).await; } } - Change::MemberDeleted(member) => self.deauthorize_member(&member).await, + Change::MemberDeleted(member) => self.deauthorize_member(&app, &member).await, } } @@ -201,15 +205,16 @@ impl Controller { let time_clock = ms_since_epoch(); let mut revocations = Vec::with_capacity(1); if let Ok(all_network_members) = self.database.list_members(&member.network_id).await { - for m in all_network_members.iter() { - if member.node_id != *m { - if let Some(peer) = app.node.peer(m) { + for other_member in all_network_members.iter() { + if member.node_id != *other_member && member.node_id.is_complete() && other_member.is_complete() { + let node_id = member.node_id.as_complete().unwrap(); + if let Some(peer) = app.node.peer(node_id) { revocations.clear(); revocations.push(Revocation::new( &member.network_id, time_clock, - &member.node_id, - m, + node_id, + other_member.as_complete().unwrap(), &self.local_identity, false, )); @@ -239,7 +244,7 @@ impl Controller { } let network = network.unwrap(); - let mut member = self.database.get_member(&network_id, &source_identity.address).await?; + let mut member = self.database.get_member(&network_id, &source_identity.address.to_partial()).await?; let mut member_changed = false; let mut authentication_result = AuthenticationResult::Rejected; @@ -253,7 +258,7 @@ impl Controller { if !member_authorized { if member.is_none() { if network.learn_members.unwrap_or(true) { - let _ = member.insert(Member::new(source_identity.address.clone(), network_id.clone())); + let _ = member.insert(Member::new(source_identity.clone(), network_id.clone())); member_changed = true; } else { return Ok((AuthenticationResult::Rejected, None)); @@ -289,6 +294,16 @@ impl Controller { let member_authorized = member_authorized; let authentication_result = authentication_result; + // Pin full address and full identity if these aren't pinned already. + if !member.node_id.is_complete() { + member.node_id = source_identity.address.to_partial(); + member_changed = true; + } + if member.identity.is_none() { + let _ = member.identity.insert(source_identity.clone().remove_typestate()); + member_changed = true; + } + // Generate network configuration if the member is authorized. let network_config = if authentication_result.approved() { // We should not be able to make it here if this is still false. @@ -325,8 +340,7 @@ impl Controller { nc.rules.reserve(deauthed_members_still_in_window.len() + 1); let mut or = false; for dead in deauthed_members_still_in_window.iter() { - nc.rules - .push(vl2::rule::Rule::match_source_zerotier_address(false, or, dead.to_partial())); + nc.rules.push(vl2::rule::Rule::match_source_zerotier_address(false, or, dead.clone())); or = true; } nc.rules.push(vl2::rule::Rule::action_drop()); @@ -455,21 +469,26 @@ impl InnerProtocolLayer for Controller { }; // Launch handler as an async background task. - let app: &VL1Service = cast_ref(app).unwrap(); - let app = app.get(); - let (self2, source, source_remote_endpoint) = (self.self_ref.upgrade().unwrap(), source.clone(), source_path.endpoint.clone()); + let app = app.concrete_self::>().unwrap().get_self_arc(); // can't be a dead pointer since we're in a handler being called by it + let (controller, source, source_remote_endpoint) = (self.self_ref.upgrade().unwrap(), source.clone(), source_path.endpoint.clone()); self.reaper.add( self.runtime.spawn(async move { let node_id = source.identity.address.clone(); let now = ms_since_epoch(); - let (result, config) = match self2.authorize(&source.identity, &network_id, now).await { + let result = match controller.authorize(&source.identity, &network_id, now).await { Result::Ok((result, Some(config))) => { //println!("{}", serde_yaml::to_string(&config).unwrap()); - self2.send_network_config(app.as_ref(), &app.node, cast_ref(source.as_ref()).unwrap(), &config, Some(message_id)); - (result, Some(config)) + controller.send_network_config( + app.as_ref(), + &app.node, + cast_ref(source.as_ref()).unwrap(), + &config, + Some(message_id), + ); + result } - Result::Ok((result, None)) => (result, None), + Result::Ok((result, None)) => result, Result::Err(e) => { #[cfg(debug_assertions)] debug_event!(app, "[vl2] ERROR getting network config: {}", e.to_string()); @@ -477,12 +496,12 @@ impl InnerProtocolLayer for Controller { } }; - let _ = self2 + let _ = controller .database .log_request(RequestLogItem { network_id, node_id, - controller_node_id: self2.local_identity.public.address.clone(), + controller_node_id: controller.local_identity.public.address.clone(), metadata, peer_version: source.version(), peer_protocol_version: source.protocol_version(), diff --git a/controller/src/database.rs b/controller/src/database.rs index 3e9a7a848..1e39d5882 100644 --- a/controller/src/database.rs +++ b/controller/src/database.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use zerotier_crypto::secure_eq; -use zerotier_network_hypervisor::vl1::{Address, InetAddress}; +use zerotier_network_hypervisor::vl1::{InetAddress, PartialAddress}; use zerotier_network_hypervisor::vl2::NetworkId; use zerotier_utils::tokio::sync::broadcast::Receiver; @@ -22,22 +22,50 @@ pub enum Change { #[async_trait] pub trait Database: Sync + Send + 'static { + /// List networks on this controller. async fn list_networks(&self) -> Result, Error>; + + /// Get a network by network ID. async fn get_network(&self, id: &NetworkId) -> Result, Error>; + + /// Save a network. + /// + /// Note that unlike members the network ID is not automatically promoted from legacy to full + /// ID format. async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Error>; - async fn list_members(&self, network_id: &NetworkId) -> Result, Error>; - async fn get_member(&self, network_id: &NetworkId, node_id: &Address) -> Result, Error>; + /// List members of a network. + async fn list_members(&self, network_id: &NetworkId) -> Result, Error>; + + /// Get a member of network. + /// + /// If node_id is not a complete address, the best unique match should be returned. None should + /// be returned not only if the member is not found but if node_id is ambiguous (would match more + /// than one member). + async fn get_member(&self, network_id: &NetworkId, node_id: &PartialAddress) -> Result, Error>; + + /// Save a modified member to a network. + /// + /// Note that member modifications can include the automatic replacement of a less specific address + /// in node_id with a fully specific address. This happens the first time a member added with an + /// incomplete address is actually seen. In that case the implementation must correctly find the + /// best matching existing member and replace it with a member identified by the fully specified + /// address, removing and re-adding if needed. + /// + /// This must also handle the (rare) case when someone may try to save a member with a less + /// specific address than the one currently in the database. In that case the "old" more specific + /// address should replace the less specific address in the node_id field. This can only happen if + /// an external user manually does this. The controller won't do this automatically. async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Error>; + /// Save a log entry for a request this controller has handled. + async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error>; + /// Get a receiver that can be used to receive changes made to networks and members, if supported. /// /// The receiver returned is a broadcast receiver. This can be called more than once if there are /// multiple parts of the controller that listen. /// - /// Changes should NOT be broadcast on call to save_network() or save_member(). They should only - /// be broadcast when externally generated changes occur. - /// /// The default implementation returns None indicating that change following is not supported. /// Change following is required for instant deauthorization with revocations and other instant /// changes in response to modifications to network and member configuration. @@ -49,7 +77,7 @@ pub trait Database: Sync + Send + 'static { /// /// The default trait implementation uses a brute force method. This should be reimplemented if a /// more efficient way is available. - async fn list_members_deauthorized_after(&self, network_id: &NetworkId, cutoff: i64) -> Result, Error> { + async fn list_members_deauthorized_after(&self, network_id: &NetworkId, cutoff: i64) -> Result, Error> { let mut v = Vec::new(); let members = self.list_members(network_id).await?; for a in members.iter() { @@ -77,6 +105,4 @@ pub trait Database: Sync + Send + 'static { } return Ok(false); } - - async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error>; } diff --git a/controller/src/filedatabase.rs b/controller/src/filedatabase.rs index 73c369d68..7d528f46c 100644 --- a/controller/src/filedatabase.rs +++ b/controller/src/filedatabase.rs @@ -1,404 +1,201 @@ +use std::collections::BTreeMap; +use std::mem::replace; use std::path::{Path, PathBuf}; -use std::str::FromStr; -use std::sync::{Arc, Mutex, Weak}; +use std::sync::{Arc, Weak}; + +use serde::{Deserialize, Serialize}; use async_trait::async_trait; -use notify::{RecursiveMode, Watcher}; -use serde::de::DeserializeOwned; +use zerotier_utils::tokio::io::AsyncWriteExt; -use zerotier_network_hypervisor::vl1::Address; +use crate::database; +use crate::database::Change; +use crate::model::{Member, Network, RequestLogItem}; + +use zerotier_network_hypervisor::vl1::PartialAddress; use zerotier_network_hypervisor::vl2::NetworkId; -use zerotier_utils::reaper::Reaper; -use zerotier_utils::tokio::fs; -use zerotier_utils::tokio::runtime::Handle; -use zerotier_utils::tokio::sync::broadcast::{channel, Receiver, Sender}; -use zerotier_utils::tokio::task::JoinHandle; -use zerotier_utils::tokio::time::{sleep, Duration, Instant}; +use zerotier_utils::tokio; +use zerotier_utils::tokio::sync::{broadcast, mpsc}; -use crate::cache::Cache; -use crate::database::{Change, Database, Error}; -use crate::model::*; - -const EVENT_HANDLER_TASK_TIMEOUT: Duration = Duration::from_secs(10); - -/// An in-filesystem database that permits live editing. -/// -/// A cache is maintained that contains the actual objects. When an object is live edited, -/// once it successfully reads and loads it is merged with the cached object and saved to -/// the cache. The cache will also contain any ephemeral data, generated data, etc. -/// -/// The file format is YAML instead of JSON for better human friendliness and the layout -/// is different from V1 so it'll need a converter to use with V1 FileDb controller data. pub struct FileDatabase { - base_path: PathBuf, - controller_address: Address, - change_sender: Sender, - tasks: Reaper, - cache: Cache, - daemon: JoinHandle<()>, + db_path: PathBuf, + log: Option>, + data: tokio::sync::Mutex<(BTreeMap, bool)>, + change_sender: broadcast::Sender, + file_write_notify_sender: mpsc::Sender<()>, + file_writer: tokio::task::JoinHandle<()>, } -// TODO: should cache at least hashes and detect changes in the filesystem live. +#[derive(Serialize, Deserialize)] +struct FileDbNetwork { + pub config: Network, + pub members: BTreeMap, +} impl FileDatabase { - pub async fn new>(runtime: Handle, base_path: P, controller_address: Address) -> Result, Error> { - let base_path: PathBuf = base_path.as_ref().into(); + pub async fn new(db_path: &Path, log_path: Option<&Path>) -> Result, Box> { + let data_bytes = tokio::fs::read(db_path).await; + let mut data: BTreeMap = BTreeMap::new(); + if let Err(e) = data_bytes { + if !matches!(e.kind(), tokio::io::ErrorKind::NotFound) { + return Err(Box::new(e)); + } + } else { + data = serde_json::from_slice(data_bytes.as_ref().unwrap().as_slice())?; + } - let (change_sender, _) = channel(256); - let db_weak_tmp: Arc>> = Arc::new(Mutex::new(Weak::default())); - let db_weak = db_weak_tmp.clone(); - let runtime2 = runtime.clone(); + let log = if let Some(log_path) = log_path { + Some(tokio::sync::Mutex::new( + tokio::fs::OpenOptions::new().append(true).create(true).mode(0o600).open(log_path).await?, + )) + } else { + None + }; - let db = Arc::new(Self { - base_path: base_path.clone(), - controller_address: controller_address.clone(), - change_sender, - tasks: Reaper::new(&runtime2), - cache: Cache::new(), - daemon: runtime2.spawn(async move { - let mut watcher = notify::recommended_watcher(move |event: notify::Result| { - if let Ok(event) = event { - match event.kind { - notify::EventKind::Create(_) | notify::EventKind::Modify(_) | notify::EventKind::Remove(_) => { - if let Some(db) = db_weak.lock().unwrap().upgrade() { - let controller_address2 = controller_address.clone(); - db.clone().tasks.add( - runtime.spawn(async move { - if let Some(path0) = event.paths.first() { - if let Some((record_type, network_id, node_id)) = - Self::record_type_from_path(controller_address2, path0.as_path()) - { - // Paths to objects that were deleted or changed. Changed includes adding new objects. - let mut deleted = None; - let mut changed = None; - - match event.kind { - notify::EventKind::Create(create_kind) => match create_kind { - notify::event::CreateKind::File => { - changed = Some(path0.as_path()); - } - _ => {} - }, - notify::EventKind::Modify(modify_kind) => match modify_kind { - notify::event::ModifyKind::Data(_) => { - changed = Some(path0.as_path()); - } - notify::event::ModifyKind::Name(rename_mode) => match rename_mode { - notify::event::RenameMode::Both => { - if event.paths.len() >= 2 { - if let Some(path1) = event.paths.last() { - deleted = Some(path0.as_path()); - changed = Some(path1.as_path()); - } - } - } - notify::event::RenameMode::From => { - deleted = Some(path0.as_path()); - } - notify::event::RenameMode::To => { - changed = Some(path0.as_path()); - } - _ => {} - }, - _ => {} - }, - notify::EventKind::Remove(remove_kind) => match remove_kind { - notify::event::RemoveKind::File => { - deleted = Some(path0.as_path()); - } - _ => {} - }, - _ => {} - } - - if deleted.is_some() { - match record_type { - RecordType::Network => { - if let Some((network, members)) = db.cache.on_network_deleted(network_id) { - let _ = db.change_sender.send(Change::NetworkDeleted(network, members)); - } - } - RecordType::Member => { - if let Some(node_id) = node_id { - if let Some(member) = db.cache.on_member_deleted(network_id, node_id) { - let _ = db.change_sender.send(Change::MemberDeleted(member)); - } - } - } - _ => {} - } - } - - if let Some(changed) = changed { - match record_type { - RecordType::Network => { - if let Ok(Some(new_network)) = Self::load_object::(changed).await { - match db.cache.on_network_updated(new_network.clone()) { - (true, Some(old_network)) => { - let _ = db - .change_sender - .send(Change::NetworkChanged(old_network, new_network)); - } - (true, None) => { - let _ = db.change_sender.send(Change::NetworkCreated(new_network)); - } - _ => {} - } - } - } - RecordType::Member => { - if let Ok(Some(new_member)) = Self::load_object::(changed).await { - match db.cache.on_member_updated(new_member.clone()) { - (true, Some(old_member)) => { - let _ = - db.change_sender.send(Change::MemberChanged(old_member, new_member)); - } - (true, None) => { - let _ = db.change_sender.send(Change::MemberCreated(new_member)); - } - _ => {} - } - } - } - _ => {} - } - } - } - } - }), - Instant::now().checked_add(EVENT_HANDLER_TASK_TIMEOUT).unwrap(), + let (file_write_notify_sender, mut file_write_notify_receiver) = mpsc::channel(16); + let db = Arc::new_cyclic(|self_weak: &Weak| { + let self_weak = self_weak.clone(); + Self { + db_path: db_path.to_path_buf(), + log, + data: tokio::sync::Mutex::new((data, false)), + change_sender: broadcast::channel(16).0, + file_write_notify_sender, + file_writer: tokio::task::spawn(async move { + loop { + file_write_notify_receiver.recv().await; + if let Some(db) = self_weak.upgrade() { + let mut data = db.data.lock().await; + if data.1 { + let json = zerotier_utils::json::to_json_pretty(&data.0); + if let Err(e) = tokio::fs::write(db.db_path.as_path(), json.as_bytes()).await { + eprintln!( + "WARNING: controller changes not persisted! unable to write file database to '{}': {}", + db.db_path.to_string_lossy(), + e.to_string() ); + } else { + data.1 = false; } } - _ => {} + } else { + break; } } - }) - .expect("FATAL: unable to start filesystem change listener"); - let _ = watcher.configure( - notify::Config::default() - .with_compare_contents(true) - .with_poll_interval(std::time::Duration::from_secs(2)), - ); - watcher - .watch(&base_path, RecursiveMode::Recursive) - .expect("FATAL: unable to watch base path"); - - loop { - // Any periodic background stuff can be put here. Adjust timing as needed. - sleep(Duration::from_secs(10)).await; - } - }), + }), + } }); - db.cache.load_all(db.as_ref()).await?; - *db_weak_tmp.lock().unwrap() = Arc::downgrade(&db); // this starts the daemon tasks and starts watching for file changes - Ok(db) } - - fn network_path(&self, network_id: &NetworkId) -> PathBuf { - self.base_path.join(format!("N{:06x}", network_id.network_no())).join("config.yaml") - } - - fn member_path(&self, network_id: &NetworkId, member_id: &Address) -> PathBuf { - self.base_path - .join(format!("N{:06x}", network_id.network_no())) - .join(format!("M{}.yaml", member_id.to_string())) - } - - async fn load_object(path: &Path) -> Result, Error> { - if let Ok(raw) = fs::read(path).await { - return Ok(Some(serde_yaml::from_slice::(raw.as_slice())?)); - } else { - return Ok(None); - } - } - - /// Get record type and also the number after it: network number or address. - fn record_type_from_path(controller_address: Address, p: &Path) -> Option<(RecordType, NetworkId, Option
)> { - let parent = p.parent()?.file_name()?.to_string_lossy(); - if parent.len() == 7 && (parent.starts_with("N") || parent.starts_with('n')) { - let network_id = NetworkId::Full(controller_address, u32::from_str_radix(&parent[1..], 16).ok()?); - if let Some(file_name) = p.file_name().map(|p| p.to_string_lossy().to_lowercase()) { - if file_name.eq("config.yaml") { - return Some((RecordType::Network, network_id, None)); - } else if file_name.len() == 16 && file_name.starts_with("m") && file_name.ends_with(".yaml") { - return Some(( - RecordType::Member, - network_id, - Some(Address::from_str(&file_name.as_str()[1..file_name.len() - 5]).ok()?), - )); - } - } - } - return None; - } } impl Drop for FileDatabase { fn drop(&mut self) { - self.daemon.abort(); + self.file_writer.abort(); } } #[async_trait] -impl Database for FileDatabase { - async fn list_networks(&self) -> Result, Error> { - let mut networks = Vec::new(); - let mut dir = fs::read_dir(&self.base_path).await?; - while let Ok(Some(ent)) = dir.next_entry().await { - if ent.file_type().await.map_or(false, |t| t.is_dir()) { - let osname = ent.file_name(); - let name = osname.to_string_lossy(); - if name.len() == 7 && name.starts_with("N") { - if fs::metadata(ent.path().join("config.yaml")).await.is_ok() { - if let Ok(network_no) = u32::from_str_radix(&name[1..], 16) { - networks.push(NetworkId::Full(self.controller_address.clone(), network_no)); - } - } +impl database::Database for FileDatabase { + async fn list_networks(&self) -> Result, database::Error> { + Ok(self.data.lock().await.0.keys().cloned().collect()) + } + + async fn get_network(&self, id: &NetworkId) -> Result, database::Error> { + Ok(self.data.lock().await.0.get(id).map(|x| x.config.clone())) + } + + async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), database::Error> { + let mut data = self.data.lock().await; + if let Some(nw) = data.0.get_mut(&obj.id) { + if !nw.config.eq(&obj) { + let old = replace(&mut nw.config, obj); + if generate_change_notification { + let _ = self.change_sender.send(Change::NetworkChanged(old, nw.config.clone())); } + let _ = self.file_write_notify_sender.send(()).await; } - } - Ok(networks) - } - - async fn get_network(&self, id: &NetworkId) -> Result, Error> { - let mut network = Self::load_object::(self.network_path(id).as_path()).await?; - if let Some(network) = network.as_mut() { - // FileDatabase stores networks by their "network number" and automatically adapts their IDs - // if the controller's identity changes. This is done to make it easy to just clone networks, - // including storing them in "git." - let network_id_should_be = NetworkId::Full(self.controller_address.clone(), network.id.network_no()); - if network.id != network_id_should_be { - network.id = network_id_should_be; - let _ = self.save_network(network.clone(), false).await?; + } else { + data.0 + .insert(obj.id.clone(), FileDbNetwork { config: obj.clone(), members: BTreeMap::new() }); + if generate_change_notification { + let _ = self.change_sender.send(Change::NetworkCreated(obj)); } + let _ = self.file_write_notify_sender.send(()).await; } - Ok(network) - } - - async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Error> { - if !generate_change_notification { - let _ = self.cache.on_network_updated(obj.clone()); - } - let base_network_path = self.network_path(&obj.id); - let _ = fs::create_dir_all(base_network_path.parent().unwrap()).await; - let _ = fs::write(base_network_path, serde_yaml::to_string(&obj)?.as_bytes()).await?; return Ok(()); } - async fn list_members(&self, network_id: &NetworkId) -> Result, Error> { - let mut members = Vec::new(); - let mut dir = fs::read_dir(self.base_path.join(format!("N{:06x}", network_id.network_no()))).await?; - while let Ok(Some(ent)) = dir.next_entry().await { - if ent.file_type().await.map_or(false, |t| t.is_file() || t.is_symlink()) { - let osname = ent.file_name(); - let name = osname.to_string_lossy(); - if name.starts_with("M") && name.ends_with(".yaml") { - if let Ok(member_address) = Address::from_str(&name[1..name.len() - 5]) { - members.push(member_address); - } - } - } - } - Ok(members) + async fn list_members(&self, network_id: &NetworkId) -> Result, database::Error> { + Ok(self + .data + .lock() + .await + .0 + .get(network_id) + .map_or_else(|| Vec::new(), |x| x.members.keys().cloned().collect())) } - async fn get_member(&self, network_id: &NetworkId, node_id: &Address) -> Result, Error> { - let mut member = Self::load_object::(self.member_path(&network_id, node_id).as_path()).await?; - if let Some(member) = member.as_mut() { - if member.network_id.eq(network_id) { - // Also auto-update member network IDs, see get_network(). - member.network_id = network_id.clone(); - self.save_member(member.clone(), false).await?; - } - } - Ok(member) + async fn get_member(&self, network_id: &NetworkId, node_id: &PartialAddress) -> Result, database::Error> { + Ok(self + .data + .lock() + .await + .0 + .get_mut(network_id) + .and_then(|x| node_id.find_unique_match(&x.members).cloned())) } - async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Error> { - if !generate_change_notification { - let _ = self.cache.on_member_updated(obj.clone()); - } - let base_member_path = self.member_path(&obj.network_id, &obj.node_id); - let _ = fs::create_dir_all(base_member_path.parent().unwrap()).await; - let _ = fs::write(base_member_path, serde_yaml::to_string(&obj)?.as_bytes()).await?; - Ok(()) - } + async fn save_member(&self, mut obj: Member, generate_change_notification: bool) -> Result<(), database::Error> { + let mut data = self.data.lock().await; + if let Some(nw) = data.0.get_mut(&obj.network_id) { + if let Some(member) = obj.node_id.find_unique_match_mut(&mut nw.members) { + if !obj.eq(member) { + if member.node_id.specificity_bytes() != obj.node_id.specificity_bytes() { + // If the specificity of the node_id has changed we have to delete and re-add the entry. - async fn changes(&self) -> Option> { - Some(self.change_sender.subscribe()) - } + let old_node_id = member.node_id.clone(); + let old = nw.members.remove(&old_node_id); - async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error> { - println!("{}", obj.to_string()); - Ok(()) - } -} + if old_node_id.specificity_bytes() > obj.node_id.specificity_bytes() { + obj.node_id = old_node_id; + } -#[cfg(test)] -mod tests { - #[allow(unused_imports)] - use super::*; - use std::sync::atomic::{AtomicUsize, Ordering}; - use zerotier_network_hypervisor::vl1::identity::Identity; + nw.members.insert(obj.node_id.clone(), obj.clone()); - /* TODO - #[allow(unused)] - #[test] - fn test_db() { - if let Ok(tokio_runtime) = zerotier_utils::tokio::runtime::Builder::new_current_thread().enable_all().build() { - let _ = tokio_runtime.block_on(async { - let node_id = Address::from_u64(0xdeadbeefu64).unwrap(); - let network_id = NetworkId::from_u64(0xfeedbeefcafebabeu64).unwrap(); - - let test_dir = std::env::temp_dir().join("zt_filedatabase_test"); - println!("test filedatabase is in: {}", test_dir.as_os_str().to_str().unwrap()); - - let _ = std::fs::remove_dir_all(&test_dir); - let controller_id = Identity::generate(false); - - assert!(fs::create_dir_all(&test_dir).await.is_ok()); - let db = Arc::new( - FileDatabase::new(tokio_runtime.handle().clone(), test_dir, controller_id.public.address.clone()) - .await - .expect("new db"), - ); - - let change_count = Arc::new(AtomicUsize::new(0)); - - let db2 = db.clone(); - let change_count2 = change_count.clone(); - tokio_runtime.spawn(async move { - let mut change_receiver = db2.changes().await.unwrap(); - loop { - if let Ok(change) = change_receiver.recv().await { - change_count2.fetch_add(1, Ordering::SeqCst); - //println!("[FileDatabase] {:#?}", change); - } else { - break; + if generate_change_notification { + let _ = self.change_sender.send(Change::MemberChanged(old.unwrap(), obj)); + } + } else { + let old = replace(member, obj); + if generate_change_notification { + let _ = self.change_sender.send(Change::MemberChanged(old, member.clone())); } } - }); - - let mut test_network = Network::new(network_id); - db.save_network(test_network.clone(), true).await.expect("network save error"); - - let mut test_member = Member::new_without_identity(node_id, network_id); - for x in 0..3 { - test_member.name = x.to_string(); - db.save_member(test_member.clone(), true).await.expect("member save error"); - - zerotier_utils::tokio::task::yield_now().await; - sleep(Duration::from_millis(100)).await; - zerotier_utils::tokio::task::yield_now().await; - - let test_member2 = db.get_member(&network_id, &node_id).await.unwrap().unwrap(); - assert!(test_member == test_member2); + let _ = self.file_write_notify_sender.send(()).await; } - }); + } else { + let _ = nw.members.insert(obj.node_id.clone(), obj.clone()); + if generate_change_notification { + let _ = self.change_sender.send(Change::MemberCreated(obj)); + } + let _ = self.file_write_notify_sender.send(()).await; + } } + return Ok(()); + } + + async fn log_request(&self, obj: RequestLogItem) -> Result<(), database::Error> { + if let Some(log) = self.log.as_ref() { + let mut json_line = zerotier_utils::json::to_json(&obj); + json_line.push('\n'); + let _ = log.lock().await.write_all(json_line.as_bytes()).await; + } + Ok(()) + } + + async fn changes(&self) -> Option> { + Some(self.change_sender.subscribe()) } - */ } diff --git a/controller/src/lib.rs b/controller/src/lib.rs index 0d663136d..597e14d0e 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -2,8 +2,6 @@ mod controller; -pub(crate) mod cache; - pub mod database; pub mod filedatabase; pub mod model; diff --git a/controller/src/main.rs b/controller/src/main.rs index ea058fb6f..e80967639 100644 --- a/controller/src/main.rs +++ b/controller/src/main.rs @@ -1,5 +1,7 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently proprietary pending actual release and licensing. See LICENSE.md. +use std::path::Path; +use std::str::FromStr; use std::sync::Arc; use zerotier_network_controller::database::Database; @@ -7,23 +9,24 @@ use zerotier_network_controller::filedatabase::FileDatabase; use zerotier_network_controller::Controller; use zerotier_network_hypervisor::vl1::identity::IdentitySecret; use zerotier_network_hypervisor::{VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION}; +use zerotier_service::vl1::{VL1Service, VL1Settings}; use zerotier_utils::exitcode; +use zerotier_utils::tokio; use zerotier_utils::tokio::runtime::Runtime; -use zerotier_vl1_service::VL1Service; -async fn run(identity: IdentitySecret, runtime: &Runtime) -> i32 { - match Controller::new(database.clone(), runtime.handle().clone()).await { +async fn run(database: Arc, identity: IdentitySecret, runtime: &Runtime) -> i32 { + match Controller::new(runtime.handle().clone(), identity.clone(), database.clone()).await { Err(err) => { eprintln!("FATAL: error initializing handler: {}", err.to_string()); exitcode::ERR_CONFIG } - Ok(handler) => match VL1Service::new(identity, handler.clone(), zerotier_vl1_service::VL1Settings::default()) { + Ok(handler) => match VL1Service::new(identity, handler.clone(), VL1Settings::default()) { Err(err) => { eprintln!("FATAL: error launching service: {}", err.to_string()); exitcode::ERR_IOERR } Ok(svc) => { - svc.node().init_default_roots(); + svc.node.init_default_roots(); handler.start(&svc).await; zerotier_utils::wait_for_process_abort(); println!("Terminate signal received, shutting down..."); @@ -36,13 +39,33 @@ async fn run(identity: IdentitySecret, runtime: &Runtime) -> i32 { fn main() { const REQUIRE_ONE_OF_ARGS: [&'static str; 2] = ["postgres", "filedb"]; let global_args = clap::Command::new("zerotier-controller") + .arg( + clap::Arg::new("identity") + .short('i') + .long("identity") + .takes_value(true) + .forbid_empty_values(true) + .value_name("identity") + .help(Some("Path to secret ZeroTier identity")) + .required(true), + ) + .arg( + clap::Arg::new("logfile") + .short('l') + .long("logfile") + .takes_value(true) + .forbid_empty_values(true) + .value_name("logfile") + .help(Some("Path to log file")) + .required(false), + ) .arg( clap::Arg::new("filedb") .short('f') .long("filedb") .takes_value(true) .forbid_empty_values(true) - .value_name("path") + .value_name("filedb") .help(Some("Use filesystem database at path")) .required_unless_present_any(&REQUIRE_ONE_OF_ARGS), ) @@ -52,8 +75,8 @@ fn main() { .long("postgres") .takes_value(true) .forbid_empty_values(true) - .value_name("path") - .help(Some("Connect to postgres with parameters in YAML file")) + .value_name("postgres") + .help(Some("Connect to postgres with supplied URL")) .required_unless_present_any(&REQUIRE_ONE_OF_ARGS), ) .version(format!("{}.{}.{}", VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION).as_str()) @@ -64,23 +87,39 @@ fn main() { std::process::exit(exitcode::ERR_USAGE); }); - if let Ok(tokio_runtime) = zerotier_utils::tokio::runtime::Builder::new_multi_thread().enable_all().build() { + if let Ok(tokio_runtime) = tokio::runtime::Builder::new_multi_thread().enable_all().build() { tokio_runtime.block_on(async { - if let Some(filedb_base_path) = global_args.value_of("filedb") { - let file_db = FileDatabase::new(tokio_runtime.handle().clone(), filedb_base_path).await; + let identity = if let Ok(identity_data) = tokio::fs::read(global_args.value_of("identity").unwrap()).await { + if let Ok(identity) = IdentitySecret::from_str(String::from_utf8_lossy(identity_data.as_slice()).as_ref()) { + identity + } else { + eprintln!("FATAL: invalid secret identity"); + std::process::exit(exitcode::ERR_CONFIG); + } + } else { + eprintln!("FATAL: unable to read secret identity"); + std::process::exit(exitcode::ERR_IOERR); + }; + + let db: Arc = if let Some(filedb_path) = global_args.value_of("filedb") { + let file_db = FileDatabase::new(Path::new(filedb_path), global_args.value_of("logfile").map(|l| Path::new(l))).await; if file_db.is_err() { eprintln!( "FATAL: unable to open filesystem database at {}: {}", - filedb_base_path, + filedb_path, file_db.as_ref().err().unwrap().to_string() ); std::process::exit(exitcode::ERR_IOERR) } - std::process::exit(run(file_db.unwrap(), &tokio_runtime).await); + file_db.unwrap() + } else if let Some(_postgres_url) = global_args.value_of("postgres") { + panic!("not implemented yet"); } else { eprintln!("FATAL: no database type selected."); std::process::exit(exitcode::ERR_USAGE); }; + + std::process::exit(run(db, identity, &tokio_runtime).await); }); } else { eprintln!("FATAL: can't start async runtime"); diff --git a/controller/src/model/member.rs b/controller/src/model/member.rs index 5c3e4e755..b73d3e9ce 100644 --- a/controller/src/model/member.rs +++ b/controller/src/model/member.rs @@ -5,17 +5,28 @@ use std::hash::Hash; use serde::{Deserialize, Serialize}; -use zerotier_network_hypervisor::vl1::{Address, InetAddress}; +use zerotier_crypto::typestate::Valid; +use zerotier_network_hypervisor::vl1::identity::Identity; +use zerotier_network_hypervisor::vl1::{InetAddress, PartialAddress}; use zerotier_network_hypervisor::vl2::NetworkId; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] pub struct Member { + /// Member node ID + /// + /// This can be a partial address if it was manually added as such by a user. As soon as a node matching + /// this partial is seen, this will be replaced by a full specificity PartialAddress from the querying + /// node's full identity. The 'identity' field will also be populated in this case. #[serde(rename = "address")] - pub node_id: Address, + pub node_id: PartialAddress, #[serde(rename = "networkId")] pub network_id: NetworkId, + /// Full identity of this node, if known. + #[serde(skip_serializing_if = "Option::is_none")] + pub identity: Option, + /// A short name that can also be used for DNS, etc. #[serde(skip_serializing_if = "String::is_empty")] #[serde(default)] @@ -68,11 +79,11 @@ pub struct Member { } impl Member { - /// Create a new network member without specifying a "pinned" identity. - pub fn new(node_id: Address, network_id: NetworkId) -> Self { + pub fn new(node_identity: Valid, network_id: NetworkId) -> Self { Self { - node_id, + node_id: node_identity.address.to_partial(), network_id, + identity: Some(node_identity.remove_typestate()), name: String::new(), last_authorized_time: None, last_deauthorized_time: None, @@ -85,7 +96,8 @@ impl Member { } } - /// Check whether this member is authorized, which is true if the last authorized time is after last deauthorized time. + /// Check whether this member is authorized. + /// This is true if the last authorized time is after last deauthorized time. pub fn authorized(&self) -> bool { self.last_authorized_time .map_or(false, |la| self.last_deauthorized_time.map_or(true, |ld| la > ld)) diff --git a/controller/src/model/mod.rs b/controller/src/model/mod.rs index fab148844..65cb65154 100644 --- a/controller/src/model/mod.rs +++ b/controller/src/model/mod.rs @@ -6,8 +6,6 @@ mod network; pub use member::*; pub use network::*; -use std::collections::HashMap; - use serde::{Deserialize, Serialize}; use zerotier_network_hypervisor::vl1::{Address, Endpoint}; @@ -20,13 +18,6 @@ pub enum RecordType { RequestLogItem, } -/// A complete network with all member configuration information for import/export or blob storage. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct NetworkExport { - pub network: Network, - pub members: HashMap, -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[repr(u8)] pub enum AuthenticationResult { diff --git a/network-hypervisor/src/vl1/address.rs b/network-hypervisor/src/vl1/address.rs index 4de542530..7aca0d4fb 100644 --- a/network-hypervisor/src/vl1/address.rs +++ b/network-hypervisor/src/vl1/address.rs @@ -1,8 +1,10 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently proprietary pending actual release and licensing. See LICENSE.md. use std::borrow::Borrow; +use std::collections::BTreeMap; use std::fmt::Debug; use std::hash::Hash; +use std::ops::Bound; use std::str::FromStr; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -24,8 +26,11 @@ pub struct Address(pub(super) [u8; Self::SIZE_BYTES]); /// A partial address, which is bytes and the number of bytes of specificity (similar to a CIDR IP address). /// /// Partial addresses are looked up to get full addresses (and identities) via roots using WHOIS messages. -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct PartialAddress(pub(super) Address, pub(super) u16); +#[derive(Clone, PartialEq, Eq)] +pub struct PartialAddress { + pub(super) address: Address, + pub(super) specificity: u16, +} impl Address { pub const SIZE_BYTES: usize = 48; @@ -60,21 +65,25 @@ impl Address { } /// Get a partial address object (with full specificity) for this address - #[inline(always)] + #[inline] pub fn to_partial(&self) -> PartialAddress { - PartialAddress(Address(self.0), Self::SIZE_BYTES as u16) + PartialAddress { + address: Address(self.0), + specificity: Self::SIZE_BYTES as u16, + } } /// Get a partial address covering the 40-bit legacy address. + #[inline] pub fn to_legacy_partial(&self) -> PartialAddress { - PartialAddress( - Address({ + PartialAddress { + address: Address({ let mut tmp = [0u8; PartialAddress::MAX_SIZE_BYTES]; tmp[..PartialAddress::LEGACY_SIZE_BYTES].copy_from_slice(&self.0[..PartialAddress::LEGACY_SIZE_BYTES]); tmp }), - PartialAddress::LEGACY_SIZE_BYTES as u16, - ) + specificity: PartialAddress::LEGACY_SIZE_BYTES as u16, + } } #[inline(always)] @@ -195,8 +204,11 @@ impl PartialAddress { && b[0] != Address::RESERVED_PREFIX && b[..Self::LEGACY_SIZE_BYTES].iter().any(|i| *i != 0) { - let mut a = Self(Address([0u8; Address::SIZE_BYTES]), b.len() as u16); - a.0 .0[..b.len()].copy_from_slice(b); + let mut a = Self { + address: Address([0u8; Address::SIZE_BYTES]), + specificity: b.len() as u16, + }; + a.address.0[..b.len()].copy_from_slice(b); Ok(a) } else { Err(InvalidParameterError("invalid address")) @@ -206,14 +218,14 @@ impl PartialAddress { #[inline] pub(crate) fn from_legacy_address_bytes(b: &[u8; 5]) -> Result { if b[0] != Address::RESERVED_PREFIX && b.iter().any(|i| *i != 0) { - Ok(Self( - Address({ + Ok(Self { + address: Address({ let mut tmp = [0u8; Self::MAX_SIZE_BYTES]; tmp[..5].copy_from_slice(b); tmp }), - Self::LEGACY_SIZE_BYTES as u16, - )) + specificity: Self::LEGACY_SIZE_BYTES as u16, + }) } else { Err(InvalidParameterError("invalid address")) } @@ -223,14 +235,14 @@ impl PartialAddress { pub(crate) fn from_legacy_address_u64(mut b: u64) -> Result { b &= 0xffffffffff; if b.wrapping_shr(32) != (Address::RESERVED_PREFIX as u64) && b != 0 { - Ok(Self( - Address({ + Ok(Self { + address: Address({ let mut tmp = [0u8; Self::MAX_SIZE_BYTES]; tmp[..5].copy_from_slice(&b.to_be_bytes()[..5]); tmp }), - Self::LEGACY_SIZE_BYTES as u16, - )) + specificity: Self::LEGACY_SIZE_BYTES as u16, + }) } else { Err(InvalidParameterError("invalid address")) } @@ -238,45 +250,60 @@ impl PartialAddress { #[inline(always)] pub fn as_bytes(&self) -> &[u8] { - debug_assert!(self.1 >= Self::MIN_SIZE_BYTES as u16); - &self.0 .0[..self.1 as usize] + debug_assert!(self.specificity >= Self::MIN_SIZE_BYTES as u16); + &self.address.0[..self.specificity as usize] } #[inline(always)] pub(crate) fn legacy_bytes(&self) -> &[u8; 5] { - debug_assert!(self.1 >= Self::MIN_SIZE_BYTES as u16); - memory::array_range::(&self.0 .0) + debug_assert!(self.specificity >= Self::MIN_SIZE_BYTES as u16); + memory::array_range::(&self.address.0) } #[inline(always)] pub(crate) fn legacy_u64(&self) -> u64 { - u64::from_be(memory::load_raw(&self.0 .0)).wrapping_shr(24) + u64::from_be(memory::load_raw(&self.address.0)).wrapping_shr(24) } + /// Returns true if this partial address matches a full length address up to this partial's specificity. #[inline(always)] - pub(super) fn matches(&self, k: &Address) -> bool { - debug_assert!(self.1 >= Self::MIN_SIZE_BYTES as u16); - let l = self.1 as usize; - self.0 .0[..l].eq(&k.0[..l]) + pub fn matches(&self, k: &Address) -> bool { + debug_assert!(self.specificity >= Self::MIN_SIZE_BYTES as u16); + let l = self.specificity as usize; + self.address.0[..l].eq(&k.0[..l]) + } + + /// Returns true if this partial address matches another up to the lower of the two addresses' specificities. + #[inline(always)] + pub fn matches_partial(&self, k: &PartialAddress) -> bool { + debug_assert!(self.specificity >= Self::MIN_SIZE_BYTES as u16); + let l = self.specificity.min(k.specificity) as usize; + self.address.0[..l].eq(&k.address.0[..l]) } /// Get the number of bits of specificity in this address #[inline(always)] - pub fn specificity(&self) -> usize { - (self.1 * 8) as usize + pub fn specificity_bits(&self) -> usize { + (self.specificity * 8) as usize + } + + /// Get the number of bytes of specificity in this address (only 8 bit increments in specificity are allowed) + #[inline(always)] + pub fn specificity_bytes(&self) -> usize { + self.specificity as usize } /// Returns true if this address has legacy 40 bit specificity (V1 ZeroTier address) #[inline(always)] pub fn is_legacy(&self) -> bool { - self.1 == Self::LEGACY_SIZE_BYTES as u16 + self.specificity == Self::LEGACY_SIZE_BYTES as u16 } - /// Get a full length address if this partial address is actually complete (384 bits of specificity) - #[inline(always)] - pub fn as_address(&self) -> Option<&Address> { - if self.1 == Self::MAX_SIZE_BYTES as u16 { - Some(&self.0) + /// Get a complete address from this partial if it is in fact complete. + #[inline] + pub fn as_complete(&self) -> Option<&Address> { + if self.specificity == Self::MAX_SIZE_BYTES as u16 { + Some(&self.address) } else { None } @@ -285,14 +312,84 @@ impl PartialAddress { /// Returns true if specificity is at the maximum value (384 bits) #[inline(always)] pub fn is_complete(&self) -> bool { - self.1 == Self::MAX_SIZE_BYTES as u16 + self.specificity == Self::MAX_SIZE_BYTES as u16 + } + + /// Efficiently find an entry in a BTreeMap of partial addresses that uniquely matches this partial. + /// + /// This returns None if there is no match or if this partial matches more than one entry, in which + /// case it's ambiguous and may be unsafe to use. This should be prohibited at other levels of the + /// system but is checked for here as well. + #[inline] + pub fn find_unique_match<'a, T>(&self, map: &'a BTreeMap) -> Option<&'a T> { + // Search for an exact or more specific match. + let mut m = None; + + // First search for exact or more specific matches, which would appear later in the sorted key list. + let mut pos = map.range((Bound::Included(self), Bound::Unbounded)); + while let Some(e) = pos.next() { + if self.matches_partial(e.0) { + if m.is_some() { + // Ambiguous! + return None; + } + let _ = m.insert(e.1); + } else { + break; + } + } + + // Then search for less specific matches or verify that the match we found above is not ambiguous. + let mut pos = map.range((Bound::Unbounded, Bound::Excluded(self))); + while let Some(e) = pos.next_back() { + if self.matches_partial(e.0) { + if m.is_some() { + return None; + } + let _ = m.insert(e.1); + } else { + break; + } + } + + return m; + } + + /// Efficiently find an entry in a BTreeMap of partial addresses that uniquely matches this partial. + /// + /// This returns None if there is no match or if this partial matches more than one entry, in which + /// case it's ambiguous and may be unsafe to use. This should be prohibited at other levels of the + /// system but is checked for here as well. + #[inline] + pub fn find_unique_match_mut<'a, T>(&self, map: &'a mut BTreeMap) -> Option<&'a mut T> { + // This not only saves some repetition but is in fact the only way to easily do this. The same code as + // find_unique_match() but with range_mut() doesn't compile because the second range_mut() would + // borrow 'map' a second time (since 'm' may have it borrowed). This is primarily due to the too-limited + // API of BTreeMap which is missing a good way to find the nearest match. This should be safe since + // we do not mutate the map and the signature of find_unique_match_mut() should properly guarantee + // that the semantics of mutable references are obeyed in the calling context. + unsafe { std::mem::transmute(self.find_unique_match::(map)) } + } +} + +impl Ord for PartialAddress { + #[inline(always)] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.address.cmp(&other.address).then(self.specificity.cmp(&other.specificity)) + } +} + +impl PartialOrd for PartialAddress { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } impl ToString for PartialAddress { fn to_string(&self) -> String { if self.is_legacy() { - hex::to_string(&self.0 .0[..Self::LEGACY_SIZE_BYTES]) + hex::to_string(&self.address.0[..Self::LEGACY_SIZE_BYTES]) } else { base24::encode(self.as_bytes()) } @@ -315,7 +412,7 @@ impl Hash for PartialAddress { #[inline(always)] fn hash(&self, state: &mut H) { // Since this contains a random hash, the first 64 bits should be enough for a local HashMap etc. - state.write_u64(memory::load_raw(&self.0 .0)) + state.write_u64(memory::load_raw(&self.address.0)) } } diff --git a/network-hypervisor/src/vl1/api.rs b/network-hypervisor/src/vl1/api.rs index ceb364655..056cccaa4 100644 --- a/network-hypervisor/src/vl1/api.rs +++ b/network-hypervisor/src/vl1/api.rs @@ -97,6 +97,17 @@ pub trait ApplicationLayer: Sync + Send + 'static { /// Called to get the current time in milliseconds since epoch from the real-time clock. /// This needs to be accurate to about one second resolution or better. fn time_clock(&self) -> i64; + + /// Get this application implementation cast to its concrete type. + /// + /// The default implementation just returns None, but this can be implemented using the cast_ref() + /// function in zerotier_utils::cast to return the concrete implementation of this type. It's exposed + /// in this interface for convenience since it's common for inner protocol or other handlers to want + /// to get 'app' as its concrete type to access internal fields and methods. Implement it if possible + /// and convenient. + fn concrete_self(&self) -> Option<&T> { + None + } } /// Result of a packet handler in the InnerProtocolLayer trait. diff --git a/network-hypervisor/src/vl1/node.rs b/network-hypervisor/src/vl1/node.rs index a2b552208..d61d51f14 100644 --- a/network-hypervisor/src/vl1/node.rs +++ b/network-hypervisor/src/vl1/node.rs @@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex, RwLock}; use std::time::Duration; use super::address::{Address, PartialAddress}; -use super::api::{ApplicationLayer, InnerProtocolLayer, PacketHandlerResult}; +use super::api::{ApplicationLayer, InnerProtocolLayer}; use super::debug_event; use super::endpoint::Endpoint; use super::event::Event; diff --git a/network-hypervisor/src/vl1/peermap.rs b/network-hypervisor/src/vl1/peermap.rs index f6d11122e..b4094eacd 100644 --- a/network-hypervisor/src/vl1/peermap.rs +++ b/network-hypervisor/src/vl1/peermap.rs @@ -39,8 +39,11 @@ impl PeerMap { /// Get a matching peer for a partial address of any specificity, but return None if the match is ambiguous. pub fn get_unambiguous(&self, address: &PartialAddress) -> Option>> { - let mm = self.maps[address.0 .0[0] as usize].read().unwrap(); - let matches = mm.range::<[u8; 48], (Bound<&[u8; 48]>, Bound<&[u8; 48]>)>((Bound::Included(&address.0 .0), Bound::Unbounded)); + let mm = self.maps[address.address.0[0] as usize].read().unwrap(); + let matches = mm.range::<[u8; Address::SIZE_BYTES], (Bound<&[u8; Address::SIZE_BYTES]>, Bound<&[u8; Address::SIZE_BYTES]>)>(( + Bound::Included(&address.address.0), + Bound::Unbounded, + )); let mut r = None; for m in matches { if address.matches(m.0) { diff --git a/service/src/vl1/vl1service.rs b/service/src/vl1/vl1service.rs index 395985bf0..69963c8fe 100644 --- a/service/src/vl1/vl1service.rs +++ b/service/src/vl1/vl1service.rs @@ -10,6 +10,7 @@ use zerotier_crypto::random; use zerotier_network_hypervisor::protocol::{PacketBufferFactory, PacketBufferPool}; use zerotier_network_hypervisor::vl1::identity::IdentitySecret; use zerotier_network_hypervisor::vl1::*; +use zerotier_utils::cast::cast_ref; use zerotier_utils::{ms_monotonic, ms_since_epoch}; use super::vl1settings::{VL1Settings, UNASSIGNED_PRIVILEGED_PORTS}; @@ -67,7 +68,8 @@ impl VL1Service { Ok(service) } - pub fn get(&self) -> Arc { + #[inline] + pub fn get_self_arc(&self) -> Arc { self.self_ref.upgrade().unwrap() } @@ -281,6 +283,11 @@ impl ApplicationLayer for VL1Service fn time_clock(&self) -> i64 { ms_since_epoch() } + + #[inline(always)] + fn concrete_self(&self) -> Option<&T> { + cast_ref(self) + } } impl Drop for VL1Service {