diff --git a/controller/Cargo.toml b/controller/Cargo.toml index 4867b09d1..056bcad04 100644 --- a/controller/Cargo.toml +++ b/controller/Cargo.toml @@ -18,3 +18,5 @@ serde_json = { version = "^1", features = ["std"], default-features = false } serde_yaml = "^0" clap = { version = "^3", features = ["std", "suggestions"], default-features = false } notify = { version = "^5", features = ["macos_fsevent"], default-features = false } +tokio-postgres = "^0" +pin-utils = "^0" diff --git a/controller/src/database.rs b/controller/src/database.rs index 34fcfc34a..d7d37bc23 100644 --- a/controller/src/database.rs +++ b/controller/src/database.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use std::error::Error; use zerotier_network_hypervisor::vl1::{Address, InetAddress, NodeStorage}; use zerotier_network_hypervisor::vl2::NetworkId; @@ -8,6 +7,8 @@ use zerotier_utils::tokio::sync::broadcast::Receiver; use crate::model::*; +pub type Error = Box; + /// Database change relevant to the controller and that was NOT initiated by the controller. #[derive(Clone, Debug)] pub enum Change { @@ -21,13 +22,13 @@ pub enum Change { #[async_trait] pub trait Database: Sync + Send + NodeStorage + 'static { - async fn list_networks(&self) -> Result, Box>; - async fn get_network(&self, id: NetworkId) -> Result, Box>; - async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Box>; + async fn list_networks(&self) -> Result, Error>; + async fn get_network(&self, id: NetworkId) -> Result, Error>; + async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Error>; - async fn list_members(&self, network_id: NetworkId) -> Result, Box>; - async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Box>; - async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Box>; + async fn list_members(&self, network_id: NetworkId) -> Result, Error>; + async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Error>; + async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Error>; /// Get a receiver that can be used to receive changes made to networks and members, if supported. /// @@ -48,11 +49,7 @@ pub trait Database: Sync + Send + NodeStorage + 'static { /// /// The default trait implementation uses a brute force method. This should be reimplemented if a /// more efficient way is available. - async fn list_members_deauthorized_after( - &self, - network_id: NetworkId, - cutoff: i64, - ) -> Result, Box> { + async fn list_members_deauthorized_after(&self, network_id: NetworkId, cutoff: i64) -> Result, Error> { let mut v = Vec::new(); let members = self.list_members(network_id).await?; for a in members.iter() { @@ -69,7 +66,7 @@ pub trait Database: Sync + Send + NodeStorage + 'static { /// /// The default trait implementation uses a brute force method. This should be reimplemented if a /// more efficient way is available. - async fn is_ip_assigned(&self, network_id: NetworkId, ip: &InetAddress) -> Result> { + async fn is_ip_assigned(&self, network_id: NetworkId, ip: &InetAddress) -> Result { let members = self.list_members(network_id).await?; for a in members.iter() { if let Some(m) = self.get_member(network_id, *a).await? { @@ -81,5 +78,5 @@ pub trait Database: Sync + Send + NodeStorage + 'static { return Ok(false); } - async fn log_request(&self, obj: RequestLogItem) -> Result<(), Box>; + async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error>; } diff --git a/controller/src/filedatabase.rs b/controller/src/filedatabase.rs index 42efd6066..9625c2a8b 100644 --- a/controller/src/filedatabase.rs +++ b/controller/src/filedatabase.rs @@ -1,4 +1,3 @@ -use std::error::Error; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::atomic::{AtomicU64, Ordering}; @@ -19,7 +18,7 @@ use zerotier_utils::tokio::task::JoinHandle; use zerotier_utils::tokio::time::{sleep, Duration, Instant}; use crate::cache::Cache; -use crate::database::{Change, Database}; +use crate::database::{Change, Database, Error}; use crate::model::*; const IDENTITY_SECRET_FILENAME: &'static str = "identity.secret"; @@ -45,7 +44,7 @@ pub struct FileDatabase { // TODO: should cache at least hashes and detect changes in the filesystem live. impl FileDatabase { - pub async fn new>(runtime: Handle, base_path: P) -> Result, Box> { + pub async fn new>(runtime: Handle, base_path: P) -> Result, Error> { let base_path: PathBuf = base_path.as_ref().into(); let _ = fs::create_dir_all(&base_path).await?; @@ -240,7 +239,7 @@ impl FileDatabase { .join(format!("M{}.yaml", member_id.to_string())) } - async fn load_object(path: &Path) -> Result, Box> { + async fn load_object(path: &Path) -> Result, Error> { if let Ok(raw) = fs::read(path).await { return Ok(Some(serde_yaml::from_slice::(raw.as_slice())?)); } else { @@ -299,7 +298,7 @@ impl NodeStorage for FileDatabase { #[async_trait] impl Database for FileDatabase { - async fn list_networks(&self) -> Result, Box> { + async fn list_networks(&self) -> Result, Error> { let mut networks = Vec::new(); if let Some(controller_address) = self.get_controller_address() { let controller_address_shift24 = u64::from(controller_address).wrapping_shl(24); @@ -323,7 +322,7 @@ impl Database for FileDatabase { Ok(networks) } - async fn get_network(&self, id: NetworkId) -> Result, Box> { + async fn get_network(&self, id: NetworkId) -> Result, Error> { let mut network = Self::load_object::(self.network_path(id).as_path()).await?; if let Some(network) = network.as_mut() { // FileDatabase stores networks by their "network number" and automatically adapts their IDs @@ -340,7 +339,7 @@ impl Database for FileDatabase { Ok(network) } - async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Box> { + async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Error> { if !generate_change_notification { let _ = self.cache.on_network_updated(obj.clone()); } @@ -350,7 +349,7 @@ impl Database for FileDatabase { return Ok(()); } - async fn list_members(&self, network_id: NetworkId) -> Result, Box> { + async fn list_members(&self, network_id: NetworkId) -> Result, Error> { let mut members = Vec::new(); let mut dir = fs::read_dir(self.base_path.join(format!("N{:06x}", network_id.network_no()))).await?; while let Ok(Some(ent)) = dir.next_entry().await { @@ -372,7 +371,7 @@ impl Database for FileDatabase { Ok(members) } - async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Box> { + async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Error> { let mut member = Self::load_object::(self.member_path(network_id, node_id).as_path()).await?; if let Some(member) = member.as_mut() { if member.network_id != network_id { @@ -384,7 +383,7 @@ impl Database for FileDatabase { Ok(member) } - async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Box> { + async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Error> { if !generate_change_notification { let _ = self.cache.on_member_updated(obj.clone()); } @@ -398,7 +397,7 @@ impl Database for FileDatabase { Some(self.change_sender.subscribe()) } - async fn log_request(&self, obj: RequestLogItem) -> Result<(), Box> { + async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error> { println!("{}", obj.to_string()); Ok(()) } diff --git a/controller/src/lib.rs b/controller/src/lib.rs index 0d663136d..389d99ee7 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -7,5 +7,6 @@ pub(crate) mod cache; pub mod database; pub mod filedatabase; pub mod model; +pub mod postgresdatabase; pub use controller::*; diff --git a/controller/src/postgresdatabase.rs b/controller/src/postgresdatabase.rs new file mode 100644 index 000000000..85fdcdffe --- /dev/null +++ b/controller/src/postgresdatabase.rs @@ -0,0 +1,410 @@ +use std::collections::HashMap; +use std::ops::Deref; +use std::str::FromStr; +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use pin_utils::pin_mut; +use serde::{Deserialize, Serialize}; +use tokio_postgres::types::Type; +use tokio_postgres::{Client, Statement}; + +use zerotier_crypto::verified::Verified; + +use zerotier_network_hypervisor::vl1::{Address, Identity, InetAddress, NodeStorage}; +use zerotier_network_hypervisor::vl2::rule::Rule; +use zerotier_network_hypervisor::vl2::NetworkId; + +use zerotier_utils::futures_util::{Stream, StreamExt}; +use zerotier_utils::tokio; +use zerotier_utils::tokio::runtime::Handle; +use zerotier_utils::tokio::sync::broadcast::{channel, Receiver, Sender}; +use zerotier_utils::tokio::task::JoinHandle; + +use crate::database::*; +use crate::model::{Member, Network, RequestLogItem}; + +const RECONNECT_RATE_LIMIT: tokio::time::Duration = tokio::time::Duration::from_millis(250); + +const GET_NETWORK_SQL: &'static str = " +SELECT + n.capabilities, + n.enable_broadcast, + n.mtu, + n.multicast_limit, + n.name, + n.private, + n.rules, + n.v4_assign_mode, + n.v6_assign_mode, + n.sso_enabled, + (CASE WHEN n.sso_enabled THEN o.client_id ELSE NULL END) as client_id, + (CASE WHEN n.sso_enabled THEN o.authorization_endpoint ELSE NULL END) as authorization_endpoint, + d.domain, + d.servers, + ARRAY(SELECT CONCAT(host(ip_range_start),'|', host(ip_range_end)) FROM ztc_network_assignment_pool WHERE network_id = n.id) AS assignment_pool, + ARRAY(SELECT CONCAT(host(address),'/',bits::text,'|',COALESCE(host(via), 'NULL')) FROM ztc_network_route WHERE network_id = n.id) AS routes +FROM + ztc_network n + LEFT OUTER JOIN ztc_org o ON o.owner_id = n.owner_id + LEFT OUTER JOIN ztc_network_oidc_config noc ON noc.network_id = n.id + LEFT OUTER JOIN ztc_oidc_config oc ON noc.client_id = oc.client_id AND o.org_id = oc.org_id + LEFT OUTER JOIN ztc_network_dns d ON d.network_id = n.id +WHERE + id = $1 AND + deleted = false"; + +const GET_NETWORK_MEMBERS_WITH_CAPABILITIES_SQL: &'static str = " +SELECT + m.id, + m.capabilities +FROM + ztc_member m +WHERE + network_id = $1 AND + authorized = true AND + deleted = false AND + capabilities IS NOT NULL AND + capabilities != '[]' +"; + +struct PostgresConnection { + s_list_networks: Statement, + s_list_members: Statement, + s_get_network: Statement, + s_get_network_members_with_capabilities: Statement, + client: Client, + connection_task: JoinHandle<()>, +} + +impl PostgresConnection { + async fn new(runtime: &Handle, postgres_path: &str) -> Result, Error> { + let (client, connection) = tokio_postgres::connect(postgres_path, tokio_postgres::NoTls).await?; + Ok(Box::new(Self { + s_list_networks: client + .prepare_typed( + "SELECT id FROM ztc_network WHERE controller_id = $1 AND deleted = false", + &[Type::TEXT], + ) + .await?, + s_list_members: client + .prepare_typed("SELECT id FROM ztc_member WHERE network_id = $1 AND deleted = false", &[Type::TEXT]) + .await?, + s_get_network: client.prepare_typed(GET_NETWORK_SQL, &[Type::TEXT]).await?, + s_get_network_members_with_capabilities: client + .prepare_typed(GET_NETWORK_MEMBERS_WITH_CAPABILITIES_SQL, &[Type::TEXT]) + .await?, + client, + connection_task: runtime.spawn(async move { + if let Err(e) = connection.await { + eprintln!("ERROR: postgresql connection error: {}", e.to_string()); + } + }), + })) + } +} + +impl Drop for PostgresConnection { + fn drop(&mut self) { + self.connection_task.abort(); + } +} + +struct ConnectionHolder<'a>(Option>, &'a PostgresDatabase); + +impl<'a> Deref for ConnectionHolder<'a> { + type Target = PostgresConnection; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0.as_ref().unwrap() + } +} + +impl<'a> Drop for ConnectionHolder<'a> { + fn drop(&mut self) { + let mut connections = self.1.connections.lock().unwrap(); + connections.0.push(self.0.take().unwrap()); + let _ = connections.1.send(()); // unblock any waiting get_connection() requests + } +} + +pub struct PostgresDatabase { + local_controller_id_str: String, + local_identity: Verified, + connections: Mutex<(Vec>, Sender<()>)>, + postgres_path: String, + runtime: Handle, +} + +impl PostgresDatabase { + pub async fn new( + runtime: Handle, + postgres_path: String, + num_connections: usize, + local_identity: Verified, + ) -> Result, Error> { + assert!(num_connections > 0); + let (sender, _) = channel(4096); + let mut connections = Vec::with_capacity(num_connections); + for _ in 0..num_connections { + connections.push(PostgresConnection::new(&runtime, postgres_path.as_str()).await?); + } + Ok(Arc::new(Self { + local_controller_id_str: local_identity.address.to_string(), + local_identity, + connections: Mutex::new((connections, sender)), + postgres_path, + runtime, + })) + } + + async fn get_connection(&self) -> Result { + loop { + let mut receiver = { + let mut connections = self.connections.lock().unwrap(); + if let Some(c) = connections.0.pop() { + if c.client.is_closed() { + break; + } else { + return Ok(ConnectionHolder(Some(c), self)); + } + } + connections.1.subscribe() + }; + let _ = receiver.recv().await; // wait for a connection to be returned + } + tokio::time::sleep(RECONNECT_RATE_LIMIT).await; // rate limit reconnection attempts + return Ok(ConnectionHolder( + Some(PostgresConnection::new(&self.runtime, self.postgres_path.as_str()).await?), + self, + )); + } +} + +impl NodeStorage for PostgresDatabase { + fn load_node_identity(&self) -> Option> { + Some(self.local_identity.clone()) + } + + fn save_node_identity(&self, _: &Verified) { + eprintln!("FATAL: NodeStorage::save_node_identity() not implemented in PostgresDatabase, identity must be pregenerated"); + panic!(); + } +} + +#[async_trait] +impl Database for PostgresDatabase { + async fn list_networks(&self) -> Result, Error> { + let c = self.get_connection().await?; + let rs = c.client.query_raw(&c.s_list_networks, &[&self.local_controller_id_str]).await?; + pin_mut!(rs); + let mut r = Vec::with_capacity(rs.size_hint().0.min(64)); + while let Some(Ok(row)) = rs.next().await { + r.push(NetworkId::from_str(row.get(0))?); + } + Ok(r) + } + + #[allow(unused_variables)] + async fn get_network(&self, id: NetworkId) -> Result, Error> { + let (nw, with_caps) = { + let c = self.get_connection().await?; + let network_id_string = id.to_string(); + if let Some(r) = c.client.query_opt(&c.s_get_network, &[&network_id_string]).await? { + if let Ok(with_caps) = c + .client + .query(&c.s_get_network_members_with_capabilities, &[&network_id_string]) + .await + { + (r, with_caps) + } else { + (r, Vec::new()) + } + } else { + return Ok(None); + } + }; + + let mut capabilities: Option<&str> = nw.get(0); + let enable_broadcast: bool = nw.get(1); + let mtu: i32 = nw.get(2); + let multicast_limit: i64 = nw.get(3); + let name: &str = nw.get(4); + let private: bool = nw.get(5); + let mut rules: Option<&str> = nw.get(6); + let v4_assign_mode: &str = nw.get(7); + let v6_assign_mode: &str = nw.get(8); + let sso_enabled: bool = nw.get(9); + let mut client_id: Option<&str> = nw.get(10); + let mut authorization_endpoint: Option<&str> = nw.get(11); + let mut domain: Option<&str> = nw.get(12); + let mut servers: Option<&str> = nw.get(13); + let mut assignment_pool: Option<&str> = nw.get(14); + let mut routes: Option<&str> = nw.get(15); + + filter_null_string(&mut capabilities); + filter_null_string(&mut rules); + filter_null_string(&mut client_id); + filter_null_string(&mut authorization_endpoint); + filter_null_string(&mut domain); + filter_null_string(&mut servers); + filter_null_string(&mut assignment_pool); + filter_null_string(&mut routes); + + let mut rules = if let Some(rules) = rules { + serde_json::from_str::>(rules)? + } else { + Vec::new() + }; + + // Capabilities are being deprecated in V2 as they are complex and rarely used. To handle networks + // that have them configured (there aren't many) we translate them into special portions of the + // general rule set that match on the capability owner's address. + if let Some(capabilities) = capabilities { + let capabilities_vec = serde_json::from_str::>(capabilities)?; + let mut capabilities = HashMap::with_capacity(capabilities_vec.len()); + for c in capabilities_vec.iter() { + capabilities.insert(c.id, c); + } + let mut members_by_cap = HashMap::with_capacity(with_caps.len()); + for wc in with_caps.iter() { + if let Ok(member_id) = Address::from_str(wc.get(0)) { + if let Ok(cap_ids) = serde_json::from_str::>(wc.get(1)) { + for cap_id in cap_ids.iter() { + members_by_cap + .entry(*cap_id) + .or_insert_with(|| Vec::with_capacity(4)) + .push(member_id); + } + } + } + } + if !members_by_cap.is_empty() { + let mut base_rules = rules.clone(); + rules.clear(); + + for (cap_id, member_ids) in members_by_cap.iter() { + if let Some(cap) = capabilities.get(cap_id) { + let mut or = false; + for m in member_ids.iter() { + rules.push(Rule::match_source_zerotier_address(false, or, *m)); + or = true; + } + for r in cap.rules.iter() { + rules.push(r.clone()); + } + rules.push(Rule::action_accept()); + } + } + + if base_rules.is_empty() { + rules.push(Rule::action_accept()); + } else { + for r in base_rules.drain(..) { + rules.push(r); + } + } + } + } + + Ok(Some(Network { + id, + name: name.to_string(), + multicast_limit: if multicast_limit < 0 || multicast_limit > (u32::MAX as i64) { + None + } else { + Some(multicast_limit as u32) + }, + enable_broadcast: Some(enable_broadcast), + v4_assign_mode: Some(serde_json::from_str(v4_assign_mode)?), + v6_assign_mode: Some(serde_json::from_str(v6_assign_mode)?), + ip_assignment_pools: todo!(), + ip_routes: todo!(), + dns: todo!(), + rules: todo!(), + credential_ttl: None, + min_supported_version: None, + mtu: if mtu < 0 || mtu > (u16::MAX as i32) { + None + } else { + Some(mtu as u16) + }, + private, + learn_members: Some(true), + })) + } + + async fn save_network(&self, obj: Network, generate_change_notification: bool) -> Result<(), Error> { + todo!() + } + + async fn list_members(&self, network_id: NetworkId) -> Result, Error> { + let network_id_string = network_id.to_string(); + let c = self.get_connection().await?; + let rs = c.client.query_raw(&c.s_list_members, &[&network_id_string]).await?; + pin_mut!(rs); + let mut r = Vec::with_capacity(rs.size_hint().0.min(64)); + while let Some(Ok(row)) = rs.next().await { + r.push(Address::from_str(row.get(0))?); + } + Ok(r) + } + + async fn get_member(&self, network_id: NetworkId, node_id: Address) -> Result, Error> { + todo!() + } + + async fn save_member(&self, obj: Member, generate_change_notification: bool) -> Result<(), Error> { + todo!() + } + + async fn changes(&self) -> Option> { + // TODO + None + } + + async fn list_members_deauthorized_after(&self, network_id: NetworkId, cutoff: i64) -> Result, Error> { + let mut v = Vec::new(); + let members = self.list_members(network_id).await?; + for a in members.iter() { + if let Some(m) = self.get_member(network_id, *a).await? { + if m.last_deauthorized_time.unwrap_or(i64::MIN) >= cutoff { + v.push(m.node_id); + } + } + } + Ok(v) + } + + async fn is_ip_assigned(&self, network_id: NetworkId, ip: &InetAddress) -> Result { + let members = self.list_members(network_id).await?; + for a in members.iter() { + if let Some(m) = self.get_member(network_id, *a).await? { + if m.ip_assignments.iter().any(|ip2| ip2.ip_bytes().eq(ip.ip_bytes())) { + return Ok(true); + } + } + } + return Ok(false); + } + + async fn log_request(&self, obj: RequestLogItem) -> Result<(), Error> { + todo!() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct V1Capability { + pub id: u32, + pub rules: Vec, +} + +fn filter_null_string(s: &mut Option<&str>) { + if let Some(ss) = s.as_ref() { + let ss = (*ss).trim(); + if ss.is_empty() || ss == "null" || ss == "NULL" { + let _ = s.take(); + } + } +} diff --git a/utils/Cargo.toml b/utils/Cargo.toml index b548cf60f..fed728f4d 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -7,12 +7,13 @@ version = "0.1.0" [features] default = [] -tokio = ["dep:tokio"] +tokio = ["dep:tokio", "dep:futures-util"] [dependencies] serde = { version = "^1", features = ["derive"], default-features = false } serde_json = { version = "^1", features = ["std"], default-features = false } tokio = { version = "^1", default-features = false, features = ["fs", "io-util", "io-std", "net", "process", "rt", "rt-multi-thread", "signal", "sync", "time"], optional = true } +futures-util = { version = "^0", optional = true } [target."cfg(windows)".dependencies] winapi = { version = "^0", features = ["handleapi", "ws2ipdef", "ws2tcpip"] } diff --git a/utils/src/defer.rs b/utils/src/defer.rs new file mode 100644 index 000000000..b66b983b4 --- /dev/null +++ b/utils/src/defer.rs @@ -0,0 +1,13 @@ +/// Defer execution of a closure until dropped. +struct Defer(Option); + +impl Drop for Defer { + fn drop(&mut self) { + self.0.take().map(|f| f()); + } +} + +/// Defer execution of a closure until the return value is dropped. +pub fn defer(f: F) -> impl Drop { + Defer(Some(f)) +} diff --git a/utils/src/lib.rs b/utils/src/lib.rs index 4c3d54eec..690a8cf14 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -3,6 +3,7 @@ pub mod arrayvec; pub mod blob; pub mod buffer; +pub mod defer; pub mod dictionary; pub mod error; #[allow(unused)] @@ -15,6 +16,8 @@ pub mod json; pub mod marshalable; pub mod memory; pub mod pool; +#[cfg(feature = "tokio")] +pub mod reaper; pub mod ringbuffer; pub mod ringbuffermap; pub mod sync; @@ -22,10 +25,10 @@ pub mod thing; pub mod varint; #[cfg(feature = "tokio")] -pub mod reaper; +pub use tokio; #[cfg(feature = "tokio")] -pub use tokio; +pub use futures_util; /// Initial value that should be used for monotonic tick time variables. pub const NEVER_HAPPENED_TICKS: i64 = 0;