diff --git a/controller/src/database.rs b/controller/src/database.rs index b12237800..89fa965a2 100644 --- a/controller/src/database.rs +++ b/controller/src/database.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use std::error::Error; use zerotier_network_hypervisor::vl1::{Address, InetAddress, NodeStorage}; use zerotier_network_hypervisor::vl2::NetworkId; @@ -17,14 +18,12 @@ pub enum Change { #[async_trait] pub trait Database: Sync + Send + NodeStorage + 'static { - type Error: std::error::Error + Send + 'static; + async fn get_network(&self, id: NetworkId) -> Result, Box>; + async fn save_network(&self, obj: Network) -> Result<(), Box>; - async fn get_network(&self, id: NetworkId) -> Result, Self::Error>; - async fn save_network(&self, obj: Network) -> Result<(), Self::Error>; - - async fn list_members(&self, network_id: NetworkId) -> Result, Self::Error>; - async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Self::Error>; - async fn save_member(&self, obj: Member) -> Result<(), Self::Error>; + async fn list_members(&self, network_id: NetworkId) -> Result, Box>; + async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Box>; + async fn save_member(&self, obj: Member) -> Result<(), Box>; /// Get a receiver that can be used to receive changes made to networks and members, if supported. /// @@ -45,7 +44,11 @@ pub trait Database: Sync + Send + NodeStorage + 'static { /// /// The default trait implementation uses a brute force method. This should be reimplemented if a /// more efficient way is available. - async fn list_members_deauthorized_after(&self, network_id: NetworkId, cutoff: i64) -> Result, Self::Error> { + async fn list_members_deauthorized_after( + &self, + network_id: NetworkId, + cutoff: i64, + ) -> Result, Box> { let mut v = Vec::new(); let members = self.list_members(network_id).await?; for a in members.iter() { @@ -62,7 +65,7 @@ pub trait Database: Sync + Send + NodeStorage + 'static { /// /// The default trait implementation uses a brute force method. This should be reimplemented if a /// more efficient way is available. - async fn is_ip_assigned(&self, network_id: NetworkId, ip: &InetAddress) -> Result { + async fn is_ip_assigned(&self, network_id: NetworkId, ip: &InetAddress) -> Result> { let members = self.list_members(network_id).await?; for a in members.iter() { if let Some(m) = self.get_member(network_id, *a).await? { @@ -74,5 +77,5 @@ pub trait Database: Sync + Send + NodeStorage + 'static { return Ok(false); } - async fn log_request(&self, obj: RequestLogItem) -> Result<(), Self::Error>; + async fn log_request(&self, obj: RequestLogItem) -> Result<(), Box>; } diff --git a/controller/src/filedatabase.rs b/controller/src/filedatabase.rs index fae0ea12c..8960e21b5 100644 --- a/controller/src/filedatabase.rs +++ b/controller/src/filedatabase.rs @@ -1,5 +1,4 @@ use std::error::Error; -use std::fmt::Display; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::atomic::{AtomicU64, Ordering}; @@ -21,35 +20,6 @@ use crate::model::*; const IDENTITY_SECRET_FILENAME: &'static str = "identity.secret"; -#[derive(Debug)] -pub enum FileDatabaseError { - InvalidYaml(String), - IoError(String), -} - -impl Display for FileDatabaseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::InvalidYaml(e) => f.write_str(format!("invalid YAML ({})", e).as_str()), - Self::IoError(e) => f.write_str(format!("I/O error ({})", e).as_str()), - } - } -} - -impl Error for FileDatabaseError {} - -impl From for FileDatabaseError { - fn from(e: serde_yaml::Error) -> Self { - Self::InvalidYaml(e.to_string()) - } -} - -impl From for FileDatabaseError { - fn from(e: zerotier_utils::tokio::io::Error) -> Self { - Self::IoError(e.to_string()) - } -} - /// An in-filesystem database that permits live editing. /// /// A cache is maintained that contains the actual objects. When an object is live edited, @@ -65,7 +35,7 @@ pub struct FileDatabase { // TODO: should cache at least hashes and detect changes in the filesystem live. impl FileDatabase { - pub async fn new>(base_path: P) -> Result> { + pub async fn new>(base_path: P) -> Result> { let base_path: PathBuf = base_path.as_ref().into(); let _ = fs::create_dir_all(&base_path).await?; @@ -86,6 +56,11 @@ impl FileDatabase { } } })?); + let _ = watcher.configure( + notify::Config::default() + .with_compare_contents(true) + .with_poll_interval(std::time::Duration::from_secs(2)), + ); watcher.watch(&base_path, RecursiveMode::Recursive)?; Ok(Self { @@ -151,9 +126,7 @@ impl NodeStorage for FileDatabase { #[async_trait] impl Database for FileDatabase { - type Error = FileDatabaseError; - - async fn get_network(&self, id: NetworkId) -> Result, Self::Error> { + async fn get_network(&self, id: NetworkId) -> Result, Box> { let r = fs::read(self.network_path(id)).await; if let Ok(raw) = r { let mut network = serde_yaml::from_slice::(raw.as_slice())?; @@ -166,7 +139,7 @@ impl Database for FileDatabase { } } - async fn save_network(&self, obj: Network) -> Result<(), Self::Error> { + async fn save_network(&self, obj: Network) -> Result<(), Box> { let base_network_path = self.network_path(obj.id); let _ = fs::create_dir_all(base_network_path.parent().unwrap()).await; //let _ = fs::write(base_network_path, to_json_pretty(&obj).as_bytes()).await?; @@ -174,7 +147,7 @@ impl Database for FileDatabase { return Ok(()); } - async fn list_members(&self, network_id: NetworkId) -> Result, Self::Error> { + async fn list_members(&self, network_id: NetworkId) -> Result, Box> { let mut members = Vec::new(); let mut dir = fs::read_dir(self.base_path.join(network_id.to_string())).await?; while let Ok(Some(ent)) = dir.next_entry().await { @@ -194,7 +167,7 @@ impl Database for FileDatabase { Ok(members) } - async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Self::Error> { + async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Box> { let r = fs::read(self.member_path(network_id, node_id)).await; if let Ok(raw) = r { let mut member = serde_yaml::from_slice::(raw.as_slice())?; @@ -207,7 +180,7 @@ impl Database for FileDatabase { } } - async fn save_member(&self, obj: Member) -> Result<(), Self::Error> { + async fn save_member(&self, obj: Member) -> Result<(), Box> { let base_member_path = self.member_path(obj.network_id, obj.node_id); let _ = fs::create_dir_all(base_member_path.parent().unwrap()).await; //let _ = fs::write(base_member_path, to_json_pretty(&obj).as_bytes()).await?; @@ -219,7 +192,7 @@ impl Database for FileDatabase { Some(self.change_sender.subscribe()) } - async fn log_request(&self, obj: RequestLogItem) -> Result<(), Self::Error> { + async fn log_request(&self, obj: RequestLogItem) -> Result<(), Box> { println!("{}", obj.to_string()); Ok(()) } diff --git a/controller/src/handler.rs b/controller/src/handler.rs index 3578e0f09..3f9c18a25 100644 --- a/controller/src/handler.rs +++ b/controller/src/handler.rs @@ -25,26 +25,26 @@ use crate::model::{AuthorizationResult, Member, RequestLogItem, CREDENTIAL_WINDO const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); /// ZeroTier VL2 network controller packet handler, answers VL2 netconf queries. -pub struct Handler { - inner: Arc>, +pub struct Handler { + inner: Arc, } -struct Inner { - service: RwLock, Handler>>>, +struct Inner { + service: RwLock>>, reaper: Reaper, daemons: Mutex>>, // drop() aborts these runtime: tokio::runtime::Handle, - database: Arc, + database: Arc, local_identity: Identity, } -impl Handler { +impl Handler { /// Start an inner protocol handler answer ZeroTier VL2 network controller queries. - pub async fn new(database: Arc, runtime: tokio::runtime::Handle) -> Result, Box> { + pub async fn new(database: Arc, runtime: tokio::runtime::Handle) -> Result, Box> { if let Some(local_identity) = database.load_node_identity() { assert!(local_identity.secret.is_some()); - let inner = Arc::new(Inner:: { + let inner = Arc::new(Inner { service: RwLock::new(Weak::default()), reaper: Reaper::new(&runtime), daemons: Mutex::new(Vec::with_capacity(1)), @@ -69,7 +69,7 @@ impl Handler { /// won't actually do anything. The reference the handler holds is weak to prevent /// a circular reference, so if the VL1Service is dropped this must be called again to /// tell the controller handler about a new instance. - pub fn set_service(&self, service: &Arc>) { + pub fn set_service(&self, service: &Arc>) { *self.inner.service.write().unwrap() = Arc::downgrade(service); } @@ -98,9 +98,9 @@ impl Handler { } // Default PathFilter implementations permit anything. -impl PathFilter for Handler {} +impl PathFilter for Handler {} -impl InnerProtocol for Handler { +impl InnerProtocol for Handler { fn handle_packet( &self, _host_system: &HostSystemImpl, @@ -245,7 +245,7 @@ impl InnerProtocol for Handler { } } -impl Inner { +impl Inner { fn send_network_config( &self, peer: &Peer, @@ -299,14 +299,14 @@ impl Inner { } } - async fn handle_change_notification(self: Arc, change: Change) {} + async fn handle_change_notification(self: Arc, _change: Change) {} async fn handle_network_config_request( self: &Arc, source_identity: &Identity, network_id: NetworkId, now: i64, - ) -> Result<(AuthorizationResult, Option), DatabaseImpl::Error> { + ) -> Result<(AuthorizationResult, Option), Box> { let network = self.database.get_network(network_id).await?; if network.is_none() { // TODO: send error @@ -424,7 +424,7 @@ impl Inner { } } -impl Drop for Inner { +impl Drop for Inner { fn drop(&mut self) { for h in self.daemons.lock().unwrap().drain(..) { h.abort(); diff --git a/controller/src/main.rs b/controller/src/main.rs index 937241214..713e645b8 100644 --- a/controller/src/main.rs +++ b/controller/src/main.rs @@ -15,7 +15,7 @@ use zerotier_utils::exitcode; use zerotier_utils::tokio::runtime::Runtime; use zerotier_vl1_service::VL1Service; -async fn run(database: Arc, runtime: &Runtime) -> i32 { +async fn run(database: Arc, runtime: &Runtime) -> i32 { let handler = Handler::new(database.clone(), runtime.handle().clone()).await; if handler.is_err() { eprintln!("FATAL: error initializing handler: {}", handler.err().unwrap().to_string()); diff --git a/controller/src/model/network.rs b/controller/src/model/network.rs index 9ef3e6cb1..d3a8a7abc 100644 --- a/controller/src/model/network.rs +++ b/controller/src/model/network.rs @@ -141,7 +141,7 @@ fn troo() -> bool { impl Network { /// Check member IP assignments and return 'true' if IP assignments were created or modified. - pub async fn check_zt_ip_assignments(&self, database: &DatabaseImpl, member: &mut Member) -> bool { + pub async fn check_zt_ip_assignments(&self, database: &DatabaseImpl, member: &mut Member) -> bool { let mut modified = false; if self.v4_assign_mode.zt { diff --git a/network-hypervisor/src/vl1/node.rs b/network-hypervisor/src/vl1/node.rs index 66970ea13..42d7a9af7 100644 --- a/network-hypervisor/src/vl1/node.rs +++ b/network-hypervisor/src/vl1/node.rs @@ -32,6 +32,12 @@ use zerotier_utils::thing::Thing; /// These methods are basically callbacks that the core calls to request or transmit things. They are called /// during calls to things like wire_recieve() and do_background_tasks(). pub trait HostSystem: Sync + Send + 'static { + /// Type for implementation of NodeStorage. + type Storage: NodeStorage + ?Sized; + + /// Path filter implementation for this host. + type PathFilter: PathFilter + ?Sized; + /// Type for local system sockets. type LocalSocket: Sync + Send + Hash + PartialEq + Eq + Clone + ToString + Sized + 'static; @@ -41,6 +47,12 @@ pub trait HostSystem: Sync + Send + 'static { /// A VL1 level event occurred. fn event(&self, event: Event); + /// Get a reference to the local storage implementation at this host. + fn storage(&self) -> &Self::Storage; + + /// Get the path filter implementation for this host. + fn path_filter(&self) -> &Self::PathFilter; + /// Get a pooled packet buffer for internal use. fn get_buffer(&self) -> PooledPacketBuffer; @@ -265,21 +277,20 @@ pub struct Node { } impl Node { - pub fn new( + pub fn new( host_system: &HostSystemImpl, - storage: &NodeStorageImpl, auto_generate_identity: bool, auto_upgrade_identity: bool, ) -> Result { let mut id = { - let id = storage.load_node_identity(); + let id = host_system.storage().load_node_identity(); if id.is_none() { if !auto_generate_identity { return Err(InvalidParameterError("no identity found and auto-generate not enabled")); } else { let id = Identity::generate(); host_system.event(Event::IdentityAutoGenerated(id.clone())); - storage.save_node_identity(&id); + host_system.storage().save_node_identity(&id); id } } else { @@ -290,7 +301,7 @@ impl Node { if auto_upgrade_identity { let old = id.clone(); if id.upgrade()? { - storage.save_node_identity(&id); + host_system.storage().save_node_identity(&id); host_system.event(Event::IdentityAutoUpgraded(old, id.clone())); } } diff --git a/network-hypervisor/src/vl1/peer.rs b/network-hypervisor/src/vl1/peer.rs index 1d1b4843a..cb622713e 100644 --- a/network-hypervisor/src/vl1/peer.rs +++ b/network-hypervisor/src/vl1/peer.rs @@ -138,6 +138,8 @@ impl Peer { fn learn_path(&self, host_system: &HostSystemImpl, new_path: &Arc, time_ticks: i64) { let mut paths = self.paths.lock().unwrap(); + // TODO: check path filter + match &new_path.endpoint { Endpoint::IpUdp(new_ip) => { // If this is an IpUdp endpoint, scan the existing paths and replace any that come from diff --git a/vl1-service/src/vl1service.rs b/vl1-service/src/vl1service.rs index 3766bfa32..18be9a497 100644 --- a/vl1-service/src/vl1service.rs +++ b/vl1-service/src/vl1service.rs @@ -27,9 +27,9 @@ const UPDATE_UDP_BINDINGS_EVERY_SECS: usize = 10; /// whatever inner protocol implementation is using it. This would typically be VL2 but could be /// a test harness or just the controller for a controller that runs stand-alone. pub struct VL1Service< - NodeStorageImpl: NodeStorage + 'static, - PathFilterImpl: PathFilter + 'static, - InnerProtocolImpl: InnerProtocol + 'static, + NodeStorageImpl: NodeStorage + ?Sized + 'static, + PathFilterImpl: PathFilter + ?Sized + 'static, + InnerProtocolImpl: InnerProtocol + ?Sized + 'static, > { state: RwLock, storage: Arc, @@ -46,8 +46,11 @@ struct VL1ServiceMutableState { running: bool, } -impl - VL1Service +impl< + NodeStorageImpl: NodeStorage + ?Sized + 'static, + PathFilterImpl: PathFilter + ?Sized + 'static, + InnerProtocolImpl: InnerProtocol + ?Sized + 'static, + > VL1Service { pub fn new( storage: Arc, @@ -55,7 +58,7 @@ impl, settings: VL1Settings, ) -> Result, Box> { - let mut service = VL1Service { + let mut service = Self { state: RwLock::new(VL1ServiceMutableState { daemons: Vec::with_capacity(2), udp_sockets: HashMap::with_capacity(8), @@ -72,7 +75,7 @@ impl UdpPacketHandler - for VL1Service +impl< + NodeStorageImpl: NodeStorage + ?Sized + 'static, + PathFilterImpl: PathFilter + ?Sized + 'static, + InnerProtocolImpl: InnerProtocol + ?Sized + 'static, + > UdpPacketHandler for VL1Service { #[inline(always)] fn incoming_udp_packet( @@ -209,9 +215,14 @@ impl HostSystem - for VL1Service +impl< + NodeStorageImpl: NodeStorage + ?Sized + 'static, + PathFilterImpl: PathFilter + ?Sized + 'static, + InnerProtocolImpl: InnerProtocol + ?Sized + 'static, + > HostSystem for VL1Service { + type Storage = NodeStorageImpl; + type PathFilter = PathFilterImpl; type LocalSocket = crate::LocalSocket; type LocalInterface = crate::LocalInterface; @@ -227,6 +238,16 @@ impl &Self::Storage { + self.storage.as_ref() + } + + #[inline(always)] + fn path_filter(&self) -> &Self::PathFilter { + self.path_filter.as_ref() + } + #[inline] fn get_buffer(&self) -> zerotier_network_hypervisor::protocol::PooledPacketBuffer { self.buffer_pool.get() @@ -306,8 +327,11 @@ impl Drop - for VL1Service +impl< + NodeStorageImpl: NodeStorage + ?Sized + 'static, + PathFilterImpl: PathFilter + ?Sized + 'static, + InnerProtocolImpl: InnerProtocol + ?Sized + 'static, + > Drop for VL1Service { fn drop(&mut self) { let mut state = self.state.write().unwrap();