From 50675004cea6d53cfaa358fb0947c7a77e6e8347 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Thu, 25 Mar 2021 15:36:12 -0400 Subject: [PATCH] JSON patching for RESTful object update and some other stuff. --- rust-zerotier-core/src/address.rs | 39 +++++++- rust-zerotier-core/src/fingerprint.rs | 21 +++-- rust-zerotier-core/src/identity.rs | 13 ++- rust-zerotier-core/src/mac.rs | 48 +++++++++- rust-zerotier-core/src/networkid.rs | 32 +++++-- service/Cargo.toml | 6 +- service/src/api.rs | 19 ++-- service/src/commands/cert.rs | 2 +- service/src/commands/identity.rs | 5 +- service/src/commands/locator.rs | 2 - service/src/commands/status.rs | 2 - service/src/fastudpsocket.rs | 10 +- service/src/httplistener.rs | 6 +- service/src/localconfig.rs | 130 +++++++++++++------------- service/src/service.rs | 6 +- service/src/utils.rs | 46 ++++++++- service/src/vnic/common.rs | 8 +- 17 files changed, 273 insertions(+), 122 deletions(-) diff --git a/rust-zerotier-core/src/address.rs b/rust-zerotier-core/src/address.rs index 0a7fd5982..12c632a02 100644 --- a/rust-zerotier-core/src/address.rs +++ b/rust-zerotier-core/src/address.rs @@ -14,11 +14,25 @@ #[derive(PartialEq, Eq, Clone, Copy, Ord, PartialOrd)] pub struct Address(pub u64); +impl Default for Address { + #[inline(always)] + fn default() -> Address { + Address(0) + } +} + +impl Address { + #[inline(always)] + fn to_bytes(&self) -> [u8; 5] { + [(self.0 >> 32) as u8, (self.0 >> 24) as u8, (self.0 >> 16) as u8, (self.0 >> 8) as u8, self.0 as u8] + } +} + impl From<&[u8]> for Address { #[inline(always)] fn from(bytes: &[u8]) -> Self { if bytes.len() >= 5 { - Address(((bytes[0] as u64) << 32) | ((bytes[0] as u64) << 24) | ((bytes[0] as u64) << 16) | ((bytes[0] as u64) << 8) | (bytes[0] as u64)) + Address(((bytes[0] as u64) << 32) | ((bytes[1] as u64) << 24) | ((bytes[2] as u64) << 16) | ((bytes[3] as u64) << 8) | (bytes[4] as u64)) } else { Address(0) } @@ -46,14 +60,31 @@ impl From<&str> for Address { } impl serde::Serialize for Address { - fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.to_string().as_str()) } + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + if serializer.is_human_readable() { + serializer.serialize_str(self.to_string().as_str()) + } else { + let b = self.to_bytes(); + serializer.serialize_bytes(b.as_ref()) + } + } } + struct AddressVisitor; + impl<'de> serde::de::Visitor<'de> for AddressVisitor { type Value = Address; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("ZeroTier Address in string format") } + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("ZeroTier Address") } fn visit_str(self, s: &str) -> Result where E: serde::de::Error { Ok(Address::from(s)) } + fn visit_bytes(self, v: &[u8]) -> Result where E: serde::de::Error { Ok(Address::from(v)) } } + impl<'de> serde::Deserialize<'de> for Address { - fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { deserializer.deserialize_str(AddressVisitor) } + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { + if deserializer.is_human_readable() { + deserializer.deserialize_str(AddressVisitor) + } else { + deserializer.deserialize_bytes(AddressVisitor) + } + } } diff --git a/rust-zerotier-core/src/fingerprint.rs b/rust-zerotier-core/src/fingerprint.rs index 40c015def..21988aa2d 100644 --- a/rust-zerotier-core/src/fingerprint.rs +++ b/rust-zerotier-core/src/fingerprint.rs @@ -22,15 +22,15 @@ use crate::capi as ztcore; #[derive(PartialEq, Eq, Clone)] pub struct Fingerprint { pub address: Address, - pub hash: [u8; 48] + pub hash: [u8; 48], } impl Fingerprint { #[inline(always)] pub(crate) fn new_from_capi(fp: &ztcore::ZT_Fingerprint) -> Fingerprint { - Fingerprint{ + Fingerprint { address: Address(fp.address), - hash: fp.hash + hash: fp.hash, } } @@ -44,9 +44,9 @@ impl Fingerprint { unsafe { if ztcore::ZT_Fingerprint_fromString(cfp.as_mut_ptr(), cs.as_ptr()) != 0 { let fp = cfp.assume_init(); - return Ok(Fingerprint{ + return Ok(Fingerprint { address: Address(fp.address), - hash: fp.hash + hash: fp.hash, }); } } @@ -84,9 +84,13 @@ impl ToString for Fingerprint { } impl serde::Serialize for Fingerprint { - fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.to_string().as_str()) } + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + serializer.serialize_str(self.to_string().as_str()) + } } + struct FingerprintVisitor; + impl<'de> serde::de::Visitor<'de> for FingerprintVisitor { type Value = Fingerprint; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("ZeroTier Fingerprint in string format") } @@ -98,6 +102,9 @@ impl<'de> serde::de::Visitor<'de> for FingerprintVisitor { return Ok(id.ok().unwrap() as Self::Value); } } + impl<'de> serde::Deserialize<'de> for Fingerprint { - fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { deserializer.deserialize_str(FingerprintVisitor) } + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { + deserializer.deserialize_str(FingerprintVisitor) + } } diff --git a/rust-zerotier-core/src/identity.rs b/rust-zerotier-core/src/identity.rs index e7cbcb0e4..5a9209bc1 100644 --- a/rust-zerotier-core/src/identity.rs +++ b/rust-zerotier-core/src/identity.rs @@ -200,9 +200,13 @@ impl ToString for Identity { } impl serde::Serialize for Identity { - fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.intl_to_string(false).as_str()) } + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + serializer.serialize_str(self.intl_to_string(false).as_str()) + } } + struct IdentityVisitor; + impl<'de> serde::de::Visitor<'de> for IdentityVisitor { type Value = Identity; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("ZeroTier Identity in string format") } @@ -214,8 +218,11 @@ impl<'de> serde::de::Visitor<'de> for IdentityVisitor { return Ok(id.ok().unwrap() as Self::Value); } } + impl<'de> serde::Deserialize<'de> for Identity { - fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { deserializer.deserialize_str(IdentityVisitor) } + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { + deserializer.deserialize_str(IdentityVisitor) + } } #[cfg(test)] @@ -255,7 +262,7 @@ mod tests { let from_str_fail = Identity::new_from_string("asdf:foo:invalid"); assert!(from_str_fail.is_err()); - let mut to_sign: [u8; 4] = [ 1,2,3,4 ]; + let mut to_sign: [u8; 4] = [1, 2, 3, 4]; let signed = test1.sign(&to_sign); assert!(signed.is_ok()); diff --git a/rust-zerotier-core/src/mac.rs b/rust-zerotier-core/src/mac.rs index 51181d9d8..26d58f27d 100644 --- a/rust-zerotier-core/src/mac.rs +++ b/rust-zerotier-core/src/mac.rs @@ -14,6 +14,20 @@ #[derive(PartialEq, Eq, Clone, Copy, PartialOrd, Ord)] pub struct MAC(pub u64); +impl Default for MAC { + #[inline(always)] + fn default() -> MAC { + MAC(0) + } +} + +impl MAC { + #[inline(always)] + fn to_bytes(&self) -> [u8; 6] { + [(self.0 >> 40) as u8, (self.0 >> 32) as u8, (self.0 >> 24) as u8, (self.0 >> 16) as u8, (self.0 >> 8) as u8, self.0 as u8] + } +} + impl ToString for MAC { fn to_string(&self) -> String { let x = self.0; @@ -21,21 +35,49 @@ impl ToString for MAC { } } +impl From<&[u8]> for MAC { + #[inline(always)] + fn from(bytes: &[u8]) -> Self { + if bytes.len() >= 6 { + MAC(((bytes[0] as u64) << 40) | ((bytes[1] as u64) << 32) | ((bytes[2] as u64) << 24) | ((bytes[3] as u64) << 16) | ((bytes[4] as u64) << 8) | (bytes[5] as u64)) + } else { + MAC(0) + } + } +} + impl From<&str> for MAC { fn from(s: &str) -> MAC { - MAC(u64::from_str_radix(s.replace(":","").as_str(), 16).unwrap_or(0)) + MAC(u64::from_str_radix(s.replace(":", "").as_str(), 16).unwrap_or(0)) } } impl serde::Serialize for MAC { - fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.to_string().as_str()) } + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + if serializer.is_human_readable() { + serializer.serialize_str(self.to_string().as_str()) + } else { + let b = self.to_bytes(); + serializer.serialize_bytes(b.as_ref()) + } + } } + struct AddressVisitor; + impl<'de> serde::de::Visitor<'de> for AddressVisitor { type Value = MAC; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("Ethernet MAC address in string format (with or without : separators)") } fn visit_str(self, s: &str) -> Result where E: serde::de::Error { Ok(MAC::from(s)) } + fn visit_bytes(self, v: &[u8]) -> Result where E: serde::de::Error { Ok(MAC::from(v)) } } + impl<'de> serde::Deserialize<'de> for MAC { - fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { deserializer.deserialize_str(AddressVisitor) } + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { + if deserializer.is_human_readable() { + deserializer.deserialize_str(AddressVisitor) + } else { + deserializer.deserialize_bytes(AddressVisitor) + } + } } diff --git a/rust-zerotier-core/src/networkid.rs b/rust-zerotier-core/src/networkid.rs index b1b71d26c..329f402c4 100644 --- a/rust-zerotier-core/src/networkid.rs +++ b/rust-zerotier-core/src/networkid.rs @@ -14,10 +14,10 @@ #[derive(PartialEq, Eq, Clone, Copy, PartialOrd, Ord)] pub struct NetworkId(pub u64); -impl NetworkId { +impl Default for NetworkId { #[inline(always)] - pub fn new_from_string(s: &str) -> NetworkId { - return NetworkId(u64::from_str_radix(s, 16).unwrap_or(0)); + fn default() -> NetworkId { + NetworkId(0) } } @@ -37,19 +37,35 @@ impl From for NetworkId { impl From<&str> for NetworkId { #[inline(always)] fn from(s: &str) -> Self { - NetworkId::new_from_string(s) + NetworkId(u64::from_str_radix(s, 16).unwrap_or(0)) } } impl serde::Serialize for NetworkId { - fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { serializer.serialize_str(self.to_string().as_str()) } + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + if serializer.is_human_readable() { + serializer.serialize_str(self.to_string().as_str()) + } else { + serializer.serialize_u64(self.0) + } + } } + struct NetworkIdVisitor; + impl<'de> serde::de::Visitor<'de> for NetworkIdVisitor { type Value = NetworkId; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("ZeroTier Address in string format") } - fn visit_str(self, s: &str) -> Result where E: serde::de::Error { Ok(NetworkId::new_from_string(s)) } + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("ZeroTier network ID") } + fn visit_str(self, s: &str) -> Result where E: serde::de::Error { Ok(NetworkId::from(s)) } + fn visit_u64(self, v: u64) -> Result where E: serde::de::Error { Ok(NetworkId(v)) } } + impl<'de> serde::Deserialize<'de> for NetworkId { - fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { deserializer.deserialize_str(NetworkIdVisitor) } + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { + if deserializer.is_human_readable() { + deserializer.deserialize_str(NetworkIdVisitor) + } else { + deserializer.deserialize_u64(NetworkIdVisitor) + } + } } diff --git a/service/Cargo.toml b/service/Cargo.toml index 3e0ee603b..4f37c8612 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -9,15 +9,15 @@ build = "build.rs" [dependencies] zerotier-core = { path = "../rust-zerotier-core" } -num_cpus = "1" +num_cpus = "*" tokio = { version = "1", features = ["rt", "net", "time", "signal", "macros"] } serde = { version = "1", features = ["derive"] } serde_json = "1" futures = "0" clap = { version = "2", features = ["suggestions", "wrap_help"] } chrono = "0" -hex = "0" -lazy_static = "1" +hex = "*" +lazy_static = "*" num-traits = "0" num-derive = "0" hyper = { version = "0", features = ["http1", "runtime", "server", "client", "tcp", "stream"] } diff --git a/service/src/api.rs b/service/src/api.rs index b55783968..472614eeb 100644 --- a/service/src/api.rs +++ b/service/src/api.rs @@ -12,21 +12,28 @@ /****/ use crate::service::Service; -use hyper::{Request, Body, StatusCode, Response, Method}; +use hyper::{Request, Body, StatusCode, Method}; pub(crate) fn status(service: Service, req: Request) -> (StatusCode, Body) { if req.method() == Method::GET { - let status = service.status(); - if status.is_none() { + service.status().map_or_else(|| { (StatusCode::SERVICE_UNAVAILABLE, Body::from("node shutdown in progress")) - } else { - (StatusCode::OK, Body::from(serde_json::to_string(status.as_ref().unwrap()).unwrap())) - } + }, |status| { + (StatusCode::OK, Body::from(serde_json::to_string(&status).unwrap())) + }) } else { (StatusCode::METHOD_NOT_ALLOWED, Body::from("/status allows method(s): GET")) } } +pub(crate) fn config(service: Service, req: Request) -> (StatusCode, Body) { + let config = service.local_config(); + if req.method() == Method::POST || req.method() == Method::PUT { + // TODO: diff config + } + (StatusCode::OK, Body::from(serde_json::to_string(config.as_ref()).unwrap())) +} + pub(crate) fn peer(service: Service, req: Request) -> (StatusCode, Body) { (StatusCode::NOT_IMPLEMENTED, Body::from("")) } diff --git a/service/src/commands/cert.rs b/service/src/commands/cert.rs index 67704038f..6f31d3f5e 100644 --- a/service/src/commands/cert.rs +++ b/service/src/commands/cert.rs @@ -146,7 +146,7 @@ fn newcsr(cli_args: &ArgMatches) -> i32 { if nwid.len() != 16 { break; } - let nwid = NetworkId::new_from_string(nwid.as_str()); + let nwid = NetworkId::from(nwid.as_str()); let fingerprint: String = Input::with_theme(theme) .with_prompt(format!(" [{}] Fingerprint of primary controller (optional)", networks.len() + 1)) diff --git a/service/src/commands/identity.rs b/service/src/commands/identity.rs index a857f5a5e..e79ad458c 100644 --- a/service/src/commands/identity.rs +++ b/service/src/commands/identity.rs @@ -12,9 +12,10 @@ /****/ use clap::ArgMatches; + +use zerotier_core::{Identity, IdentityType}; + use crate::store::Store; -use zerotier_core::{IdentityType, Identity}; -use std::sync::Arc; fn new_(cli_args: &ArgMatches) -> i32 { let id_type = cli_args.value_of("type").map_or(IdentityType::Curve25519, |idt| { diff --git a/service/src/commands/locator.rs b/service/src/commands/locator.rs index ea4deadde..6a5e8b341 100644 --- a/service/src/commands/locator.rs +++ b/service/src/commands/locator.rs @@ -15,8 +15,6 @@ use clap::ArgMatches; use zerotier_core::*; -use crate::store::Store; - fn new_(cli_args: &ArgMatches) -> i32 { let timestamp = cli_args.value_of("timestamp").map_or(crate::utils::ms_since_epoch(), |ts| { if ts.is_empty() { diff --git a/service/src/commands/status.rs b/service/src/commands/status.rs index 36b5fc018..eeeab5a7f 100644 --- a/service/src/commands/status.rs +++ b/service/src/commands/status.rs @@ -12,8 +12,6 @@ /****/ use std::error::Error; -use std::rc::Rc; -use std::str::FromStr; use std::sync::Arc; use hyper::{Uri, Method, StatusCode}; diff --git a/service/src/fastudpsocket.rs b/service/src/fastudpsocket.rs index 4a4a13255..6f4ebcd74 100644 --- a/service/src/fastudpsocket.rs +++ b/service/src/fastudpsocket.rs @@ -82,12 +82,10 @@ fn bind_udp_socket(_device_name: &str, address: &InetAddress) -> Result, Infallible>(Response::builder().header("Content-Type", "application/json").status(status).body(body).unwrap()) } })) } diff --git a/service/src/localconfig.rs b/service/src/localconfig.rs index 309444d87..a1d2083a5 100644 --- a/service/src/localconfig.rs +++ b/service/src/localconfig.rs @@ -67,6 +67,14 @@ pub struct LocalConfigPhysicalPathConfig { pub blacklist: bool } +impl Default for LocalConfigPhysicalPathConfig { + fn default() -> Self { + LocalConfigPhysicalPathConfig { + blacklist: false + } + } +} + #[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] #[serde(default)] pub struct LocalConfigVirtualConfig { @@ -74,6 +82,14 @@ pub struct LocalConfigVirtualConfig { pub try_: Vec } +impl Default for LocalConfigVirtualConfig { + fn default() -> Self { + LocalConfigVirtualConfig { + try_: Vec::new() + } + } +} + #[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] #[serde(default)] pub struct LocalConfigNetworkSettings { @@ -89,6 +105,18 @@ pub struct LocalConfigNetworkSettings { pub allow_default_route_override: bool, } +impl Default for LocalConfigNetworkSettings { + fn default() -> Self { + LocalConfigNetworkSettings { + allow_managed_ips: true, + allow_global_ips: false, + allow_managed_routes: true, + allow_global_routes: false, + allow_default_route_override: false + } + } +} + #[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] #[serde(default)] pub struct LocalConfigLogSettings { @@ -105,6 +133,22 @@ pub struct LocalConfigLogSettings { pub stderr: bool, } +impl Default for LocalConfigLogSettings { + fn default() -> Self { + // TODO: change before release to saner defaults + LocalConfigLogSettings { + path: None, + max_size: 131072, + vl1: true, + vl2: true, + vl2_trace_rules: true, + vl2_trace_multicast: true, + debug: true, + stderr: true, + } + } +} + #[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] #[serde(default)] pub struct LocalConfigSettings { @@ -124,56 +168,22 @@ pub struct LocalConfigSettings { pub explicit_addresses: Vec, } -#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] -#[serde(default)] -pub struct LocalConfig { - pub physical: BTreeMap, - #[serde(rename = "virtual")] - pub virtual_: BTreeMap, - pub network: BTreeMap, - pub settings: LocalConfigSettings, -} - -impl Default for LocalConfigPhysicalPathConfig { +impl Default for LocalConfigSettings { fn default() -> Self { - LocalConfigPhysicalPathConfig { - blacklist: false + let mut bl: Vec = Vec::new(); + bl.reserve(LocalConfigSettings::DEFAULT_PREFIX_BLACKLIST.len()); + for n in LocalConfigSettings::DEFAULT_PREFIX_BLACKLIST.iter() { + bl.push(String::from(*n)); } - } -} -impl Default for LocalConfigVirtualConfig { - fn default() -> Self { - LocalConfigVirtualConfig { - try_: Vec::new() - } - } -} - -impl Default for LocalConfigNetworkSettings { - fn default() -> Self { - LocalConfigNetworkSettings { - allow_managed_ips: true, - allow_global_ips: false, - allow_managed_routes: true, - allow_global_routes: false, - allow_default_route_override: false - } - } -} - -impl Default for LocalConfigLogSettings { - fn default() -> Self { - // TODO: change before release to saner defaults - LocalConfigLogSettings { - path: None, - max_size: 131072, - vl1: true, - vl2: true, - vl2_trace_rules: true, - vl2_trace_multicast: true, - debug: true, - stderr: true, + LocalConfigSettings { + primary_port: zerotier_core::DEFAULT_PORT, + secondary_port: Some(zerotier_core::DEFAULT_SECONDARY_PORT), + auto_port_search: true, + port_mapping: true, + log: LocalConfigLogSettings::default(), + interface_prefix_blacklist: bl, + explicit_addresses: Vec::new() } } } @@ -198,24 +208,14 @@ impl LocalConfigSettings { } } -impl Default for LocalConfigSettings { - fn default() -> Self { - let mut bl: Vec = Vec::new(); - bl.reserve(LocalConfigSettings::DEFAULT_PREFIX_BLACKLIST.len()); - for n in LocalConfigSettings::DEFAULT_PREFIX_BLACKLIST.iter() { - bl.push(String::from(*n)); - } - - LocalConfigSettings { - primary_port: zerotier_core::DEFAULT_PORT, - secondary_port: Some(zerotier_core::DEFAULT_SECONDARY_PORT), - auto_port_search: true, - port_mapping: true, - log: LocalConfigLogSettings::default(), - interface_prefix_blacklist: bl, - explicit_addresses: Vec::new() - } - } +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] +#[serde(default)] +pub struct LocalConfig { + pub physical: BTreeMap, + #[serde(rename = "virtual")] + pub virtual_: BTreeMap, + pub network: BTreeMap, + pub settings: LocalConfigSettings, } impl Default for LocalConfig { diff --git a/service/src/service.rs b/service/src/service.rs index 71d904ebf..dee882b19 100644 --- a/service/src/service.rs +++ b/service/src/service.rs @@ -11,10 +11,8 @@ */ /****/ -use std::cell::Cell; use std::collections::BTreeMap; use std::net::{SocketAddr, Ipv4Addr, IpAddr, Ipv6Addr}; -use std::str::FromStr; use std::sync::{Arc, Mutex, Weak}; use std::sync::atomic::{AtomicBool, Ordering, AtomicPtr}; use std::time::Duration; @@ -41,7 +39,7 @@ const CONFIG_CHECK_INTERVAL: i64 = 5000; #[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct ServiceStatus { #[serde(rename = "objectType")] - pub object_type: &'static str, + pub object_type: String, pub address: Address, pub clock: i64, #[serde(rename = "startTime")] @@ -220,7 +218,7 @@ impl Service { let ver = zerotier_core::version(); self.node().map(|node| { ServiceStatus { - object_type: "status", + object_type: "status".to_owned(), address: node.address(), clock: ms_since_epoch(), start_time: self.intl.startup_time, diff --git a/service/src/utils.rs b/service/src/utils.rs index b6be57dba..e523cf277 100644 --- a/service/src/utils.rs +++ b/service/src/utils.rs @@ -11,7 +11,7 @@ */ /****/ -use std::error::Error; +use std::borrow::Borrow; use std::fs::File; use std::io::Read; use std::mem::MaybeUninit; @@ -20,6 +20,9 @@ use std::path::Path; use zerotier_core::{Identity, Locator}; +use serde::Serialize; +use serde::de::DeserializeOwned; + use crate::osdep; #[inline(always)] @@ -121,3 +124,44 @@ pub(crate) fn decrypt_http_auth_nonce(nonce: &str) -> u64 { nonce[0] } } + +/// Recursively patch a JSON object. +/// This is slightly different from a usual JSON merge. For objects in the target their fields +/// are updated by recursively calling json_patch if the same field is present in the source. +/// If the source tries to set an object to something other than another object, this is ignored. +/// Other fields are replaced. This is used for RESTful config object updates. The depth limit +/// field is to prevent stack overflows via the API. +pub(crate) fn json_patch(target: &mut serde_json::value::Value, source: &serde_json::value::Value, depth_limit: usize) { + if target.is_object() { + if source.is_object() { + let source = source.as_object().unwrap(); + for kv in target.as_object_mut().unwrap() { + let _ = source.get(kv.0).map(|new_value| { + if depth_limit > 0 { + json_patch(kv.1, new_value, depth_limit - 1) + } + }); + } + } + } else { + *target = source.clone(); + } +} + +/// Patch a serializable object with the fields present in a JSON object. +/// If there are no changes, None is returned. The depth limit is passed through to json_patch and +/// should be set to a sanity check value to prevent overflows. +pub(crate) fn json_patch_object(obj: O, patch: &str, depth_limit: usize) -> Result, serde_json::Error> { + serde_json::from_str::(patch).map_or_else(|e| Err(e), |patch| { + serde_json::value::to_value(obj.borrow()).map_or_else(|e| Err(e), |mut obj_value| { + json_patch(&mut obj_value, &patch, depth_limit); + serde_json::value::from_value::(obj_value).map_or_else(|e| Err(e), |obj_merged| { + if obj.eq(&obj_merged) { + Ok(None) + } else { + Ok(Some(obj_merged)) + } + }) + }) + }) +} diff --git a/service/src/vnic/common.rs b/service/src/vnic/common.rs index 3ab2e388b..78da34f54 100644 --- a/service/src/vnic/common.rs +++ b/service/src/vnic/common.rs @@ -12,10 +12,12 @@ /****/ use std::collections::BTreeSet; -use std::ptr::null_mut; +#[allow(unused_imports)] use zerotier_core::{MAC, MulticastGroup}; -use num_traits::cast::AsPrimitive; + +#[allow(unused_imports)] +use num_traits::AsPrimitive; use crate::osdep as osdep; @@ -25,7 +27,7 @@ pub(crate) fn get_l2_multicast_subscriptions(dev: &str) -> BTreeSet = BTreeSet::new(); let dev = dev.as_bytes(); unsafe { - let mut maddrs: *mut osdep::ifmaddrs = null_mut(); + let mut maddrs: *mut osdep::ifmaddrs = std::ptr::null_mut(); if osdep::getifmaddrs(&mut maddrs as *mut *mut osdep::ifmaddrs) == 0 { let mut i = maddrs; while !i.is_null() {