From aba212fd873bb5cad0c3311a47d88ce317d89e3d Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Wed, 30 Mar 2022 15:46:17 -0400 Subject: [PATCH] Loads of syncwhole work. --- .nova/Configuration.json | 3 + rustfmt.toml | 8 + syncwhole/rustfmt.toml | 1 + syncwhole/src/datastore.rs | 38 +- syncwhole/src/host.rs | 89 ++--- syncwhole/src/iblt.rs | 233 +++++------ syncwhole/src/main.rs | 78 +++- syncwhole/src/node.rs | 569 ++++++++++----------------- syncwhole/src/protocol.rs | 233 +++++++---- syncwhole/src/utils.rs | 52 ++- syncwhole/src/varint.rs | 56 +-- zerotier-system-service/src/utils.rs | 1 - 12 files changed, 647 insertions(+), 714 deletions(-) create mode 100644 .nova/Configuration.json create mode 100644 rustfmt.toml create mode 120000 syncwhole/rustfmt.toml diff --git a/.nova/Configuration.json b/.nova/Configuration.json new file mode 100644 index 000000000..2d92ea5ac --- /dev/null +++ b/.nova/Configuration.json @@ -0,0 +1,3 @@ +{ + "workspace.name" : "ZeroTier" +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000..bf8d97024 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,8 @@ +max_width = 300 +use_small_heuristics = "Max" +tab_spaces = 4 +newline_style = "Unix" +control_brace_style = "AlwaysSameLine" +edition = "2021" +imports_granularity = "Crate" +group_imports = "StdExternalCrate" diff --git a/syncwhole/rustfmt.toml b/syncwhole/rustfmt.toml new file mode 120000 index 000000000..39f97b043 --- /dev/null +++ b/syncwhole/rustfmt.toml @@ -0,0 +1 @@ +../rustfmt.toml \ No newline at end of file diff --git a/syncwhole/src/datastore.rs b/syncwhole/src/datastore.rs index 35d20686c..0890b513a 100644 --- a/syncwhole/src/datastore.rs +++ b/syncwhole/src/datastore.rs @@ -6,21 +6,35 @@ * https://www.zerotier.com/ */ +/// Size of keys, which is the size of a 512-bit hash. This is a protocol constant. +pub const KEY_SIZE: usize = 64; + +/// Minimum possible key (all zero). +pub const MIN_KEY: [u8; KEY_SIZE] = [0; KEY_SIZE]; + +/// Maximum possible key (all 0xff). +pub const MAX_KEY: [u8; KEY_SIZE] = [0xff; KEY_SIZE]; + /// Generate a range of SHA512 hashes from a prefix and a number of bits. /// The range will be inclusive and cover all keys under the prefix. -pub fn range_from_prefix(prefix: &[u8], prefix_bits: usize) -> ([u8; 64], [u8; 64]) { - let prefix_bits = prefix_bits.min(prefix.len() * 8).min(64); - let mut start = [0_u8; 64]; - let mut end = [0xff_u8; 64]; +pub fn range_from_prefix(prefix: &[u8], prefix_bits: usize) -> Option<([u8; KEY_SIZE], [u8; KEY_SIZE])> { + let mut start = [0_u8; KEY_SIZE]; + let mut end = [0xff_u8; KEY_SIZE]; + if prefix_bits > (KEY_SIZE * 8) { + return None; + } let whole_bytes = prefix_bits / 8; let remaining_bits = prefix_bits % 8; + if prefix.len() < (whole_bytes + ((remaining_bits != 0) as usize)) { + return None; + } start[0..whole_bytes].copy_from_slice(&prefix[0..whole_bytes]); end[0..whole_bytes].copy_from_slice(&prefix[0..whole_bytes]); - if remaining_bits != 0 && whole_bytes < prefix.len() { + if remaining_bits != 0 { start[whole_bytes] |= prefix[whole_bytes]; end[whole_bytes] &= prefix[whole_bytes] | ((0xff_u8).wrapping_shr(remaining_bits as u32)); } - (start, end) + return Some((start, end)); } /// Result returned by DataStore::load(). @@ -32,15 +46,13 @@ pub enum LoadResult + Send> { NotFound, /// Supplied reference_time is outside what is available (usually too old). - TimeNotAvailable + TimeNotAvailable, } /// Result returned by DataStore::store(). pub enum StoreResult { /// Entry was accepted. - /// The integer included with Ok is the reference time that should be advertised. - /// If this is not a temporally subjective data set then zero can be used. - Ok(i64), + Ok, /// Entry was a duplicate of one we already have but was otherwise valid. Duplicate, @@ -49,7 +61,7 @@ pub enum StoreResult { Ignored, /// Entry was rejected as malformed or otherwise invalid (e.g. failed signature check). - Rejected + Rejected, } /// API to be implemented by the data set we want to replicate. @@ -73,7 +85,7 @@ pub trait DataStore: Sync + Send { type LoadResultValueType: AsRef<[u8]> + Send; /// Key hash size, always 64 for SHA512. - const KEY_SIZE: usize = 64; + const KEY_SIZE: usize = KEY_SIZE; /// Maximum size of a value in bytes. const MAX_VALUE_SIZE: usize; @@ -98,7 +110,7 @@ pub trait DataStore: Sync + Send { fn contains(&self, reference_time: i64, key: &[u8]) -> bool { match self.load(reference_time, key) { LoadResult::Ok(_) => true, - _ => false + _ => false, } } diff --git a/syncwhole/src/host.rs b/syncwhole/src/host.rs index 23c5f3309..eb5445589 100644 --- a/syncwhole/src/host.rs +++ b/syncwhole/src/host.rs @@ -6,70 +6,50 @@ * https://www.zerotier.com/ */ -use std::collections::HashSet; use std::net::SocketAddr; #[cfg(feature = "include_sha2_lib")] use sha2::digest::{Digest, FixedOutput}; +use serde::{Deserialize, Serialize}; + use crate::node::RemoteNodeInfo; +/// Configuration setttings for a syncwhole node. +#[derive(Serialize, Deserialize, Clone, Eq, PartialEq)] +pub struct Config { + /// A list of peer addresses to which we always want to stay connected. + /// The library will try to maintain connectivity to these regardless of connection limits. + pub anchors: Vec, + + /// A list of peer addresses that we can try in order to achieve desired_connection_count. + pub seeds: Vec, + + /// The maximum number of TCP connections we should allow. + pub max_connection_count: usize, + + /// The desired number of peering links. + pub desired_connection_count: usize, + + /// An optional name for this node to advertise to other nodes. + pub name: String, + + /// An optional contact string for this node to advertise to other nodes. + /// Example: bighead@stanford.edu or https://www.piedpiper.com/ + pub contact: String, +} + /// A trait that users of syncwhole implement to provide configuration information and listen for events. pub trait Host: Sync + Send { - /// Get a list of peer addresses to which we always want to try to stay connected. - /// - /// These are always contacted until a link is established regardless of anything else. - fn fixed_peers(&self) -> &[SocketAddr]; - - /// Get a random peer address not in the supplied set. - /// - /// The default implementation just returns None. - fn another_peer(&self, exclude: &HashSet) -> Option { - None - } - - /// Get the maximum number of endpoints allowed. - /// - /// This is checked on incoming connect and incoming links are refused if the total is - /// over this count. Fixed endpoints will be contacted even if the total is over this limit. - /// - /// The default implementation returns 1024. - fn max_connection_count(&self) -> usize { - 1024 - } - - /// Get the number of connections we ideally want. - /// - /// Attempts will be made to lazily contact remote endpoints if the total number of links - /// is under this amount. Note that fixed endpoints will still be contacted even if the - /// total is over the desired count. - /// - /// This should always be less than max_connection_count(). - /// - /// The default implementation returns 128. - fn desired_connection_count(&self) -> usize { - 128 - } - - /// Get an optional name that this node should advertise. - /// - /// The default implementation returns None. - fn name(&self) -> Option<&str> { - None - } - - /// Get an optional contact info string that this node should advertise. - /// - /// The default implementation returns None. - fn contact(&self) -> Option<&str> { - None - } + /// Get a copy of the current configuration for this syncwhole node. + fn node_config(&self) -> Config; /// Test whether an inbound connection should be allowed from an address. /// /// This is called on first incoming connection before any init is received. The authenticate() /// method is called once init has been received and is another decision point. The default /// implementation of this always returns true. + #[allow(unused_variables)] fn allow(&self, remote_address: &SocketAddr) -> bool { true } @@ -83,6 +63,10 @@ pub trait Host: Sync + Send { /// /// This actually gets called twice per link: once when Init is received to compute the /// response, and once when InitResponse is received to verify the response to our challenge. + /// + /// The default implementation authenticates with an all-zero key. Leave it this way if + /// you don't want authentication. + #[allow(unused_variables)] fn authenticate(&self, info: &RemoteNodeInfo, challenge: &[u8]) -> Option<[u8; 64]> { Some(Self::hmac_sha512(&[0_u8; 64], challenge)) } @@ -125,8 +109,8 @@ pub trait Host: Sync + Send { /// /// Supplied key will always be 64 bytes in length. /// - /// The default implementation is a basic HMAC implemented in terms of sha512() above. This - /// can be specialized if the user wishes to provide their own implementation. + /// The default implementation is HMAC implemented in terms of sha512() above. Specialize + /// to provide your own implementation. fn hmac_sha512(key: &[u8], msg: &[u8]) -> [u8; 64] { let mut opad = [0x5c_u8; 128]; let mut ipad = [0x36_u8; 128]; @@ -137,7 +121,6 @@ pub trait Host: Sync + Send { for i in 0..64 { ipad[i] ^= key[i]; } - let s1 = Self::sha512(&[&ipad, msg]); - Self::sha512(&[&opad, &s1]) + Self::sha512(&[&opad, &Self::sha512(&[&ipad, msg])]) } } diff --git a/syncwhole/src/iblt.rs b/syncwhole/src/iblt.rs index 1ab441b5d..f787fd5ad 100644 --- a/syncwhole/src/iblt.rs +++ b/syncwhole/src/iblt.rs @@ -6,35 +6,19 @@ * https://www.zerotier.com/ */ -use std::mem::size_of; -use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut, write_bytes}; - -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use crate::utils::*; +/// Called to get the next iteration index for each KEY_MAPPING_ITERATIONS table lookup. +/// (See IBLT papers, etc.) #[inline(always)] -fn next_iteration_index(prev_iteration_index: u64) -> u64 { - splitmix64(prev_iteration_index.wrapping_add(1)) -} - -#[derive(Clone, PartialEq, Eq)] -#[repr(C, packed)] -struct IBLTEntry { - key_sum: u64, - check_hash_sum: u64, - count: i32 -} - -impl IBLTEntry { - #[inline(always)] - fn is_singular(&self) -> bool { - if i32::from_le(self.count) == 1 || i32::from_le(self.count) == -1 { - xorshift64(self.key_sum) == self.check_hash_sum - } else { - false - } - } +fn next_iteration_index(mut x: u64) -> u64 { + x = x.wrapping_add(1); + x ^= x.wrapping_shr(30); + x = x.wrapping_mul(0xbf58476d1ce4e5b9); + x ^= x.wrapping_shr(27); + x = x.wrapping_mul(0x94d049bb133111eb); + x ^= x.wrapping_shr(31); + x } /// An Invertible Bloom Lookup Table for set reconciliation with 64-bit hashes. @@ -42,170 +26,151 @@ impl IBLTEntry { /// Usage inspired by this paper: /// /// https://dash.harvard.edu/bitstream/handle/1/14398536/GENTILI-SENIORTHESIS-2015.pdf -#[derive(Clone, PartialEq, Eq)] -pub struct IBLT { - map: Vec +#[repr(C, packed)] +pub struct IBLT { + key: [u64; BUCKETS], + check_hash: [u64; BUCKETS], + count: [i8; BUCKETS], } -impl IBLT { +impl IBLT { /// This was determined to be effective via empirical testing with random keys. This /// is a protocol constant that can't be changed without upgrading all nodes in a domain. const KEY_MAPPING_ITERATIONS: usize = 2; - pub fn new(buckets: usize) -> Self { - assert!(buckets < i32::MAX as usize); - assert_eq!(size_of::(), 20); - let mut iblt = Self { map: Vec::with_capacity(buckets) }; - unsafe { - iblt.map.set_len(buckets); - iblt.reset(); + /// Number of buckets in this IBLT. + pub const BUCKETS: usize = BUCKETS; + + /// Size of this IBLT in bytes. + pub const SIZE_BYTES: usize = BUCKETS * (8 + 8 + 1); + + #[inline(always)] + fn is_singular(&self, i: usize) -> bool { + let c = self.count[i]; + if c == 1 || c == -1 { + xorshift64(self.key[i]) == self.check_hash[i] + } else { + false } - iblt } - /// Compute the size in bytes of an IBLT with the given number of buckets. - #[inline(always)] - pub fn size_bytes_with_buckets(buckets: usize) -> usize { buckets * size_of::() } + /// Create a new zeroed IBLT. + pub fn new() -> Self { + assert_eq!(Self::SIZE_BYTES, std::mem::size_of::()); + assert!(BUCKETS < (i32::MAX as usize)); + unsafe { std::mem::zeroed() } + } + + /// Cast a byte array to an IBLT if it is of the correct size. + pub fn ref_from_bytes(b: &[u8]) -> Option<&Self> { + if b.len() == Self::SIZE_BYTES { + Some(unsafe { &*b.as_ptr().cast() }) + } else { + None + } + } /// Compute the IBLT size in buckets to reconcile a given set difference, or return 0 if no advantage. /// This returns zero if an IBLT would take up as much or more space than just sending local_set_size /// hashes of hash_size_bytes. + #[inline(always)] pub fn calc_iblt_parameters(hash_size_bytes: usize, local_set_size: u64, difference_size: u64) -> usize { - let hashes_would_be = (hash_size_bytes as f64) * (local_set_size as f64); - let buckets_should_be = (difference_size as f64) * 1.8; // factor determined experimentally for best bytes/item, can be tuned - let iblt_would_be = buckets_should_be * (size_of::() as f64); - if iblt_would_be < hashes_would_be { - buckets_should_be.ceil() as usize + let b = (difference_size as f64) * 1.8; // factor determined experimentally for best bytes/item, can be tuned + if b > 64.0 && (b * (8.0 + 8.0 + 1.0)) < ((hash_size_bytes as f64) * (local_set_size as f64)) { + b.round() as usize } else { 0 } } - /// Get the size of this IBLT in buckets. - #[inline(always)] - pub fn buckets(&self) -> usize { self.map.len() } - - /// Get the size of this IBLT in bytes. - pub fn size_bytes(&self) -> usize { self.map.len() * size_of::() } - /// Zero this IBLT. #[inline(always)] pub fn reset(&mut self) { - unsafe { write_bytes(self.map.as_mut_ptr().cast::(), 0, self.map.len() * size_of::()); } + unsafe { + std::ptr::write_bytes((self as *mut Self).cast::(), 0, std::mem::size_of::()); + } } /// Get this IBLT as a byte slice in place. #[inline(always)] pub fn as_bytes(&self) -> &[u8] { - unsafe { &*slice_from_raw_parts(self.map.as_ptr().cast::(), self.map.len() * size_of::()) } + unsafe { &*std::ptr::slice_from_raw_parts((self as *const Self).cast::(), std::mem::size_of::()) } } - /// Construct an IBLT from an input reader and a size in bytes. - pub async fn new_from_reader(r: &mut R, bytes: usize) -> std::io::Result { - assert_eq!(size_of::(), 20); - if (bytes % size_of::()) != 0 { - Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "incomplete or invalid IBLT")) - } else { - let buckets = bytes / size_of::(); - let mut iblt = Self { map: Vec::with_capacity(buckets) }; - unsafe { - iblt.map.set_len(buckets); - r.read_exact(&mut *slice_from_raw_parts_mut(iblt.map.as_mut_ptr().cast::(), bytes)).await?; - } - Ok(iblt) - } - } - - /// Write this IBLT to a stream. - /// Note that the size of the IBLT in bytes must be stored separately. Use size_bytes() to get that. - #[inline(always)] - pub async fn write(&self, w: &mut W) -> std::io::Result<()> { - w.write_all(self.as_bytes()).await - } - - fn ins_rem(&mut self, mut key: u64, delta: i32) { - key = splitmix64(key); + fn ins_rem(&mut self, key: u64, delta: i8) { let check_hash = xorshift64(key); let mut iteration_index = u64::from_le(key); - let buckets = self.map.len(); for _ in 0..Self::KEY_MAPPING_ITERATIONS { iteration_index = next_iteration_index(iteration_index); - let b = unsafe { self.map.get_unchecked_mut((iteration_index as usize) % buckets) }; - b.key_sum ^= key; - b.check_hash_sum ^= check_hash; - b.count = (i32::from_le(b.count) + delta).to_le(); + let i = (iteration_index as usize) % BUCKETS; + self.key[i] ^= key; + self.check_hash[i] ^= check_hash; + self.count[i] = self.count[i].wrapping_add(delta); } } /// Insert a 64-bit key. - /// Panics if the key is shorter than 64 bits. If longer, bits beyond 64 are ignored. #[inline(always)] - pub fn insert(&mut self, key: &[u8]) { - assert!(key.len() >= 8); - self.ins_rem(unsafe { u64::from_ne_bytes(*(key.as_ptr().cast::<[u8; 8]>())) }, 1); + pub fn insert(&mut self, key: u64) { + self.ins_rem(key, 1); } /// Remove a 64-bit key. - /// Panics if the key is shorter than 64 bits. If longer, bits beyond 64 are ignored. #[inline(always)] - pub fn remove(&mut self, key: &[u8]) { - assert!(key.len() >= 8); - self.ins_rem(unsafe { u64::from_ne_bytes(*(key.as_ptr().cast::<[u8; 8]>())) }, -1); + pub fn remove(&mut self, key: u64) { + self.ins_rem(key, -1); } /// Subtract another IBLT from this one to get a set difference. - /// - /// This returns true on success or false on error, which right now can only happen if the - /// other IBLT has a different number of buckets or if it contains so many entries - pub fn subtract(&mut self, other: &Self) -> bool { - if other.map.len() == self.map.len() { - for (s, o) in self.map.iter_mut().zip(other.map.iter()) { - s.key_sum ^= o.key_sum; - s.check_hash_sum ^= o.check_hash_sum; - s.count = (i32::from_le(s.count) - i32::from_le(o.count)).to_le(); - } - return true; + pub fn subtract(&mut self, other: &Self) { + for i in 0..BUCKETS { + self.key[i] ^= other.key[i]; + } + for i in 0..BUCKETS { + self.check_hash[i] ^= other.check_hash[i]; + } + for i in 0..BUCKETS { + self.count[i] = self.count[i].wrapping_sub(other.count[i]); } - return false; } /// List as many entries in this IBLT as can be extracted. - pub fn list(mut self, mut f: F) { - let mut queue: Vec = Vec::with_capacity(self.map.len()); + pub fn list(mut self, mut f: F) { + let mut queue: Vec = Vec::with_capacity(BUCKETS); - for bi in 0..self.map.len() { - if unsafe { self.map.get_unchecked(bi).is_singular() } { - queue.push(bi as u32); + for i in 0..BUCKETS { + if self.is_singular(i) { + queue.push(i as u32); } } loop { - let b = queue.pop(); - let b = if b.is_some() { - unsafe { self.map.get_unchecked_mut(b.unwrap() as usize) } + let i = queue.pop(); + let i = if i.is_some() { + i.unwrap() as usize } else { break; }; - if b.is_singular() { - let key = b.key_sum; + + if self.is_singular(i) { + let key = self.key[i]; + + f(key); + let check_hash = xorshift64(key); let mut iteration_index = u64::from_le(key); - - f(&(splitmix64_inverse(key)).to_ne_bytes()); - for _ in 0..Self::KEY_MAPPING_ITERATIONS { iteration_index = next_iteration_index(iteration_index); - let b_idx = iteration_index % (self.map.len() as u64); - let b = unsafe { self.map.get_unchecked_mut(b_idx as usize) }; - b.key_sum ^= key; - b.check_hash_sum ^= check_hash; - b.count = (i32::from_le(b.count) - 1).to_le(); - - if b.is_singular() { - if queue.len() > self.map.len() { // sanity check for invalid IBLT + let i = (iteration_index as usize) % BUCKETS; + self.key[i] ^= key; + self.check_hash[i] ^= check_hash; + self.count[i] = self.count[i].wrapping_sub(1); + if self.is_singular(i) { + if queue.len() > BUCKETS { + // sanity check, should be impossible return; } - queue.push(b_idx as u32); + queue.push(i as u32); } } } @@ -234,19 +199,18 @@ mod tests { let mut count = 64; const CAPACITY: usize = 4096; while count <= CAPACITY { - let mut test: IBLT = IBLT::new(CAPACITY); + let mut test = IBLT::::new(); expected.clear(); for _ in 0..count { let x = rn; rn = splitmix64(rn); expected.insert(x); - test.insert(&x.to_ne_bytes()); + test.insert(x); } let mut list_count = 0; test.list(|x| { - let x = u64::from_ne_bytes(*x); list_count += 1; assert!(expected.contains(&x)); }); @@ -266,16 +230,16 @@ mod tests { let mut missing: HashSet = HashSet::with_capacity(CAPACITY); while missing_count <= CAPACITY { missing.clear(); - let mut local: IBLT = IBLT::new(CAPACITY); - let mut remote: IBLT = IBLT::new(CAPACITY); + let mut local = IBLT::::new(); + let mut remote = IBLT::::new(); for k in 0..REMOTE_SIZE { if k >= missing_count { - local.insert(&rn.to_ne_bytes()); + local.insert(rn); } else { missing.insert(rn); } - remote.insert(&rn.to_ne_bytes()); + remote.insert(rn); rn = splitmix64(rn); } @@ -283,7 +247,6 @@ mod tests { let bytes = local.as_bytes().len(); let mut cnt = 0; local.list(|k| { - let k = u64::from_ne_bytes(*k); assert!(missing.contains(&k)); cnt += 1; }); diff --git a/syncwhole/src/main.rs b/syncwhole/src/main.rs index bcea03ad7..e8325a1e9 100644 --- a/syncwhole/src/main.rs +++ b/syncwhole/src/main.rs @@ -1,22 +1,40 @@ extern crate core; use std::collections::BTreeMap; +use std::io::{stdout, Write}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::ops::Bound::Included; use std::sync::{Arc, Mutex}; -use std::time::{Duration, SystemTime}; +use std::time::{Duration, Instant, SystemTime}; use sha2::digest::Digest; use sha2::Sha512; use syncwhole::datastore::{DataStore, LoadResult, StoreResult}; use syncwhole::host::Host; -use syncwhole::ms_since_epoch; use syncwhole::node::{Node, RemoteNodeInfo}; use syncwhole::utils::*; -const TEST_NODE_COUNT: usize = 16; +const TEST_NODE_COUNT: usize = 8; const TEST_PORT_RANGE_START: u16 = 21384; +const TEST_STARTING_RECORDS_PER_NODE: usize = 16; + +static mut RANDOM_CTR: u128 = 0; + +fn get_random_bytes(mut buf: &mut [u8]) { + // This is only for testing and is not really secure. + let mut ctr = unsafe { RANDOM_CTR }; + if ctr == 0 { + ctr = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos() * (1 + Instant::now().elapsed().as_nanos()); + } + while !buf.is_empty() { + let l = buf.len().min(64); + ctr = ctr.wrapping_add(1); + buf[..l].copy_from_slice(&Sha512::digest(&ctr.to_ne_bytes()).as_slice()[..l]); + buf = &mut buf[l..]; + } + unsafe { RANDOM_CTR = ctr }; +} struct TestNodeHost { name: String, @@ -34,22 +52,16 @@ impl Host for TestNodeHost { } fn on_connect(&self, info: &RemoteNodeInfo) { - println!("{:5}: connected to {} ({}, {})", self.name, info.remote_address.to_string(), info.node_name.as_ref().map_or("null", |s| s.as_str()), if info.inbound { "inbound" } else { "outbound" }); + //println!("{:5}: connected to {} ({}, {})", self.name, info.remote_address.to_string(), info.node_name.as_ref().map_or("null", |s| s.as_str()), if info.inbound { "inbound" } else { "outbound" }); } fn on_connection_closed(&self, info: &RemoteNodeInfo, reason: String) { println!("{:5}: closed connection to {}: {} ({}, {})", self.name, info.remote_address.to_string(), reason, if info.inbound { "inbound" } else { "outbound" }, if info.initialized { "initialized" } else { "not initialized" }); } - fn get_secure_random(&self, mut buf: &mut [u8]) { + fn get_secure_random(&self, buf: &mut [u8]) { // This is only for testing and is not really secure. - let mut ctr = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos(); - while !buf.is_empty() { - let l = buf.len().min(64); - ctr = ctr.wrapping_add(1); - buf[0..l].copy_from_slice(&Self::sha512(&[&ctr.to_ne_bytes()])[0..l]); - buf = &mut buf[l..]; - } + get_random_bytes(buf); } } @@ -82,13 +94,17 @@ impl DataStore for TestNodeHost { } fn count(&self, _: i64, key_range_start: &[u8], key_range_end: &[u8]) -> u64 { - self.db.lock().unwrap().range((Included(key_range_start.try_into().unwrap()), Included(key_range_end.try_into().unwrap()))).count() as u64 + let s: [u8; 64] = key_range_start.try_into().unwrap(); + let e: [u8; 64] = key_range_end.try_into().unwrap(); + self.db.lock().unwrap().range((Included(s), Included(e))).count() as u64 } fn total_count(&self) -> u64 { self.db.lock().unwrap().len() as u64 } fn for_each bool>(&self, _: i64, key_range_start: &[u8], key_range_end: &[u8], mut f: F) { - for (k, v) in self.db.lock().unwrap().range((Included(key_range_start.try_into().unwrap()), Included(key_range_end.try_into().unwrap()))) { + let s: [u8; 64] = key_range_start.try_into().unwrap(); + let e: [u8; 64] = key_range_end.try_into().unwrap(); + for (k, v) in self.db.lock().unwrap().range((Included(s), Included(e))) { if !f(k, v.as_ref()) { break; } @@ -101,6 +117,7 @@ fn main() { println!("Running syncwhole local self-test network with {} nodes starting at 127.0.0.1:{}", TEST_NODE_COUNT, TEST_PORT_RANGE_START); println!(); + println!("Starting nodes on 127.0.0.1..."); let mut nodes: Vec> = Vec::with_capacity(TEST_NODE_COUNT); for port in TEST_PORT_RANGE_START..(TEST_PORT_RANGE_START + (TEST_NODE_COUNT as u16)) { let mut peers: Vec = Vec::with_capacity(TEST_NODE_COUNT); @@ -114,20 +131,43 @@ fn main() { peers, db: Mutex::new(BTreeMap::new()) }); - println!("Starting node on 127.0.0.1:{} with {} records in data store...", port, nh.db.lock().unwrap().len()); + //println!("Starting node on 127.0.0.1:{}...", port, nh.db.lock().unwrap().len()); nodes.push(Node::new(nh.clone(), nh.clone(), SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))).await.unwrap()); } - println!(); + print!("Waiting for all connections to be established..."); + let _ = stdout().flush(); loop { tokio::time::sleep(Duration::from_secs(1)).await; - /* let mut count = 0; for n in nodes.iter() { count += n.connection_count().await; } - println!("{}", count); - */ + if count == (TEST_NODE_COUNT * (TEST_NODE_COUNT - 1)) { + println!(" {} connections up.", count); + break; + } else { + print!("."); + let _ = stdout().flush(); + } + } + + println!("Populating maps with data to be synchronized between nodes..."); + let mut all_records = BTreeMap::new(); + for n in nodes.iter_mut() { + for _ in 0..TEST_STARTING_RECORDS_PER_NODE { + let mut k = [0_u8; 64]; + let mut v = [0_u8; 32]; + get_random_bytes(&mut k); + get_random_bytes(&mut v); + let v: Arc<[u8]> = Arc::from(v); + all_records.insert(k.clone(), v.clone()); + n.datastore().db.lock().unwrap().insert(k, v); + } + } + + loop { + tokio::time::sleep(Duration::from_secs(1)).await; } }); } diff --git a/syncwhole/src/node.rs b/syncwhole/src/node.rs index 0531a94f5..a42894fd4 100644 --- a/syncwhole/src/node.rs +++ b/syncwhole/src/node.rs @@ -11,13 +11,13 @@ use std::io::IoSlice; use std::mem::MaybeUninit; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::ops::Add; -use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering}; +use std::sync::Arc; use serde::{Deserialize, Serialize}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; -use tokio::net::{TcpListener, TcpSocket, TcpStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio::time::{Duration, Instant}; @@ -30,25 +30,19 @@ use crate::utils::*; use crate::varint; /// Inactivity timeout for connections in milliseconds. -const CONNECTION_TIMEOUT: i64 = 120000; - -/// How often to send STATUS messages in milliseconds. -const STATUS_INTERVAL: i64 = 10000; +const CONNECTION_TIMEOUT: i64 = SYNC_STATUS_PERIOD * 4; /// How often to run the housekeeping task's loop in milliseconds. -const HOUSEKEEPING_INTERVAL: i64 = STATUS_INTERVAL / 2; - -/// Size of read buffer, which is used to reduce the number of syscalls. -const READ_BUFFER_SIZE: usize = 16384; +const HOUSEKEEPING_INTERVAL: i64 = SYNC_STATUS_PERIOD; /// Information about a remote node to which we are connected. #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RemoteNodeInfo { /// Optional name advertised by remote node (arbitrary). - pub node_name: Option, + pub name: String, /// Optional contact information advertised by remote node (arbitrary). - pub node_contact: Option, + pub contact: String, /// Actual remote endpoint address. pub remote_address: SocketAddr, @@ -78,6 +72,10 @@ fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> { } } +fn decode_msgpack<'a, T: Deserialize<'a>>(b: &'a [u8]) -> std::io::Result { + rmp_serde::from_slice(b).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid msgpack object: {}", e.to_string()))) +} + /// An instance of the syncwhole data set synchronization engine. /// /// This holds a number of async tasks that are terminated or aborted if this object @@ -85,7 +83,7 @@ fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> { pub struct Node { internal: Arc>, housekeeping_task: JoinHandle<()>, - listener_task: JoinHandle<()> + listener_task: JoinHandle<()>, } impl Node { @@ -105,18 +103,21 @@ impl Node { host: host.clone(), connections: Mutex::new(HashMap::with_capacity(64)), bind_address, + starting_instant: Instant::now(), }); - Ok(Self { - internal: internal.clone(), - housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()), - listener_task: tokio::spawn(internal.listener_task_main(listener)), - }) + Ok(Self { internal: internal.clone(), housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()), listener_task: tokio::spawn(internal.listener_task_main(listener)) }) } - pub fn datastore(&self) -> &Arc { &self.internal.datastore } + #[inline(always)] + pub fn datastore(&self) -> &Arc { + &self.internal.datastore + } - pub fn host(&self) -> &Arc { &self.internal.host } + #[inline(always)] + pub fn host(&self) -> &Arc { + &self.internal.host + } pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result { self.internal.clone().connect(endpoint, Instant::now().add(Duration::from_millis(CONNECTION_TIMEOUT as u64))).await @@ -126,7 +127,7 @@ impl Node { let connections = self.internal.connections.lock().await; let mut cl: Vec = Vec::with_capacity(connections.len()); for (_, c) in connections.iter() { - cl.push(c.0.info.blocking_lock().clone()); + cl.push(c.info.lock().await.clone()); } cl } @@ -152,16 +153,22 @@ pub struct NodeInternal { host: Arc, // Connections and their task join handles, by remote endpoint address. - connections: Mutex, Option>>)>>, + connections: Mutex>>, // Local address to which this node is bound bind_address: SocketAddr, + + // Instant this node started. + starting_instant: Instant, } impl NodeInternal { + fn ms_monotonic(&self) -> i64 { + Instant::now().duration_since(self.starting_instant).as_millis() as i64 + } + /// Loop that constantly runs in the background to do cleanup and service things. async fn housekeeping_task_main(self: Arc) { - let mut last_status_sent = ms_monotonic(); let mut tasks: Vec> = Vec::new(); let mut connected_to_addresses: HashSet = HashSet::new(); let mut sleep_until = Instant::now().add(Duration::from_millis(500)); @@ -171,50 +178,17 @@ impl NodeInternal { tasks.clear(); connected_to_addresses.clear(); - let now = ms_monotonic(); + let now = self.ms_monotonic(); - // Check connection timeouts, send status updates, and garbage collect from the connections map. - // Status message outputs are backgrounded since these can block of TCP buffers are nearly full. - // A timeout based on the service loop interval is used. Usually these sends will finish instantly - // but if they take too long this typically means the link is dead. We wait for all tasks at the - // end of the service loop. The on_connection_closed() method in 'host' is called in sub-tasks to - // prevent the possibility of deadlocks on self.connections.lock() if the Host implementation calls - // something that tries to lock it. - let status = if (now - last_status_sent) >= STATUS_INTERVAL { - last_status_sent = now; - Some(msg::Status { - total_record_count: self.datastore.total_count(), - reference_time: self.datastore.clock() - }) - } else { - None - }; self.connections.lock().await.retain(|sa, c| { - let cc = &(c.0); - if !cc.closed.load(Ordering::Relaxed) { - if (now - cc.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT { + if !c.closed.load(Ordering::Relaxed) { + if (now - c.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT { connected_to_addresses.insert(sa.clone()); - let _ = status.as_ref().map(|status| { - let status = status.clone(); - let self2 = self.clone(); - let cc = cc.clone(); - let sa = sa.clone(); - tasks.push(tokio::spawn(async move { - if cc.info.lock().await.initialized { - // This almost always completes instantly due to queues, but add a timeout in case connection - // is stalled. In this case the result is a closed connection. - if !tokio::time::timeout_at(sleep_until, cc.send_obj(MESSAGE_TYPE_STATUS, &status, now)).await.map_or(false, |r| r.is_ok()) { - let _ = self2.connections.lock().await.remove(&sa).map(|c| c.1.map(|j| j.abort())); - self2.host.on_connection_closed(&*cc.info.lock().await, "write overflow (timeout)".to_string()); - } - } - })); - }); true // keep connection } else { - let _ = c.1.take().map(|j| j.abort()); + let _ = c.read_task.lock().unwrap().take().map(|j| j.abort()); let host = self.host.clone(); - let cc = cc.clone(); + let cc = c.clone(); tasks.push(tokio::spawn(async move { host.on_connection_closed(&*cc.info.lock().await, "timeout".to_string()); })); @@ -222,8 +196,8 @@ impl NodeInternal { } } else { let host = self.host.clone(); - let cc = cc.clone(); - let j = c.1.take(); + let cc = c.clone(); + let j = c.read_task.lock().unwrap().take(); tasks.push(tokio::spawn(async move { if j.is_some() { let e = j.unwrap().await; @@ -241,9 +215,10 @@ impl NodeInternal { } }); - // Always try to connect to fixed peers. - let fixed_peers = self.host.fixed_peers(); - for sa in fixed_peers.iter() { + let config = self.host.node_config(); + + // Always try to connect to anchor peers. + for sa in config.anchors.iter() { if !connected_to_addresses.contains(sa) { let sa = sa.clone(); let self2 = self.clone(); @@ -255,18 +230,18 @@ impl NodeInternal { } // Try to connect to more peers until desired connection count is reached. - let desired_connection_count = self.host.desired_connection_count().min(self.host.max_connection_count()); - while connected_to_addresses.len() < desired_connection_count { - let sa = self.host.another_peer(&connected_to_addresses); - if sa.is_some() { - let sa = sa.unwrap(); + let desired_connection_count = config.desired_connection_count.min(config.max_connection_count); + for sa in config.seeds.iter() { + if connected_to_addresses.len() >= desired_connection_count { + break; + } + if !connected_to_addresses.contains(sa) { + connected_to_addresses.insert(sa.clone()); let self2 = self.clone(); + let sa = sa.clone(); tasks.push(tokio::spawn(async move { let _ = self2.connect(&sa, sleep_until).await; })); - connected_to_addresses.insert(sa.clone()); - } else { - break; } } @@ -289,7 +264,8 @@ impl NodeInternal { if socket.is_ok() { let (stream, address) = socket.unwrap(); if self.host.allow(&address) { - if self.connections.lock().await.len() < self.host.max_connection_count() || self.host.fixed_peers().contains(&address) { + let config = self.host.node_config(); + if self.connections.lock().await.len() < config.max_connection_count || config.anchors.contains(&address) { Self::connection_start(&self, address, stream, true).await; } } @@ -316,33 +292,27 @@ impl NodeInternal { let mut ok = false; let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| { ok = true; - let _ = stream.set_nodelay(true); + //let _ = stream.set_nodelay(true); let (reader, writer) = stream.into_split(); - let now = ms_monotonic(); + let now = self.ms_monotonic(); let connection = Arc::new(Connection { writer: Mutex::new(writer), last_send_time: AtomicI64::new(now), last_receive_time: AtomicI64::new(now), bytes_sent: AtomicU64::new(0), bytes_received: AtomicU64::new(0), - info: Mutex::new(RemoteNodeInfo { - node_name: None, - node_contact: None, - remote_address: address.clone(), - explicit_addresses: Vec::new(), - connect_time: ms_since_epoch(), - connect_instant: ms_monotonic(), - inbound, - initialized: false - }), - closed: AtomicBool::new(false) + info: Mutex::new(RemoteNodeInfo { name: String::new(), contact: String::new(), remote_address: address.clone(), explicit_addresses: Vec::new(), connect_time: ms_since_epoch(), connect_instant: now, inbound, initialized: false }), + read_task: std::sync::Mutex::new(None), + closed: AtomicBool::new(false), }); let self2 = self.clone(); - (connection.clone(), Some(tokio::spawn(async move { - let result = self2.connection_io_task_main(&connection, BufReader::with_capacity(READ_BUFFER_SIZE, reader)).await; - connection.closed.store(true, Ordering::Relaxed); + let c2 = connection.clone(); + connection.read_task.lock().unwrap().replace(tokio::spawn(async move { + let result = self2.connection_io_task_main(&c2, reader).await; + c2.closed.store(true, Ordering::Relaxed); result - }))) + })); + connection }); ok } @@ -351,53 +321,96 @@ impl NodeInternal { /// /// This handles reading from the connection and reacting to what it sends. Killing this /// task is done when the connection is closed. - async fn connection_io_task_main(self: Arc, connection: &Arc, mut reader: BufReader) -> std::io::Result<()> { + async fn connection_io_task_main(self: Arc, connection: &Arc, mut reader: OwnedReadHalf) -> std::io::Result<()> { + const BUF_CHUNK_SIZE: usize = 4096; + const READ_BUF_INITIAL_SIZE: usize = 65536; // should be a multiple of BUF_CHUNK_SIZE + + let background_tasks = AsyncTaskReaper::new(); + let mut write_buffer: Vec = Vec::with_capacity(BUF_CHUNK_SIZE); + let mut read_buffer: Vec = Vec::new(); + read_buffer.resize(READ_BUF_INITIAL_SIZE, 0); + + let config = self.host.node_config(); let mut anti_loopback_challenge_sent = [0_u8; 64]; let mut domain_challenge_sent = [0_u8; 64]; let mut auth_challenge_sent = [0_u8; 64]; self.host.get_secure_random(&mut anti_loopback_challenge_sent); self.host.get_secure_random(&mut domain_challenge_sent); self.host.get_secure_random(&mut auth_challenge_sent); - connection.send_obj(MESSAGE_TYPE_INIT, &msg::Init { - anti_loopback_challenge: &anti_loopback_challenge_sent, - domain_challenge: &domain_challenge_sent, - auth_challenge: &auth_challenge_sent, - node_name: self.host.name().map(|n| n.to_string()), - node_contact: self.host.contact().map(|c| c.to_string()), - locally_bound_port: self.bind_address.port(), - explicit_ipv4: None, - explicit_ipv6: None - }, ms_monotonic()).await?; + connection + .send_obj( + &mut write_buffer, + MessageType::Init, + &msg::Init { + anti_loopback_challenge: &anti_loopback_challenge_sent, + domain_challenge: &domain_challenge_sent, + auth_challenge: &auth_challenge_sent, + node_name: config.name.as_str(), + node_contact: config.contact.as_str(), + locally_bound_port: self.bind_address.port(), + explicit_ipv4: None, + explicit_ipv6: None, + }, + self.ms_monotonic(), + ) + .await?; + drop(config); + let max_message_size = ((D::MAX_VALUE_SIZE * 8) + (D::KEY_SIZE * 1024) + 65536) as u64; // sanity limit let mut initialized = false; - let background_tasks = AsyncTaskReaper::new(); let mut init_received = false; - let mut buf: Vec = Vec::new(); - buf.resize(4096, 0); + let mut buffer_fill = 0_usize; loop { - let message_type = reader.read_u8().await?; - let message_size = varint::read_async(&mut reader).await?; - let header_size = 1 + message_size.1; - let message_size = message_size.0; - if message_size > (D::MAX_VALUE_SIZE + ((D::KEY_SIZE + 10) * 256) + 65536) as u64 { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large")); - } + let message_type: MessageType; + let message_size: usize; + let header_size: usize; + let total_size: usize; + loop { + buffer_fill += reader.read(&mut read_buffer.as_mut_slice()[buffer_fill..]).await?; + if buffer_fill >= 2 { + // type and at least one byte of varint + let ms = varint::decode(&read_buffer.as_slice()[1..]); + if ms.1 > 0 { + // varint is all there and parsed correctly + if ms.0 > max_message_size { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large")); + } - let now = ms_monotonic(); + message_type = MessageType::from(*read_buffer.get(0).unwrap()); + message_size = ms.0 as usize; + header_size = 1 + ms.1; + total_size = header_size + message_size; + + if read_buffer.len() < total_size { + read_buffer.resize(((total_size / BUF_CHUNK_SIZE) + 1) * BUF_CHUNK_SIZE, 0); + } + while buffer_fill < total_size { + buffer_fill += reader.read(&mut read_buffer.as_mut_slice()[buffer_fill..]).await?; + } + + break; + } + } + } + let message = &read_buffer.as_slice()[header_size..total_size]; + + let now = self.ms_monotonic(); connection.last_receive_time.store(now, Ordering::Relaxed); match message_type { + MessageType::Nop => {} - MESSAGE_TYPE_INIT => { + MessageType::Init => { if init_received { return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init")); } + init_received = true; - let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; + let msg: msg::Init = decode_msgpack(message)?; let (anti_loopback_response, domain_challenge_response, auth_challenge_response) = { let mut info = connection.info.lock().await; - info.node_name = msg.node_name.clone(); - info.node_contact = msg.node_contact.clone(); + info.name = msg.node_name.to_string(); + info.contact = msg.node_contact.to_string(); let _ = msg.explicit_ipv4.map(|pv4| { info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port))); }); @@ -409,260 +422,121 @@ impl NodeInternal { if auth_challenge_response.is_none() { return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped")); } - ( - H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), - H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), msg.domain_challenge), - auth_challenge_response.unwrap() - ) + + (H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), msg.domain_challenge), auth_challenge_response.unwrap()) }; - connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse { - anti_loopback_response: &anti_loopback_response, - domain_response: &domain_challenge_response, - auth_response: &auth_challenge_response - }, now).await?; + connection.send_obj(&mut write_buffer, MessageType::InitResponse, &msg::InitResponse { anti_loopback_response: &anti_loopback_response, domain_response: &domain_challenge_response, auth_response: &auth_challenge_response }, now).await?; + } - init_received = true; - }, - - MESSAGE_TYPE_INIT_RESPONSE => { - let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; + MessageType::InitResponse => { + let msg: msg::InitResponse = decode_msgpack(message)?; let mut info = connection.info.lock().await; if info.initialized { return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init response")); } - info.initialized = true; - let info = info.clone(); if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) { return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self")); } - if msg.domain_response.eq(&H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), &domain_challenge_sent)) { + if !msg.domain_response.eq(&H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), &domain_challenge_sent)) { return Err(std::io::Error::new(std::io::ErrorKind::Other, "domain mismatch")); } if !self.host.authenticate(&info, &auth_challenge_sent).map_or(false, |cr| msg.auth_response.eq(&cr)) { return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed")); } - self.host.on_connect(&info); + info.initialized = true; initialized = true; - }, + let info = info.clone(); + self.host.on_connect(&info); + } + + // Handle messages other than INIT and INIT_RESPONSE after checking 'initialized' flag. _ => { if !initialized { return Err(std::io::Error::new(std::io::ErrorKind::Other, "init exchange must be completed before other messages are sent")); } match message_type { + _ => {} - MESSAGE_TYPE_STATUS => { - let msg: msg::Status = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; - self.connection_request_summary(connection, msg.total_record_count, now, msg.reference_time).await?; - }, + MessageType::HaveRecords => { + let msg: msg::HaveRecords = decode_msgpack(message)?; + } - MESSAGE_TYPE_GET_SUMMARY => { - //let msg: msg::GetSummary = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; - }, + MessageType::GetRecords => { + let msg: msg::GetRecords = decode_msgpack(message)?; + } - MESSAGE_TYPE_SUMMARY => { - let mut remaining = message_size as isize; - - // Read summary header. - let summary_header_size = varint::read_async(&mut reader).await?; - remaining -= summary_header_size.1 as isize; - let summary_header_size = summary_header_size.0; - if (summary_header_size as i64) > (remaining as i64) { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid summary header")); - } - let summary_header: msg::SummaryHeader = connection.read_obj(&mut reader, &mut buf, summary_header_size as usize).await?; - remaining -= summary_header_size as isize; - - // Read and evaluate summary that we were sent. - match summary_header.summary_type { - SUMMARY_TYPE_KEYS => { - self.connection_receive_and_process_remote_hash_list( - connection, - remaining, - &mut reader, - now, - summary_header.reference_time, - &summary_header.prefix[0..summary_header.prefix.len().min((summary_header.prefix_bits / 8) as usize)]).await?; - }, - SUMMARY_TYPE_IBLT => { - //let summary = IBLT::new_from_reader(&mut reader, remaining as usize).await?; - }, - _ => {} // ignore unknown summary types - } - - // Request another summary if needed, keeping a ping-pong game going in a tight loop until synced. - self.connection_request_summary(connection, summary_header.total_record_count, now, summary_header.reference_time).await?; - }, - - MESSAGE_TYPE_HAVE_RECORDS => { - let mut remaining = message_size as isize; - let reference_time = varint::read_async(&mut reader).await?; - remaining -= reference_time.1 as isize; - if remaining <= 0 { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid message")); - } - self.connection_receive_and_process_remote_hash_list(connection, remaining, &mut reader, now, reference_time.0 as i64, &[]).await? - }, - - MESSAGE_TYPE_GET_RECORDS => { - }, - - MESSAGE_TYPE_RECORD => { - let value = connection.read_msg(&mut reader, &mut buf, message_size as usize).await?; - if value.len() > D::MAX_VALUE_SIZE { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "value larger than MAX_VALUE_SIZE")); - } - let key = H::sha512(&[value]); - match self.datastore.store(&key, value) { - StoreResult::Ok(reference_time) => { - let mut have_records_msg = [0_u8; 2 + 10 + ANNOUNCE_HASH_BYTES]; - let mut msg_len = varint::encode(&mut have_records_msg, reference_time as u64); - have_records_msg[msg_len] = ANNOUNCE_HASH_BYTES as u8; - msg_len += 1; - have_records_msg[msg_len..(msg_len + ANNOUNCE_HASH_BYTES)].copy_from_slice(&key[..ANNOUNCE_HASH_BYTES]); - msg_len += ANNOUNCE_HASH_BYTES; - - let self2 = self.clone(); - let connection2 = connection.clone(); - background_tasks.spawn(async move { - let connections = self2.connections.lock().await; - let mut recipients = Vec::with_capacity(connections.len()); - for (_, c) in connections.iter() { - if !Arc::ptr_eq(&(c.0), &connection2) { - recipients.push(c.0.clone()); - } + MessageType::Record => { + let key = H::sha512(&[message]); + match self.datastore.store(&key, message) { + StoreResult::Ok => { + // TODO: probably should not announce if way out of sync + let connections = self.connections.lock().await; + let mut announce_to: Vec> = Vec::with_capacity(connections.len()); + for (_, c) in connections.iter() { + if !Arc::ptr_eq(&connection, c) { + announce_to.push(c.clone()); } - drop(connections); // release lock + } + drop(connections); // release lock - for c in recipients.iter() { - // This typically completes instantly due to buffering, as this message is small. - // Add a small timeout in the case that some connections are stalled. Misses will - // not impact the overall network much. - let _ = tokio::time::timeout(Duration::from_millis(250), c.send_msg(MESSAGE_TYPE_HAVE_RECORDS, &have_records_msg[..msg_len], now)).await; + background_tasks.spawn(async move { + for c in announce_to.iter() { + let _ = c.send_msg(MessageType::HaveRecord, &key[0..ANNOUNCE_KEY_LEN], now).await; } }); - }, + } StoreResult::Rejected => { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid datum received")); - }, - _ => {} // duplicate or ignored values are just... ignored - } - }, - - _ => { - // Skip messages that aren't recognized or don't need to be parsed. - let mut remaining = message_size as usize; - while remaining > 0 { - let s = remaining.min(buf.len()); - reader.read_exact(&mut buf.as_mut_slice()[0..s]).await?; - remaining -= s; + return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("record rejected by data store: {}", to_hex_string(&key)))); + } + _ => {} } } + MessageType::SyncStatus => { + let msg: msg::SyncStatus = decode_msgpack(message)?; + } + + MessageType::SyncRequest => { + let msg: msg::SyncRequest = decode_msgpack(message)?; + } + + MessageType::SyncResponse => { + let msg: msg::SyncResponse = decode_msgpack(message)?; + } } } } - connection.bytes_received.fetch_add((header_size as u64) + message_size, Ordering::Relaxed); + read_buffer.copy_within(total_size..buffer_fill, 0); + buffer_fill -= total_size; + + connection.bytes_received.fetch_add(total_size as u64, Ordering::Relaxed); } } - - /// Request a summary if needed, or do nothing if not. - /// - /// This is where all the logic lives that determines whether to request summaries, choosing a - /// prefix, etc. It's called when the remote node tells us its total record count with an - /// associated reference time, which happens in status announcements and in summaries. - async fn connection_request_summary(&self, connection: &Arc, total_record_count: u64, now: i64, reference_time: i64) -> std::io::Result<()> { - let my_total_record_count = self.datastore.total_count(); - if my_total_record_count < total_record_count { - // Figure out how many bits need to be in a randomly chosen prefix to choose a slice of - // the data set such that the set difference should be around 4096 records. This assumes - // random distribution, which should be mostly maintained by probing prefixes at random. - let prefix_bits = ((total_record_count - my_total_record_count) as f64) / 4096.0; - let prefix_bits = if prefix_bits > 1.0 { - (prefix_bits.log2().ceil() as usize).min(64) - } else { - 0 as usize - }; - let prefix_bytes = (prefix_bits / 8) + (((prefix_bits % 8) != 0) as usize); - - // Generate a random prefix of this many bits (to the nearest byte). - let mut prefix = [0_u8; 64]; - self.host.get_secure_random(&mut prefix[..prefix_bytes]); - - // Request a set summary for this prefix, providing our own count for this prefix so - // the remote can decide whether to send something like an IBLT or just hashes. - let (local_range_start, local_range_end) = range_from_prefix(&prefix, prefix_bits); - connection.send_obj(MESSAGE_TYPE_GET_SUMMARY, &msg::GetSummary { - reference_time, - prefix: &prefix[..prefix_bytes], - prefix_bits: prefix_bits as u8, - record_count: self.datastore.count(reference_time, &local_range_start, &local_range_end) - }, now).await - } else { - Ok(()) - } - } - - /// Read a stream of record hashes (or hash prefixes) from a connection and request records we don't have. - async fn connection_receive_and_process_remote_hash_list(&self, connection: &Arc, mut remaining: isize, reader: &mut BufReader, now: i64, reference_time: i64, common_prefix: &[u8]) -> std::io::Result<()> { - if remaining > 0 { - // Hash list is prefaced by the number of bytes in each hash, since whole 64 byte hashes do not have to be sent. - let prefix_entry_size = reader.read_u8().await? as usize; - let total_prefix_size = common_prefix.len() + prefix_entry_size; - - if prefix_entry_size > 0 && total_prefix_size <= 64 { - remaining -= 1; - if remaining >= (prefix_entry_size as isize) { - let mut get_records_msg: Vec = Vec::with_capacity(((remaining as usize) / prefix_entry_size) * total_prefix_size); - varint::write(&mut get_records_msg, reference_time as u64)?; - get_records_msg.push(total_prefix_size as u8); - - let mut key_prefix_buf = [0_u8; 64]; - key_prefix_buf[..common_prefix.len()].copy_from_slice(common_prefix); - - while remaining >= (prefix_entry_size as isize) { - remaining -= prefix_entry_size as isize; - reader.read_exact(&mut key_prefix_buf[common_prefix.len()..total_prefix_size]).await?; - - if if total_prefix_size < 64 { - let (s, e) = range_from_prefix(&key_prefix_buf[..total_prefix_size], total_prefix_size * 8); - self.datastore.count(reference_time, &s, &e) == 0 - } else { - !self.datastore.contains(reference_time, &key_prefix_buf) - } { - let _ = get_records_msg.write_all(&key_prefix_buf[..total_prefix_size]); - } - } - - if remaining == 0 { - return connection.send_msg(MESSAGE_TYPE_GET_RECORDS, get_records_msg.as_slice(), now).await; - } - } - } - } - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hash list")); - } } impl Drop for NodeInternal { fn drop(&mut self) { - let _ = tokio::runtime::Handle::try_current().map_or_else(|_| { - for (_, c) in self.connections.blocking_lock().drain() { - c.1.map(|c| c.abort()); - } - }, |h| { - let _ = h.block_on(async { - for (_, c) in self.connections.lock().await.drain() { - c.1.map(|c| c.abort()); + let _ = tokio::runtime::Handle::try_current().map_or_else( + |_| { + for (_, c) in self.connections.blocking_lock().drain() { + c.read_task.lock().unwrap().as_mut().map(|c| c.abort()); } - }); - }); + }, + |h| { + let _ = h.block_on(async { + for (_, c) in self.connections.lock().await.drain() { + c.read_task.lock().unwrap().as_mut().map(|c| c.abort()); + } + }); + }, + ); } } @@ -673,53 +547,30 @@ struct Connection { bytes_sent: AtomicU64, bytes_received: AtomicU64, info: Mutex, + read_task: std::sync::Mutex>>>, closed: AtomicBool, } impl Connection { - async fn send_msg(&self, message_type: u8, data: &[u8], now: i64) -> std::io::Result<()> { - let mut type_and_size = [0_u8; 16]; - type_and_size[0] = message_type; - let tslen = 1 + varint::encode(&mut type_and_size[1..], data.len() as u64) as usize; - let total_size = tslen + data.len(); - if self.writer.lock().await.write_vectored(&[IoSlice::new(&type_and_size[..tslen]), IoSlice::new(data)]).await? == total_size { + async fn send_msg(&self, message_type: MessageType, data: &[u8], now: i64) -> std::io::Result<()> { + let mut header: [u8; 16] = unsafe { MaybeUninit::uninit().assume_init() }; + header[0] = message_type as u8; + let header_size = 1 + varint::encode(&mut header[1..], data.len() as u64); + if self.writer.lock().await.write_vectored(&[IoSlice::new(&header[0..header_size]), IoSlice::new(data)]).await? == (data.len() + header_size) { self.last_send_time.store(now, Ordering::Relaxed); - self.bytes_sent.fetch_add(total_size as u64, Ordering::Relaxed); + self.bytes_sent.fetch_add((header_size + data.len()) as u64, Ordering::Relaxed); Ok(()) } else { Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error")) } } - async fn send_obj(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> { - let data = rmp_serde::encode::to_vec_named(&obj); - if data.is_ok() { - let data = data.unwrap(); - let mut header: [u8; 16] = unsafe { MaybeUninit::uninit().assume_init() }; - header[0] = message_type; - let header_size = 1 + varint::encode(&mut header[1..], data.len() as u64); - if self.writer.lock().await.write_vectored(&[IoSlice::new(&header[0..header_size]), IoSlice::new(data.as_slice())]).await? == (data.len() + header_size) { - self.last_send_time.store(now, Ordering::Relaxed); - self.bytes_sent.fetch_add((header_size + data.len()) as u64, Ordering::Relaxed); - Ok(()) - } else { - Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error")) - } + async fn send_obj(&self, write_buf: &mut Vec, message_type: MessageType, obj: &O, now: i64) -> std::io::Result<()> { + write_buf.clear(); + if rmp_serde::encode::write_named(write_buf, obj).is_ok() { + self.send_msg(message_type, write_buf.as_slice(), now).await } else { Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure (internal error)")) } } - - async fn read_msg<'a, R: AsyncReadExt + Unpin>(&self, reader: &mut R, buf: &'a mut Vec, message_size: usize) -> std::io::Result<&'a [u8]> { - if message_size > buf.len() { - buf.resize(((message_size / 4096) + 1) * 4096, 0); - } - let b = &mut buf.as_mut_slice()[0..message_size]; - reader.read_exact(b).await?; - Ok(b) - } - - async fn read_obj<'a, R: AsyncReadExt + Unpin, O: Deserialize<'a>>(&self, reader: &mut R, buf: &'a mut Vec, message_size: usize) -> std::io::Result { - rmp_serde::from_read_ref(self.read_msg(reader, buf, message_size).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid msgpack: {}", e.to_string()))) - } } diff --git a/syncwhole/src/protocol.rs b/syncwhole/src/protocol.rs index 5e50fbd05..355c99e5a 100644 --- a/syncwhole/src/protocol.rs +++ b/syncwhole/src/protocol.rs @@ -6,55 +6,108 @@ * https://www.zerotier.com/ */ -/// No operation, payload ignored. -pub const MESSAGE_TYPE_NOP: u8 = 0; +/// Number of bytes of SHA512 to announce, should be high enough to make collisions virtually impossible. +pub const ANNOUNCE_KEY_LEN: usize = 24; -/// Sent by both sides of a TCP link when it's established. -pub const MESSAGE_TYPE_INIT: u8 = 1; +/// Send SyncStatus this frequently, in milliseconds. +pub const SYNC_STATUS_PERIOD: i64 = 5000; -/// Reply sent to INIT. -pub const MESSAGE_TYPE_INIT_RESPONSE: u8 = 2; +#[derive(Clone, Copy, Eq, PartialEq)] +#[repr(u8)] +pub enum MessageType { + Nop = 0_u8, + Init = 1_u8, + InitResponse = 2_u8, + HaveRecord = 3_u8, + HaveRecords = 4_u8, + GetRecords = 5_u8, + Record = 6_u8, + SyncStatus = 7_u8, + SyncRequest = 8_u8, + SyncResponse = 9_u8, +} -/// Sent every few seconds to notify peers of number of records, clock, etc. -pub const MESSAGE_TYPE_STATUS: u8 = 3; +impl From for MessageType { + /// Get a type from a byte, returning the Nop type if the byte is out of range. + #[inline(always)] + fn from(b: u8) -> Self { + if b <= 7 { + unsafe { std::mem::transmute(b) } + } else { + Self::Nop + } + } +} -/// Get a set summary of a prefix in the data set. -pub const MESSAGE_TYPE_GET_SUMMARY: u8 = 4; +impl MessageType { + pub fn name(&self) -> &'static str { + match *self { + Self::Nop => "NOP", + Self::Init => "INIT", + Self::InitResponse => "INIT_RESPONSE", + Self::HaveRecord => "HAVE_RECORD", + Self::HaveRecords => "HAVE_RECORDS", + Self::GetRecords => "GET_RECORDS", + Self::Record => "RECORD", + Self::SyncStatus => "SYNC_STATUS", + Self::SyncRequest => "SYNC_REQUEST", + Self::SyncResponse => "SYNC_RESPONSE", + } + } +} -/// Set summary of a prefix. -pub const MESSAGE_TYPE_SUMMARY: u8 = 5; +#[derive(Clone, Copy, Eq, PartialEq)] +#[repr(u8)] +pub enum SyncResponseType { + /// No response, do nothing. + None = 0_u8, -/// Payload is a list of keys of records. Usually sent to advertise recently received new records. -pub const MESSAGE_TYPE_HAVE_RECORDS: u8 = 6; + /// Response is a msgpack-encoded HaveRecords message. + HaveRecords = 1_u8, -/// Payload is a list of keys of records the sending node wants. -pub const MESSAGE_TYPE_GET_RECORDS: u8 = 7; + /// Response is a series of records prefixed by varint record sizes. + Records = 2_u8, -/// Payload is a record, with key being omitted if the data store's KEY_IS_COMPUTED constant is true. -pub const MESSAGE_TYPE_RECORD: u8 = 8; + /// Response is an IBLT set summary. + IBLT = 3_u8, +} -/// Summary type: simple array of keys under the given prefix. -pub const SUMMARY_TYPE_KEYS: u8 = 0; +impl From for SyncResponseType { + /// Get response type from a byte, returning None if the byte is out of range. + #[inline(always)] + fn from(b: u8) -> Self { + if b <= 3 { + unsafe { std::mem::transmute(b) } + } else { + Self::None + } + } +} -/// An IBLT set summary. -pub const SUMMARY_TYPE_IBLT: u8 = 1; - -/// Number of bytes of each SHA512 hash to announce, request, etc. This is okay to change but 16 is plenty. -pub const ANNOUNCE_HASH_BYTES: usize = 16; +impl SyncResponseType { + pub fn as_str(&self) -> &'static str { + match *self { + SyncResponseType::None => "NONE", + SyncResponseType::HaveRecords => "HAVE_RECORDS", + SyncResponseType::Records => "RECORDS", + SyncResponseType::IBLT => "IBLT", + } + } +} pub mod msg { - use serde::{Serialize, Deserialize}; + use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize)] pub struct IPv4 { pub ip: [u8; 4], - pub port: u16 + pub port: u16, } #[derive(Serialize, Deserialize)] pub struct IPv6 { pub ip: [u8; 16], - pub port: u16 + pub port: u16, } #[derive(Serialize, Deserialize)] @@ -72,10 +125,10 @@ pub mod msg { pub auth_challenge: &'a [u8], /// Optional name to advertise for this node. - pub node_name: Option, + pub node_name: &'a str, /// Optional contact information for this node, such as a URL or an e-mail address. - pub node_contact: Option, + pub node_contact: &'a str, /// Port to which this node has locally bound. /// This is used to try to auto-detect whether a NAT is in the way. @@ -105,63 +158,95 @@ pub mod msg { pub auth_response: &'a [u8], } - #[derive(Serialize, Deserialize, Clone)] - pub struct Status { - /// Total number of records in data set. - #[serde(rename = "c")] - pub total_record_count: u64, + #[derive(Serialize, Deserialize)] + pub struct HaveRecords<'a> { + /// Length of each key, chosen to ensure uniqueness. + #[serde(rename = "l")] + pub key_length: usize, - /// Reference wall clock time in milliseconds since Unix epoch. - #[serde(rename = "t")] - pub reference_time: i64, + /// Keys whose existence is being announced, of 'key_length' length. + #[serde(with = "serde_bytes")] + #[serde(rename = "k")] + pub keys: &'a [u8], } - #[derive(Serialize, Deserialize, Clone)] - pub struct GetSummary<'a> { - /// Reference wall clock time in milliseconds since Unix epoch. - #[serde(rename = "t")] - pub reference_time: i64, + #[derive(Serialize, Deserialize)] + pub struct GetRecords<'a> { + /// Length of each key, chosen to ensure uniqueness. + #[serde(rename = "l")] + pub key_length: usize, - /// Prefix within key space. - #[serde(rename = "p")] + /// Keys to retrieve, of 'key_length' bytes in length. #[serde(with = "serde_bytes")] - pub prefix: &'a [u8], + #[serde(rename = "k")] + pub keys: &'a [u8], + } - /// Length of prefix in bits (trailing bits in byte array are ignored). - #[serde(rename = "b")] - pub prefix_bits: u8, - - /// Number of records in this range the requesting node already has, used to choose summary type. - #[serde(rename = "r")] + #[derive(Serialize, Deserialize)] + pub struct SyncStatus { + /// Total number of records this node has in its data store. + #[serde(rename = "c")] pub record_count: u64, + + /// Sending node's system clock. + #[serde(rename = "t")] + pub clock: u64, } - #[derive(Serialize, Deserialize, Clone)] - pub struct SummaryHeader<'a> { - /// Total number of records in data set, for easy rapid generation of next query. - #[serde(rename = "c")] - pub total_record_count: u64, - - /// Reference wall clock time in milliseconds since Unix epoch. - #[serde(rename = "t")] - pub reference_time: i64, - - /// Random salt value used by some summary types. - #[serde(rename = "x")] + #[derive(Serialize, Deserialize)] + pub struct SyncRequest<'a> { + /// Query mask, a random string of KEY_SIZE bytes. #[serde(with = "serde_bytes")] - pub salt: &'a [u8], + #[serde(rename = "q")] + pub query_mask: &'a [u8], - /// Prefix within key space. - #[serde(rename = "p")] - #[serde(with = "serde_bytes")] - pub prefix: &'a [u8], - - /// Length of prefix in bits (trailing bits in byte array are ignored). + /// Number of bits to match as a prefix in query_mask (0 for entire data set). #[serde(rename = "b")] - pub prefix_bits: u8, + pub query_mask_bits: u8, - /// Type of summary that follows this header. + /// Number of records requesting node already holds under query mask prefix. + #[serde(rename = "c")] + pub record_count: u64, + + /// Sender's reference time. + #[serde(rename = "t")] + pub reference_time: u64, + + /// Random salt #[serde(rename = "s")] - pub summary_type: u8, + pub salt: u64, + } + + #[derive(Serialize, Deserialize)] + pub struct SyncResponse<'a> { + /// Query mask, a random string of KEY_SIZE bytes. + #[serde(with = "serde_bytes")] + #[serde(rename = "q")] + pub query_mask: &'a [u8], + + /// Number of bits to match as a prefix in query_mask (0 for entire data set). + #[serde(rename = "b")] + pub query_mask_bits: u8, + + /// Number of records sender has under this prefix. + #[serde(rename = "c")] + pub record_count: u64, + + /// Sender's reference time. + #[serde(rename = "t")] + pub reference_time: u64, + + /// Random salt + #[serde(rename = "s")] + pub salt: u64, + + /// SyncResponseType determining content of 'data'. + #[serde(rename = "r")] + pub response_type: u8, + + /// Data whose meaning depends on the response type. + #[serde(with = "serde_bytes")] + #[serde(rename = "d")] + pub data: &'a [u8], } } diff --git a/syncwhole/src/utils.rs b/syncwhole/src/utils.rs index 49f50d5b9..0a9b789f6 100644 --- a/syncwhole/src/utils.rs +++ b/syncwhole/src/utils.rs @@ -8,8 +8,10 @@ use std::collections::HashMap; use std::future::Future; -use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::SystemTime; + use tokio::task::JoinHandle; /// Get the real time clock in milliseconds since Unix epoch. @@ -17,11 +19,6 @@ pub fn ms_since_epoch() -> i64 { std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as i64 } -/// Get the current monotonic clock in milliseconds. -pub fn ms_monotonic() -> i64 { - std::time::Instant::now().elapsed().as_millis() as i64 -} - /// Encode a byte slice to a hexadecimal string. pub fn to_hex_string(b: &[u8]) -> String { const HEX_CHARS: [u8; 16] = [b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', b'b', b'c', b'd', b'e', b'f']; @@ -66,6 +63,28 @@ pub fn splitmix64_inverse(mut x: u64) -> u64 { x.to_le() } +static mut RANDOM_STATE_0: u64 = 0; +static mut RANDOM_STATE_1: u64 = 0; + +/// Get a non-cryptographic pseudorandom number. +pub fn random() -> u64 { + let (mut s0, mut s1) = unsafe { (RANDOM_STATE_0, RANDOM_STATE_1) }; + if s0 == 0 { + s0 = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos() as u64; + } + if s1 == 0 { + s1 = splitmix64(std::process::id() as u64); + } + let s1_new = xorshift64(s1); + s0 = splitmix64(s0.wrapping_add(s1)); + s1 = s1_new; + unsafe { + RANDOM_STATE_0 = s0; + RANDOM_STATE_1 = s1; + }; + s0 +} + /// Wrapper for tokio::spawn() that aborts tasks not yet completed when it is dropped. pub struct AsyncTaskReaper { ctr: AtomicUsize, @@ -74,22 +93,23 @@ pub struct AsyncTaskReaper { impl AsyncTaskReaper { pub fn new() -> Self { - Self { - ctr: AtomicUsize::new(0), - handles: Arc::new(std::sync::Mutex::new(HashMap::new())) - } + Self { ctr: AtomicUsize::new(0), handles: Arc::new(std::sync::Mutex::new(HashMap::new())) } } /// Spawn a new task. - /// Note that currently any output is ignored. This is primarily for background tasks - /// that are used similarly to goroutines in Go. + /// + /// Note that currently any task output is ignored. This is for fire and forget + /// background tasks that you want to be collected on loss of scope. pub fn spawn(&self, future: F) { let id = self.ctr.fetch_add(1, Ordering::Relaxed); let handles = self.handles.clone(); - self.handles.lock().unwrap().insert(id, tokio::spawn(async move { - let _ = future.await; - let _ = handles.lock().unwrap().remove(&id); - })); + self.handles.lock().unwrap().insert( + id, + tokio::spawn(async move { + let _ = future.await; + let _ = handles.lock().unwrap().remove(&id); + }), + ); } } diff --git a/syncwhole/src/varint.rs b/syncwhole/src/varint.rs index b94f41c94..1cb21524c 100644 --- a/syncwhole/src/varint.rs +++ b/syncwhole/src/varint.rs @@ -6,10 +6,7 @@ * https://www.zerotier.com/ */ -use std::io::{Read, Write}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -const VARINT_MAX_SIZE_BYTES: usize = 10; +pub const VARINT_MAX_SIZE_BYTES: usize = 10; pub fn encode(b: &mut [u8], mut v: u64) -> usize { let mut i = 0; @@ -27,50 +24,21 @@ pub fn encode(b: &mut [u8], mut v: u64) -> usize { i } -#[inline(always)] -pub async fn write_async(w: &mut W, v: u64) -> std::io::Result<()> { - let mut b = [0_u8; VARINT_MAX_SIZE_BYTES]; - let i = encode(&mut b, v); - w.write_all(&b[0..i]).await -} - -pub async fn read_async(r: &mut R) -> std::io::Result<(u64, usize)> { +pub fn decode(b: &[u8]) -> (u64, usize) { let mut v = 0_u64; let mut pos = 0; - let mut count = 0; - loop { - count += 1; - let b = r.read_u8().await?; - if b <= 0x7f { - v |= (b as u64).wrapping_shl(pos); + let mut l = 0; + let bl = b.len(); + while l < bl { + let x = b[l]; + l += 1; + if x <= 0x7f { + v |= (x as u64).wrapping_shl(pos); pos += 7; } else { - v |= ((b & 0x7f) as u64).wrapping_shl(pos); - return Ok((v, count)); - } - } -} - -#[inline(always)] -pub fn write(w: &mut W, v: u64) -> std::io::Result<()> { - let mut b = [0_u8; VARINT_MAX_SIZE_BYTES]; - let i = encode(&mut b, v); - w.write_all(&b[0..i]) -} - -pub fn read(r: &mut R) -> std::io::Result { - let mut v = 0_u64; - let mut buf = [0_u8; 1]; - let mut pos = 0; - loop { - let _ = r.read_exact(&mut buf)?; - let b = buf[0]; - if b <= 0x7f { - v |= (b as u64).wrapping_shl(pos); - pos += 7; - } else { - v |= ((b & 0x7f) as u64).wrapping_shl(pos); - return Ok(v); + v |= ((x & 0x7f) as u64).wrapping_shl(pos); + return (v, l); } } + return (0, 0); } diff --git a/zerotier-system-service/src/utils.rs b/zerotier-system-service/src/utils.rs index 89ec96412..61cb3bde2 100644 --- a/zerotier-system-service/src/utils.rs +++ b/zerotier-system-service/src/utils.rs @@ -39,7 +39,6 @@ pub fn ms_monotonic() -> i64 { #[cfg(not(any(target_os = "macos", target_os = "ios")))] pub fn ms_monotonic() -> i64 { - std::time::Instant::now().elapsed().as_millis() as i64 } pub fn parse_bool(v: &str) -> Result {