diff --git a/controller/src/controller.rs b/controller/src/controller.rs index 821aad9e2..b899f69a3 100644 --- a/controller/src/controller.rs +++ b/controller/src/controller.rs @@ -6,7 +6,7 @@ use tokio::time::{Duration, Instant}; use zerotier_utils::tokio; use zerotier_network_hypervisor::protocol::{verbs, PacketBuffer}; -use zerotier_network_hypervisor::vl1::{HostSystem, Identity, InnerProtocol, PacketHandlerResult, Path, Peer}; +use zerotier_network_hypervisor::vl1::{HostSystem, Identity, InnerProtocol, PacketHandlerResult, Path, PathFilter, Peer}; use zerotier_network_hypervisor::vl2::NetworkId; use zerotier_utils::dictionary::Dictionary; @@ -40,6 +40,31 @@ impl Controller { } } +impl PathFilter for Controller { + fn check_path( + &self, + _id: &Identity, + _endpoint: &zerotier_network_hypervisor::vl1::Endpoint, + _local_socket: Option<&HostSystemImpl::LocalSocket>, + _local_interface: Option<&HostSystemImpl::LocalInterface>, + ) -> bool { + true + } + + fn get_path_hints( + &self, + _id: &Identity, + ) -> Option< + Vec<( + zerotier_network_hypervisor::vl1::Endpoint, + Option, + Option, + )>, + > { + None + } +} + impl InnerProtocol for Controller { fn handle_packet( &self, diff --git a/controller/src/database.rs b/controller/src/database.rs index 7a9a2ff4f..bbcde8652 100644 --- a/controller/src/database.rs +++ b/controller/src/database.rs @@ -1,20 +1,20 @@ +use std::error::Error; + use async_trait::async_trait; -use zerotier_network_hypervisor::vl1::Address; +use zerotier_network_hypervisor::vl1::{Address, NodeStorage}; use zerotier_network_hypervisor::vl2::NetworkId; use crate::model::*; #[async_trait] -pub trait Database: Sync + Send + Sized + 'static { - type Error; +pub trait Database: Sync + Send + NodeStorage + 'static { + async fn get_network(&self, id: NetworkId) -> Result, Box>; + async fn save_network(&self, obj: &Network) -> Result<(), Box>; - async fn get_network(&self, id: NetworkId) -> Result, Self::Error>; - async fn save_network(&self, obj: &Network) -> Result<(), Self::Error>; + async fn list_members(&self, network_id: NetworkId) -> Result, Box>; + async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Box>; + async fn save_member(&self, obj: &Member) -> Result<(), Box>; - async fn list_members(&self, network_id: NetworkId) -> Result, Self::Error>; - async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Self::Error>; - async fn save_member(&self, obj: &Member) -> Result<(), Self::Error>; - - async fn log_request(&self, obj: &RequestLogItem) -> Result<(), Self::Error>; + async fn log_request(&self, obj: &RequestLogItem) -> Result<(), Box>; } diff --git a/controller/src/filedatabase.rs b/controller/src/filedatabase.rs index cf7d04cfb..a8940d9f7 100644 --- a/controller/src/filedatabase.rs +++ b/controller/src/filedatabase.rs @@ -1,14 +1,16 @@ use std::error::Error; use std::path::{Path, PathBuf}; -use std::sync::Arc; +use std::str::FromStr; use async_trait::async_trait; use serde::de::DeserializeOwned; use serde::Serialize; -use zerotier_network_hypervisor::vl1::Address; +use zerotier_network_hypervisor::vl1::{Address, Identity, NodeStorage}; use zerotier_network_hypervisor::vl2::NetworkId; + +use zerotier_utils::io::{fs_restrict_permissions, read_limit}; use zerotier_utils::json::{json_patch, to_json_pretty}; use zerotier_utils::tokio::fs; use zerotier_utils::tokio::io::ErrorKind; @@ -16,6 +18,8 @@ use zerotier_utils::tokio::io::ErrorKind; use crate::database::Database; use crate::model::*; +const IDENTITY_SECRET_FILENAME: &'static str = "identity.secret"; + pub struct FileDatabase { base: PathBuf, cache: PathBuf, @@ -30,11 +34,11 @@ fn member_path(base: &PathBuf, network_id: NetworkId, member_id: Address) -> Pat } impl FileDatabase { - pub async fn new>(base_path: P) -> Arc { + pub async fn new>(base_path: P) -> Self { let base: PathBuf = base_path.as_ref().into(); let cache: PathBuf = base_path.as_ref().join("cache"); let _ = fs::create_dir_all(&cache).await; - Arc::new(Self { base, cache }) + Self { base, cache } } /// Merge an object with its cached instance and save the result to the 'cache' path. @@ -59,11 +63,31 @@ impl FileDatabase { } } +impl NodeStorage for FileDatabase { + fn load_node_identity(&self) -> Option { + let id_data = read_limit(self.base.join(IDENTITY_SECRET_FILENAME), 4096); + if id_data.is_err() { + return None; + } + let id_data = Identity::from_str(String::from_utf8_lossy(id_data.unwrap().as_slice()).as_ref()); + if id_data.is_err() { + return None; + } + Some(id_data.unwrap()) + } + + fn save_node_identity(&self, id: &Identity) { + assert!(id.secret.is_some()); + let id_secret_str = id.to_secret_string(); + let secret_path = self.base.join(IDENTITY_SECRET_FILENAME); + assert!(std::fs::write(&secret_path, id_secret_str.as_bytes()).is_ok()); + assert!(fs_restrict_permissions(&secret_path)); + } +} + #[async_trait] impl Database for FileDatabase { - type Error = Box; - - async fn get_network(&self, id: NetworkId) -> Result, Self::Error> { + async fn get_network(&self, id: NetworkId) -> Result, Box> { let r = fs::read(network_path(&self.base, id)).await; if let Ok(raw) = r { let r = serde_json::from_slice::(raw.as_slice()); @@ -83,7 +107,7 @@ impl Database for FileDatabase { } } - async fn save_network(&self, obj: &Network) -> Result<(), Self::Error> { + async fn save_network(&self, obj: &Network) -> Result<(), Box> { let _ = fs::create_dir_all(self.base.join(obj.id.to_string())).await; let _ = fs::create_dir_all(self.cache.join(obj.id.to_string())).await; @@ -97,7 +121,7 @@ impl Database for FileDatabase { Ok(()) } - async fn list_members(&self, network_id: NetworkId) -> Result, Self::Error> { + async fn list_members(&self, network_id: NetworkId) -> Result, Box> { let mut members = Vec::new(); let mut dir = fs::read_dir(self.base.join(network_id.to_string())).await?; while let Ok(Some(ent)) = dir.next_entry().await { @@ -117,7 +141,7 @@ impl Database for FileDatabase { Ok(members) } - async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Self::Error> { + async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Box> { let r = fs::read(member_path(&self.base, network_id, node_id)).await; if let Ok(raw) = r { let r = serde_json::from_slice::(raw.as_slice()); @@ -139,7 +163,7 @@ impl Database for FileDatabase { } } - async fn save_member(&self, obj: &Member) -> Result<(), Self::Error> { + async fn save_member(&self, obj: &Member) -> Result<(), Box> { let base_member_path = member_path(&self.base, obj.network_id, obj.node_id); if !fs::metadata(&base_member_path).await.is_ok() { fs::write(base_member_path, to_json_pretty(obj).as_bytes()).await?; @@ -153,7 +177,7 @@ impl Database for FileDatabase { Ok(()) } - async fn log_request(&self, obj: &RequestLogItem) -> Result<(), Self::Error> { + async fn log_request(&self, obj: &RequestLogItem) -> Result<(), Box> { println!("{}", obj.to_string()); Ok(()) } diff --git a/controller/src/main.rs b/controller/src/main.rs index 0621ea5d7..39d4f0114 100644 --- a/controller/src/main.rs +++ b/controller/src/main.rs @@ -2,44 +2,94 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; - use std::time::Duration; +use clap::{Arg, Command}; + +use zerotier_network_controller::controller::Controller; +use zerotier_network_controller::database::Database; +use zerotier_network_controller::filedatabase::FileDatabase; + +use zerotier_network_hypervisor::{VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION}; use zerotier_utils::exitcode; +use zerotier_utils::tokio::runtime::Runtime; use zerotier_vl1_service::VL1Service; -fn main() { - std::process::exit( - if let Ok(_tokio_runtime) = zerotier_utils::tokio::runtime::Builder::new_multi_thread().enable_all().build() { - let test_inner = Arc::new(zerotier_network_hypervisor::vl1::DummyInnerProtocol::default()); - let test_path_filter = Arc::new(zerotier_network_hypervisor::vl1::DummyPathFilter::default()); - let datadir = open_datadir(&flags); - let svc = VL1Service::new(datadir, test_inner, test_path_filter, zerotier_vl1_service::VL1Settings::default()); - if svc.is_ok() { - let svc = svc.unwrap(); - svc.node().init_default_roots(); +async fn run(database: Arc, runtime: &Runtime) -> i32 { + let controller = Controller::new(database.clone(), runtime.handle().clone()); - // Wait for kill signal on Unix-like platforms. - #[cfg(unix)] - { - let term = Arc::new(AtomicBool::new(false)); - let _ = signal_hook::flag::register(libc::SIGINT, term.clone()); - let _ = signal_hook::flag::register(libc::SIGTERM, term.clone()); - let _ = signal_hook::flag::register(libc::SIGQUIT, term.clone()); - while !term.load(Ordering::Relaxed) { - std::thread::sleep(Duration::from_secs(1)); - } - } - - println!("Terminate signal received, shutting down..."); - exitcode::OK - } else { - eprintln!("FATAL: error launching service: {}", svc.err().unwrap().to_string()); - exitcode::ERR_IOERR - } - } else { - eprintln!("FATAL: error launching service: can't start async runtime"); - exitcode::ERR_IOERR - }, + let svc = VL1Service::new( + database.clone(), + controller.clone(), + controller.clone(), + zerotier_vl1_service::VL1Settings::default(), ); + if svc.is_ok() { + let svc = svc.unwrap(); + svc.node().init_default_roots(); + + // Wait for kill signal on Unix-like platforms. + #[cfg(unix)] + { + let term = Arc::new(AtomicBool::new(false)); + let _ = signal_hook::flag::register(libc::SIGINT, term.clone()); + let _ = signal_hook::flag::register(libc::SIGTERM, term.clone()); + let _ = signal_hook::flag::register(libc::SIGQUIT, term.clone()); + while !term.load(Ordering::Relaxed) { + std::thread::sleep(Duration::from_secs(1)); + } + } + + println!("Terminate signal received, shutting down..."); + exitcode::OK + } else { + eprintln!("FATAL: error launching service: {}", svc.err().unwrap().to_string()); + exitcode::ERR_IOERR + } +} + +fn main() { + const REQUIRE_ONE_OF_ARGS: [&'static str; 2] = ["postgres", "filedb"]; + let global_args = Command::new("zerotier-controller") + .arg( + Arg::new("filedb") + .short('f') + .long("filedb") + .takes_value(true) + .forbid_empty_values(true) + .value_name("path") + .help(Some("Use filesystem database at path")) + .required_unless_present_any(&REQUIRE_ONE_OF_ARGS), + ) + .arg( + Arg::new("postgres") + .short('p') + .long("postgres") + .takes_value(true) + .forbid_empty_values(true) + .value_name("path") + .help(Some("Connect to postgres with parameters in JSON file")) + .required_unless_present_any(&REQUIRE_ONE_OF_ARGS), + ) + .version(format!("{}.{}.{}", VERSION_MAJOR, VERSION_MINOR, VERSION_REVISION).as_str()) + .arg_required_else_help(true) + .try_get_matches_from(std::env::args()) + .unwrap_or_else(|e| { + let _ = e.print(); + std::process::exit(exitcode::ERR_USAGE); + }); + + if let Ok(tokio_runtime) = zerotier_utils::tokio::runtime::Builder::new_multi_thread().enable_all().build() { + tokio_runtime.block_on(async { + if let Some(filedb_base_path) = global_args.value_of("filedb") { + std::process::exit(run(Arc::new(FileDatabase::new(filedb_base_path).await), &tokio_runtime).await); + } else { + eprintln!("FATAL: no database type selected."); + std::process::exit(exitcode::ERR_USAGE); + }; + }); + } else { + eprintln!("FATAL: error launching service: can't start async runtime"); + std::process::exit(exitcode::ERR_IOERR) + } } diff --git a/service/src/cli/rootset.rs b/service/src/cli/rootset.rs index d51385b46..474cab25b 100644 --- a/service/src/cli/rootset.rs +++ b/service/src/cli/rootset.rs @@ -8,6 +8,7 @@ use crate::{exitcode, Flags}; use zerotier_network_hypervisor::vl1::RootSet; +use zerotier_utils::io::{read_limit, DEFAULT_FILE_IO_READ_LIMIT}; use zerotier_utils::json::to_json_pretty; use zerotier_utils::marshalable::Marshalable; @@ -26,7 +27,7 @@ pub fn cmd(_: Flags, cmd_args: &ArgMatches) -> i32 { let path = path.unwrap(); let secret_arg = secret_arg.unwrap(); let secret = crate::utils::parse_cli_identity(secret_arg, true); - let json_data = crate::utils::read_limit(path, crate::utils::DEFAULT_FILE_IO_READ_LIMIT); + let json_data = read_limit(path, DEFAULT_FILE_IO_READ_LIMIT); if secret.is_err() { eprintln!("ERROR: unable to parse '{}' or read as a file.", secret_arg); return exitcode::ERR_IOERR; @@ -62,7 +63,7 @@ pub fn cmd(_: Flags, cmd_args: &ArgMatches) -> i32 { let path = sc_args.value_of("path"); if path.is_some() { let path = path.unwrap(); - let json_data = crate::utils::read_limit(path, crate::utils::DEFAULT_FILE_IO_READ_LIMIT); + let json_data = read_limit(path, DEFAULT_FILE_IO_READ_LIMIT); if json_data.is_err() { eprintln!("ERROR: unable to read '{}'.", path); return exitcode::ERR_IOERR; @@ -90,7 +91,7 @@ pub fn cmd(_: Flags, cmd_args: &ArgMatches) -> i32 { let path = sc_args.value_of("path"); if path.is_some() { let path = path.unwrap(); - let json_data = crate::utils::read_limit(path, 1048576); + let json_data = read_limit(path, DEFAULT_FILE_IO_READ_LIMIT); if json_data.is_err() { eprintln!("ERROR: unable to read '{}'.", path); return exitcode::ERR_IOERR; diff --git a/service/src/utils.rs b/service/src/utils.rs index 52979dff7..095232379 100644 --- a/service/src/utils.rs +++ b/service/src/utils.rs @@ -1,7 +1,7 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. -use std::io::Read; use std::path::Path; +use std::str::FromStr; use zerotier_network_hypervisor::vl1::Identity; use zerotier_utils::io::read_limit; diff --git a/utils/src/io.rs b/utils/src/io.rs index 9a48e85c2..eb8166f96 100644 --- a/utils/src/io.rs +++ b/utils/src/io.rs @@ -5,7 +5,7 @@ use std::io::Read; use std::path::Path; /// Default sanity limit parameter for read_limit() used throughout the service. -pub const DEFAULT_FILE_IO_READ_LIMIT: usize = 1048576; +pub const DEFAULT_FILE_IO_READ_LIMIT: usize = 262144; /// Convenience function to read up to limit bytes from a file. /// diff --git a/vl1-service/src/vl1service.rs b/vl1-service/src/vl1service.rs index 6a3c9650d..1ca455f7d 100644 --- a/vl1-service/src/vl1service.rs +++ b/vl1-service/src/vl1service.rs @@ -16,9 +16,6 @@ use crate::settings::VL1Settings; use crate::sys::udp::{udp_test_bind, BoundUdpPort, UdpPacketHandler}; use crate::LocalSocket; -/// This can be adjusted to trade thread count for maximum I/O concurrency. -const MAX_PER_SOCKET_CONCURRENCY: usize = 8; - /// Update UDP bindings every this many seconds. const UPDATE_UDP_BINDINGS_EVERY_SECS: usize = 10;