diff --git a/syncwhole/src/datastore.rs b/syncwhole/src/datastore.rs index c2d59fedb..35d20686c 100644 --- a/syncwhole/src/datastore.rs +++ b/syncwhole/src/datastore.rs @@ -6,12 +6,24 @@ * https://www.zerotier.com/ */ -use std::ops::Bound::Included; -use std::collections::BTreeMap; -use std::sync::{Arc, Mutex}; -use crate::ms_since_epoch; +/// Generate a range of SHA512 hashes from a prefix and a number of bits. +/// The range will be inclusive and cover all keys under the prefix. +pub fn range_from_prefix(prefix: &[u8], prefix_bits: usize) -> ([u8; 64], [u8; 64]) { + let prefix_bits = prefix_bits.min(prefix.len() * 8).min(64); + let mut start = [0_u8; 64]; + let mut end = [0xff_u8; 64]; + let whole_bytes = prefix_bits / 8; + let remaining_bits = prefix_bits % 8; + start[0..whole_bytes].copy_from_slice(&prefix[0..whole_bytes]); + end[0..whole_bytes].copy_from_slice(&prefix[0..whole_bytes]); + if remaining_bits != 0 && whole_bytes < prefix.len() { + start[whole_bytes] |= prefix[whole_bytes]; + end[whole_bytes] &= prefix[whole_bytes] | ((0xff_u8).wrapping_shr(remaining_bits as u32)); + } + (start, end) +} -/// Result returned by DB::load(). +/// Result returned by DataStore::load(). pub enum LoadResult + Send> { /// Object was found. Ok(V), @@ -23,10 +35,12 @@ pub enum LoadResult + Send> { TimeNotAvailable } -/// Result returned by DB::store(). +/// Result returned by DataStore::store(). pub enum StoreResult { - /// Entry was accepted (whether or not an old value was replaced). - Ok, + /// Entry was accepted. + /// The integer included with Ok is the reference time that should be advertised. + /// If this is not a temporally subjective data set then zero can be used. + Ok(i64), /// Entry was a duplicate of one we already have but was otherwise valid. Duplicate, @@ -40,19 +54,16 @@ pub enum StoreResult { /// API to be implemented by the data set we want to replicate. /// +/// Keys as understood by syncwhole are SHA512 hashes of values. The user can of course +/// have their own concept of a "key" separate from this, but that would not be used +/// for data set replication. Replication is content identity based. +/// /// The API specified here supports temporally subjective data sets. These are data sets /// where the existence or non-existence of a record may depend on the (real world) time. /// A parameter for reference time allows a remote querying node to send its own "this is /// what time I think it is" value to be considered locally so that data can be replicated /// as of any given time. /// -/// The constants KEY_SIZE, MAX_VALUE_SIZE, and KEY_IS_COMPUTED are protocol constants -/// for your replication domain. They can't be changed once defined unless all nodes -/// are upgraded at once. -/// -/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a function of -/// values. If this is true, key_from_value() must be implemented. -/// /// The implementation must be thread safe and may be called concurrently. pub trait DataStore: Sync + Send { /// Type to be enclosed in the Ok() enum value in LoadResult. @@ -61,32 +72,13 @@ pub trait DataStore: Sync + Send { /// your implementation. Examples include Box<[u8]>, Arc<[u8]>, Vec, etc. type LoadResultValueType: AsRef<[u8]> + Send; - /// Size of keys, which must be fixed in length. These are typically hashes. - const KEY_SIZE: usize; + /// Key hash size, always 64 for SHA512. + const KEY_SIZE: usize = 64; /// Maximum size of a value in bytes. const MAX_VALUE_SIZE: usize; - /// This should be true if the key is computed, such as by hashing the value. - /// - /// If this is true then keys do not have to be sent over the wire. Instead they - /// are computed by calling get_key(). If this is false keys are assumed not to - /// be computable from values and are explicitly sent. - const KEY_IS_COMPUTED: bool; - - /// Compute the key corresponding to a value. - /// - /// If KEY_IS_COMPUTED is true this must be implemented. The default implementation - /// panics to indicate this. If KEY_IS_COMPUTED is false this is never called. - #[allow(unused_variables)] - fn key_from_value(&self, value: &[u8], key_buffer: &mut [u8]) { - panic!("key_from_value() must be implemented if KEY_IS_COMPUTED is true"); - } - /// Get the current wall clock in milliseconds since Unix epoch. - /// - /// This is delegated to the data store to support scenarios where you want to fix - /// the clock or snapshot at a given time. fn clock(&self) -> i64; /// Get the domain of this data store. @@ -99,144 +91,56 @@ pub trait DataStore: Sync + Send { /// Get an item if it exists as of a given reference time. fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult; + /// Check whether this data store contains a key. + /// + /// The default implementation just uses load(). Override if you can provide a faster + /// version. + fn contains(&self, reference_time: i64, key: &[u8]) -> bool { + match self.load(reference_time, key) { + LoadResult::Ok(_) => true, + _ => false + } + } + /// Store an item in the data store and return its status. /// /// Note that no time is supplied here. The data store must determine this in an implementation /// dependent manner if this is a temporally subjective data store. It could be determined by /// the wall clock, from the object itself, etc. /// - /// The implementation is responsible for validating inputs and returning 'Rejected' if they - /// are invalid. A return value of Rejected can be used to do things like drop connections - /// to peers that send invalid data, so it should only be returned if the data is malformed - /// or something like a signature check fails. The 'Ignored' enum value should be returned - /// for inputs that are valid but were not stored for some other reason, such as being - /// expired. It's important to return 'Ok' for accepted values to hint to the replicator - /// that they should be aggressively advertised to other peers. + /// The key supplied here will always be the SHA512 hash of the value. There is no need to + /// re-compute and check the key, but the value must be validated. /// - /// If KEY_IS_COMPUTED is true, the key supplied here can be assumed to be correct. It will - /// have been computed via get_key(). + /// Validation of the value and returning the appropriate StoreResult is important to the + /// operation of the synchronization algorithm: + /// + /// StoreResult::Ok - Value was valid and was accepted and saved. + /// + /// StoreResult::Duplicate - Value was valid but is a duplicate of one we already have. + /// + /// StoreResult::Ignored - Value was valid but for some other reason was not saved. + /// + /// StoreResult::Rejected - Value was not valid, causes link to peer to be dropped. + /// + /// Rejected should only be returned if the value actually fails a validity check, signature + /// verification, proof of work check, or some other required criteria. Ignored must be + /// returned if the value is valid but is too old or was rejected for some other normal reason. fn store(&self, key: &[u8], value: &[u8]) -> StoreResult; - /// Get the number of items under a prefix as of a given reference time. - /// - /// The default implementation uses for_each_key(). This can be specialized if it can - /// be done more efficiently than that. - fn count(&self, reference_time: i64, key_prefix: &[u8]) -> u64 { - let mut cnt: u64 = 0; - self.for_each_key(reference_time, key_prefix, |_| { - cnt += 1; - true - }); - cnt - } + /// Get the number of items in a range. + fn count(&self, reference_time: i64, key_range_start: &[u8], key_range_end: &[u8]) -> u64; /// Get the total number of records in this data store. fn total_count(&self) -> u64; - /// Iterate through keys beneath a key prefix, stopping if the function returns false. + /// Iterate through keys, stopping if the function returns false. /// /// The default implementation uses for_each(). This can be specialized if it's faster to /// only load keys. - fn for_each_key bool>(&self, reference_time: i64, key_prefix: &[u8], mut f: F) { - self.for_each(reference_time, key_prefix, |k, _| f(k)); + fn for_each_key bool>(&self, reference_time: i64, key_range_start: &[u8], key_range_end: &[u8], mut f: F) { + self.for_each(reference_time, key_range_start, key_range_end, |k, _| f(k)); } - /// Iterate through keys and values beneath a key prefix, stopping if the function returns false. - fn for_each bool>(&self, reference_time: i64, key_prefix: &[u8], f: F); -} - -/// A simple in-memory data store backed by a BTreeMap. -pub struct MemoryDataStore { - max_age: i64, - domain: String, - db: Mutex)>> -} - -impl MemoryDataStore { - pub fn new(max_age: i64, domain: String) -> Self { - Self { - max_age: if max_age > 0 { max_age } else { i64::MAX }, - domain, - db: Mutex::new(BTreeMap::new()) - } - } -} - -impl DataStore for MemoryDataStore { - type LoadResultValueType = Arc<[u8]>; - const KEY_SIZE: usize = KEY_SIZE; - const MAX_VALUE_SIZE: usize = 65536; - const KEY_IS_COMPUTED: bool = false; - - fn clock(&self) -> i64 { ms_since_epoch() } - - fn domain(&self) -> &str { self.domain.as_str() } - - fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult { - let db = self.db.lock().unwrap(); - let e = db.get(key); - if e.is_some() { - let e = e.unwrap(); - if (reference_time - e.0) <= self.max_age { - LoadResult::Ok(e.1.clone()) - } else { - LoadResult::NotFound - } - } else { - LoadResult::NotFound - } - } - - fn store(&self, key: &[u8], value: &[u8]) -> StoreResult { - let ts = crate::ms_since_epoch(); - let mut isdup = false; - self.db.lock().unwrap().entry(key.try_into().unwrap()).and_modify(|e| { - if e.1.as_ref().eq(value) { - isdup = true; - } else { - *e = (ts, Arc::from(value)); - } - }).or_insert_with(|| { - (ts, Arc::from(value)) - }); - if isdup { - StoreResult::Duplicate - } else { - StoreResult::Ok - } - } - - fn total_count(&self) -> u64 { self.db.lock().unwrap().len() as u64 } - - fn for_each bool>(&self, reference_time: i64, key_prefix: &[u8], mut f: F) { - let mut r_start = [0_u8; KEY_SIZE]; - let mut r_end = [0xff_u8; KEY_SIZE]; - (&mut r_start[0..key_prefix.len()]).copy_from_slice(key_prefix); - (&mut r_end[0..key_prefix.len()]).copy_from_slice(key_prefix); - for (k, v) in self.db.lock().unwrap().range((Included(r_start), Included(r_end))) { - if (reference_time - v.0) <= self.max_age { - if !f(k, &v.1) { - break; - } - } - } - } -} - -impl PartialEq for MemoryDataStore { - fn eq(&self, other: &Self) -> bool { - self.max_age == other.max_age && self.domain == other.domain && self.db.lock().unwrap().eq(&*other.db.lock().unwrap()) - } -} - -impl Eq for MemoryDataStore {} - -impl Clone for MemoryDataStore { - fn clone(&self) -> Self { - Self { - max_age: self.max_age, - domain: self.domain.clone(), - db: Mutex::new(self.db.lock().unwrap().clone()) - } - } + /// Iterate through keys and values, stopping if the function returns false. + fn for_each bool>(&self, reference_time: i64, key_range_start: &[u8], key_range_end: &[u8], f: F); } diff --git a/syncwhole/src/iblt.rs b/syncwhole/src/iblt.rs index 30df97036..1ab441b5d 100644 --- a/syncwhole/src/iblt.rs +++ b/syncwhole/src/iblt.rs @@ -6,43 +6,12 @@ * https://www.zerotier.com/ */ -use std::alloc::{alloc_zeroed, dealloc, Layout}; use std::mem::size_of; -use std::ptr::write_bytes; +use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut, write_bytes}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use crate::varint; - -#[inline(always)] -fn xorshift64(mut x: u64) -> u64 { - x = u64::from_le(x); - x ^= x.wrapping_shl(13); - x ^= x.wrapping_shr(7); - x ^= x.wrapping_shl(17); - x.to_le() -} - -#[inline(always)] -fn splitmix64(mut x: u64) -> u64 { - x = u64::from_le(x); - x ^= x.wrapping_shr(30); - x = x.wrapping_mul(0xbf58476d1ce4e5b9); - x ^= x.wrapping_shr(27); - x = x.wrapping_mul(0x94d049bb133111eb); - x ^= x.wrapping_shr(31); - x.to_le() -} - -#[inline(always)] -fn splitmix64_inverse(mut x: u64) -> u64 { - x = u64::from_le(x); - x ^= x.wrapping_shr(31) ^ x.wrapping_shr(62); - x = x.wrapping_mul(0x319642b2d24d8ec3); - x ^= x.wrapping_shr(27) ^ x.wrapping_shr(54); - x = x.wrapping_mul(0x96de1b173f119089); - x ^= x.wrapping_shr(30) ^ x.wrapping_shr(60); - x.to_le() -} +use crate::utils::*; #[inline(always)] fn next_iteration_index(prev_iteration_index: u64) -> u64 { @@ -50,16 +19,17 @@ fn next_iteration_index(prev_iteration_index: u64) -> u64 { } #[derive(Clone, PartialEq, Eq)] +#[repr(C, packed)] struct IBLTEntry { key_sum: u64, check_hash_sum: u64, - count: i64 + count: i32 } impl IBLTEntry { #[inline(always)] fn is_singular(&self) -> bool { - if self.count == 1 || self.count == -1 { + if i32::from_le(self.count) == 1 || i32::from_le(self.count) == -1 { xorshift64(self.key_sum) == self.check_hash_sum } else { false @@ -68,84 +38,102 @@ impl IBLTEntry { } /// An Invertible Bloom Lookup Table for set reconciliation with 64-bit hashes. +/// +/// Usage inspired by this paper: +/// +/// https://dash.harvard.edu/bitstream/handle/1/14398536/GENTILI-SENIORTHESIS-2015.pdf #[derive(Clone, PartialEq, Eq)] -pub struct IBLT { - map: *mut [IBLTEntry; B] +pub struct IBLT { + map: Vec } -impl IBLT { - /// Number of buckets (capacity) of this IBLT. - pub const BUCKETS: usize = B; - +impl IBLT { /// This was determined to be effective via empirical testing with random keys. This /// is a protocol constant that can't be changed without upgrading all nodes in a domain. const KEY_MAPPING_ITERATIONS: usize = 2; - pub fn new() -> Self { - assert!(B < u32::MAX as usize); // sanity check - Self { - map: unsafe { alloc_zeroed(Layout::new::<[IBLTEntry; B]>()).cast() } + pub fn new(buckets: usize) -> Self { + assert!(buckets < i32::MAX as usize); + assert_eq!(size_of::(), 20); + let mut iblt = Self { map: Vec::with_capacity(buckets) }; + unsafe { + iblt.map.set_len(buckets); + iblt.reset(); } + iblt } - pub fn reset(&mut self) { - unsafe { write_bytes(self.map.cast::(), 0, size_of::<[IBLTEntry; B]>()) }; - } + /// Compute the size in bytes of an IBLT with the given number of buckets. + #[inline(always)] + pub fn size_bytes_with_buckets(buckets: usize) -> usize { buckets * size_of::() } - pub async fn read(&mut self, r: &mut R) -> std::io::Result<()> { - let mut prev_c = 0_i64; - for b in unsafe { (*self.map).iter_mut() } { - let _ = r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() }).await?; - let _ = r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() }).await?; - let mut c = varint::read_async(r).await?.0 as i64; - if (c & 1) == 0 { - c = c.wrapping_shr(1); - } else { - c = -c.wrapping_shr(1); - } - b.count = c + prev_c; - prev_c = b.count; - } - Ok(()) - } - - pub async fn write(&self, w: &mut W) -> std::io::Result<()> { - let mut prev_c = 0_i64; - for b in unsafe { (*self.map).iter() } { - let _ = w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() }).await?; - let _ = w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() }).await?; - let mut c = (b.count - prev_c).wrapping_shl(1); - prev_c = b.count; - if c < 0 { - c = -c | 1; - } - let _ = varint::write_async(w, c as u64).await?; - } - Ok(()) - } - - /// Get this IBLT as a byte array. - pub fn to_bytes(&self) -> Vec { - let mut out = std::io::Cursor::new(Vec::::with_capacity(B * 20)); - let current = tokio::runtime::Handle::try_current(); - if current.is_ok() { - assert!(current.unwrap().block_on(self.write(&mut out)).is_ok()); + /// Compute the IBLT size in buckets to reconcile a given set difference, or return 0 if no advantage. + /// This returns zero if an IBLT would take up as much or more space than just sending local_set_size + /// hashes of hash_size_bytes. + pub fn calc_iblt_parameters(hash_size_bytes: usize, local_set_size: u64, difference_size: u64) -> usize { + let hashes_would_be = (hash_size_bytes as f64) * (local_set_size as f64); + let buckets_should_be = (difference_size as f64) * 1.8; // factor determined experimentally for best bytes/item, can be tuned + let iblt_would_be = buckets_should_be * (size_of::() as f64); + if iblt_would_be < hashes_would_be { + buckets_should_be.ceil() as usize } else { - assert!(tokio::runtime::Builder::new_current_thread().build().unwrap().block_on(self.write(&mut out)).is_ok()); + 0 } - out.into_inner() } - fn ins_rem(&mut self, mut key: u64, delta: i64) { + /// Get the size of this IBLT in buckets. + #[inline(always)] + pub fn buckets(&self) -> usize { self.map.len() } + + /// Get the size of this IBLT in bytes. + pub fn size_bytes(&self) -> usize { self.map.len() * size_of::() } + + /// Zero this IBLT. + #[inline(always)] + pub fn reset(&mut self) { + unsafe { write_bytes(self.map.as_mut_ptr().cast::(), 0, self.map.len() * size_of::()); } + } + + /// Get this IBLT as a byte slice in place. + #[inline(always)] + pub fn as_bytes(&self) -> &[u8] { + unsafe { &*slice_from_raw_parts(self.map.as_ptr().cast::(), self.map.len() * size_of::()) } + } + + /// Construct an IBLT from an input reader and a size in bytes. + pub async fn new_from_reader(r: &mut R, bytes: usize) -> std::io::Result { + assert_eq!(size_of::(), 20); + if (bytes % size_of::()) != 0 { + Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "incomplete or invalid IBLT")) + } else { + let buckets = bytes / size_of::(); + let mut iblt = Self { map: Vec::with_capacity(buckets) }; + unsafe { + iblt.map.set_len(buckets); + r.read_exact(&mut *slice_from_raw_parts_mut(iblt.map.as_mut_ptr().cast::(), bytes)).await?; + } + Ok(iblt) + } + } + + /// Write this IBLT to a stream. + /// Note that the size of the IBLT in bytes must be stored separately. Use size_bytes() to get that. + #[inline(always)] + pub async fn write(&self, w: &mut W) -> std::io::Result<()> { + w.write_all(self.as_bytes()).await + } + + fn ins_rem(&mut self, mut key: u64, delta: i32) { key = splitmix64(key); let check_hash = xorshift64(key); let mut iteration_index = u64::from_le(key); + let buckets = self.map.len(); for _ in 0..Self::KEY_MAPPING_ITERATIONS { iteration_index = next_iteration_index(iteration_index); - let b = unsafe { (*self.map).get_unchecked_mut((iteration_index as usize) % B) }; + let b = unsafe { self.map.get_unchecked_mut((iteration_index as usize) % buckets) }; b.key_sum ^= key; b.check_hash_sum ^= check_hash; - b.count += delta; + b.count = (i32::from_le(b.count) + delta).to_le(); } } @@ -166,27 +154,35 @@ impl IBLT { } /// Subtract another IBLT from this one to get a set difference. - pub fn subtract(&mut self, other: &Self) { - for (s, o) in unsafe { (*self.map).iter_mut().zip((*other.map).iter()) } { - s.key_sum ^= o.key_sum; - s.check_hash_sum ^= o.check_hash_sum; - s.count -= o.count; + /// + /// This returns true on success or false on error, which right now can only happen if the + /// other IBLT has a different number of buckets or if it contains so many entries + pub fn subtract(&mut self, other: &Self) -> bool { + if other.map.len() == self.map.len() { + for (s, o) in self.map.iter_mut().zip(other.map.iter()) { + s.key_sum ^= o.key_sum; + s.check_hash_sum ^= o.check_hash_sum; + s.count = (i32::from_le(s.count) - i32::from_le(o.count)).to_le(); + } + return true; } + return false; } - pub fn list(self, mut f: F) -> bool { - let mut queue: Vec = Vec::with_capacity(B); + /// List as many entries in this IBLT as can be extracted. + pub fn list(mut self, mut f: F) { + let mut queue: Vec = Vec::with_capacity(self.map.len()); - for b in 0..B { - if unsafe { (*self.map).get_unchecked(b).is_singular() } { - queue.push(b as u32); + for bi in 0..self.map.len() { + if unsafe { self.map.get_unchecked(bi).is_singular() } { + queue.push(bi as u32); } } loop { let b = queue.pop(); let b = if b.is_some() { - unsafe { (*self.map).get_unchecked_mut(b.unwrap() as usize) } + unsafe { self.map.get_unchecked_mut(b.unwrap() as usize) } } else { break; }; @@ -199,29 +195,21 @@ impl IBLT { for _ in 0..Self::KEY_MAPPING_ITERATIONS { iteration_index = next_iteration_index(iteration_index); - let b_idx = iteration_index % (B as u64); - let b = unsafe { (*self.map).get_unchecked_mut(b_idx as usize) }; + let b_idx = iteration_index % (self.map.len() as u64); + let b = unsafe { self.map.get_unchecked_mut(b_idx as usize) }; b.key_sum ^= key; b.check_hash_sum ^= check_hash; - b.count -= 1; + b.count = (i32::from_le(b.count) - 1).to_le(); if b.is_singular() { - if queue.len() >= (B * 2) { // sanity check for invalid IBLT - return false; + if queue.len() > self.map.len() { // sanity check for invalid IBLT + return; } queue.push(b_idx as u32); } } } } - - return true; - } -} - -impl Drop for IBLT { - fn drop(&mut self) { - unsafe { dealloc(self.map.cast(), Layout::new::<[IBLTEntry; B]>()) }; } } @@ -229,6 +217,7 @@ impl Drop for IBLT { mod tests { use std::collections::HashSet; use std::time::SystemTime; + use crate::iblt::*; #[test] @@ -245,7 +234,7 @@ mod tests { let mut count = 64; const CAPACITY: usize = 4096; while count <= CAPACITY { - let mut test: IBLT = IBLT::new(); + let mut test: IBLT = IBLT::new(CAPACITY); expected.clear(); for _ in 0..count { @@ -277,8 +266,8 @@ mod tests { let mut missing: HashSet = HashSet::with_capacity(CAPACITY); while missing_count <= CAPACITY { missing.clear(); - let mut local: IBLT = IBLT::new(); - let mut remote: IBLT = IBLT::new(); + let mut local: IBLT = IBLT::new(CAPACITY); + let mut remote: IBLT = IBLT::new(CAPACITY); for k in 0..REMOTE_SIZE { if k >= missing_count { @@ -291,7 +280,7 @@ mod tests { } local.subtract(&mut remote); - let bytes = local.to_bytes().len(); + let bytes = local.as_bytes().len(); let mut cnt = 0; local.list(|k| { let k = u64::from_ne_bytes(*k); diff --git a/syncwhole/src/lib.rs b/syncwhole/src/lib.rs index af7ec975b..0da564aee 100644 --- a/syncwhole/src/lib.rs +++ b/syncwhole/src/lib.rs @@ -13,11 +13,4 @@ pub(crate) mod iblt; pub mod datastore; pub mod node; pub mod host; - -pub fn ms_since_epoch() -> i64 { - std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as i64 -} - -pub fn ms_monotonic() -> i64 { - std::time::Instant::now().elapsed().as_millis() as i64 -} +pub mod utils; diff --git a/syncwhole/src/main.rs b/syncwhole/src/main.rs index 1c8f33a1e..bcea03ad7 100644 --- a/syncwhole/src/main.rs +++ b/syncwhole/src/main.rs @@ -13,6 +13,7 @@ use syncwhole::datastore::{DataStore, LoadResult, StoreResult}; use syncwhole::host::Host; use syncwhole::ms_since_epoch; use syncwhole::node::{Node, RemoteNodeInfo}; +use syncwhole::utils::*; const TEST_NODE_COUNT: usize = 16; const TEST_PORT_RANGE_START: u16 = 21384; @@ -55,13 +56,7 @@ impl Host for TestNodeHost { impl DataStore for TestNodeHost { type LoadResultValueType = Arc<[u8]>; - const KEY_SIZE: usize = 64; const MAX_VALUE_SIZE: usize = 1024; - const KEY_IS_COMPUTED: bool = true; - - fn key_from_value(&self, value: &[u8], key_buffer: &mut [u8]) { - key_buffer.copy_from_slice(Sha512::digest(value).as_slice()); - } fn clock(&self) -> i64 { ms_since_epoch() } @@ -73,7 +68,7 @@ impl DataStore for TestNodeHost { fn store(&self, key: &[u8], value: &[u8]) -> StoreResult { assert_eq!(key.len(), 64); - let mut res = StoreResult::Ok; + let mut res = StoreResult::Ok(0); self.db.lock().unwrap().entry(key.try_into().unwrap()).and_modify(|e| { if e.as_ref().eq(value) { res = StoreResult::Duplicate; @@ -86,14 +81,14 @@ impl DataStore for TestNodeHost { res } + fn count(&self, _: i64, key_range_start: &[u8], key_range_end: &[u8]) -> u64 { + self.db.lock().unwrap().range((Included(key_range_start.try_into().unwrap()), Included(key_range_end.try_into().unwrap()))).count() as u64 + } + fn total_count(&self) -> u64 { self.db.lock().unwrap().len() as u64 } - fn for_each bool>(&self, _: i64, key_prefix: &[u8], mut f: F) { - let mut r_start = [0_u8; Self::KEY_SIZE]; - let mut r_end = [0xff_u8; Self::KEY_SIZE]; - (&mut r_start[0..key_prefix.len()]).copy_from_slice(key_prefix); - (&mut r_end[0..key_prefix.len()]).copy_from_slice(key_prefix); - for (k, v) in self.db.lock().unwrap().range((Included(r_start), Included(r_end))) { + fn for_each bool>(&self, _: i64, key_range_start: &[u8], key_range_end: &[u8], mut f: F) { + for (k, v) in self.db.lock().unwrap().range((Included(key_range_start.try_into().unwrap()), Included(key_range_end.try_into().unwrap()))) { if !f(k, v.as_ref()) { break; } @@ -126,11 +121,13 @@ fn main() { loop { tokio::time::sleep(Duration::from_secs(1)).await; + /* let mut count = 0; for n in nodes.iter() { count += n.connection_count().await; } println!("{}", count); + */ } }); } diff --git a/syncwhole/src/node.rs b/syncwhole/src/node.rs index c2d88c6ac..0531a94f5 100644 --- a/syncwhole/src/node.rs +++ b/syncwhole/src/node.rs @@ -13,7 +13,6 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::ops::Add; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering}; -use std::time::SystemTime; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; @@ -21,12 +20,13 @@ use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::sync::Mutex; use tokio::task::JoinHandle; -use tokio::time::{Instant, Duration}; +use tokio::time::{Duration, Instant}; -use crate::datastore::DataStore; +use crate::datastore::*; use crate::host::Host; -use crate::ms_monotonic; +use crate::iblt::IBLT; use crate::protocol::*; +use crate::utils::*; use crate::varint; /// Inactivity timeout for connections in milliseconds. @@ -56,8 +56,11 @@ pub struct RemoteNodeInfo { /// Explicitly advertised remote addresses supplied by remote node (not necessarily verified). pub explicit_addresses: Vec, - /// Time TCP connection was established. - pub connect_time: SystemTime, + /// Time TCP connection was established (ms since epoch). + pub connect_time: i64, + + /// Time TCP connection was estaablished (ms, monotonic). + pub connect_instant: i64, /// True if this is an inbound TCP connection. pub inbound: bool, @@ -101,7 +104,7 @@ impl Node { datastore: db.clone(), host: host.clone(), connections: Mutex::new(HashMap::with_capacity(64)), - bind_address + bind_address, }); Ok(Self { @@ -115,7 +118,6 @@ impl Node { pub fn host(&self) -> &Arc { &self.internal.host } - #[inline(always)] pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result { self.internal.clone().connect(endpoint, Instant::now().add(Duration::from_millis(CONNECTION_TIMEOUT as u64))).await } @@ -142,14 +144,22 @@ impl Drop for Node { } pub struct NodeInternal { + // Secret used to perform HMAC to detect and drop loopback connections to self. anti_loopback_secret: [u8; 64], + + // Outside code implementations of DataStore and Host traits. datastore: Arc, host: Arc, + + // Connections and their task join handles, by remote endpoint address. connections: Mutex, Option>>)>>, + + // Local address to which this node is bound bind_address: SocketAddr, } impl NodeInternal { + /// Loop that constantly runs in the background to do cleanup and service things. async fn housekeeping_task_main(self: Arc) { let mut last_status_sent = ms_monotonic(); let mut tasks: Vec> = Vec::new(); @@ -173,8 +183,8 @@ impl NodeInternal { let status = if (now - last_status_sent) >= STATUS_INTERVAL { last_status_sent = now; Some(msg::Status { - record_count: self.datastore.total_count(), - clock: self.datastore.clock() as u64 + total_record_count: self.datastore.total_count(), + reference_time: self.datastore.clock() }) } else { None @@ -191,6 +201,8 @@ impl NodeInternal { let sa = sa.clone(); tasks.push(tokio::spawn(async move { if cc.info.lock().await.initialized { + // This almost always completes instantly due to queues, but add a timeout in case connection + // is stalled. In this case the result is a closed connection. if !tokio::time::timeout_at(sleep_until, cc.send_obj(MESSAGE_TYPE_STATUS, &status, now)).await.map_or(false, |r| r.is_ok()) { let _ = self2.connections.lock().await.remove(&sa).map(|c| c.1.map(|j| j.abort())); self2.host.on_connection_closed(&*cc.info.lock().await, "write overflow (timeout)".to_string()); @@ -198,7 +210,7 @@ impl NodeInternal { } })); }); - true + true // keep connection } else { let _ = c.1.take().map(|j| j.abort()); let host = self.host.clone(); @@ -206,7 +218,7 @@ impl NodeInternal { tasks.push(tokio::spawn(async move { host.on_connection_closed(&*cc.info.lock().await, "timeout".to_string()); })); - false + false // discard connection } } else { let host = self.host.clone(); @@ -225,7 +237,7 @@ impl NodeInternal { host.on_connection_closed(&*cc.info.lock().await, "remote host closed connection".to_string()); } })); - false + false // discard connection } }); @@ -243,7 +255,7 @@ impl NodeInternal { } // Try to connect to more peers until desired connection count is reached. - let desired_connection_count = self.host.desired_connection_count(); + let desired_connection_count = self.host.desired_connection_count().min(self.host.max_connection_count()); while connected_to_addresses.len() < desired_connection_count { let sa = self.host.another_peer(&connected_to_addresses); if sa.is_some() { @@ -270,133 +282,36 @@ impl NodeInternal { } } + /// Incoming TCP acceptor task. async fn listener_task_main(self: Arc, listener: TcpListener) { loop { let socket = listener.accept().await; if socket.is_ok() { - let (stream, endpoint) = socket.unwrap(); - let num_conn = self.connections.lock().await.len(); - if num_conn < self.host.max_connection_count() || self.host.fixed_peers().contains(&endpoint) { - Self::connection_start(&self, endpoint, stream, true).await; + let (stream, address) = socket.unwrap(); + if self.host.allow(&address) { + if self.connections.lock().await.len() < self.host.max_connection_count() || self.host.fixed_peers().contains(&address) { + Self::connection_start(&self, address, stream, true).await; + } } } } } - #[inline(always)] - async fn connection_io_task_main(self: Arc, connection: &Arc, mut reader: BufReader) -> std::io::Result<()> { - let mut buf: Vec = Vec::new(); - buf.resize(4096, 0); - - let mut anti_loopback_challenge_sent = [0_u8; 64]; - let mut challenge_sent = [0_u8; 64]; - self.host.get_secure_random(&mut anti_loopback_challenge_sent); - self.host.get_secure_random(&mut challenge_sent); - connection.send_obj(MESSAGE_TYPE_INIT, &msg::Init { - anti_loopback_challenge: &anti_loopback_challenge_sent, - challenge: &challenge_sent, - domain: self.datastore.domain().to_string(), - key_size: D::KEY_SIZE as u16, - max_value_size: D::MAX_VALUE_SIZE as u64, - node_name: self.host.name().map(|n| n.to_string()), - node_contact: self.host.contact().map(|c| c.to_string()), - locally_bound_port: self.bind_address.port(), - explicit_ipv4: None, - explicit_ipv6: None - }, ms_monotonic()).await?; - - let mut init_received = false; - loop { - reader.read_exact(&mut buf.as_mut_slice()[0..1]).await?; - let message_type = unsafe { *buf.get_unchecked(0) }; - let message_size = varint::read_async(&mut reader).await?; - let header_size = 1 + message_size.1; - let message_size = message_size.0; - if message_size > (D::MAX_VALUE_SIZE + ((D::KEY_SIZE + 10) * 256) + 65536) as u64 { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large")); - } - - let now = ms_monotonic(); - connection.last_receive_time.store(now, Ordering::Relaxed); - - match message_type { - - MESSAGE_TYPE_INIT => { - if init_received { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init")); - } - - let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; - - if !msg.domain.as_str().eq(self.datastore.domain()) { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("data set domain mismatch: '{}' != '{}'", msg.domain, self.datastore.domain()))); - } - if msg.key_size != D::KEY_SIZE as u16 || msg.max_value_size > D::MAX_VALUE_SIZE as u64 { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set key/value sizing mismatch")); - } - - let (anti_loopback_response, challenge_response) = { - let mut info = connection.info.lock().await; - info.node_name = msg.node_name.clone(); - info.node_contact = msg.node_contact.clone(); - let _ = msg.explicit_ipv4.map(|pv4| { - info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port))); - }); - let _ = msg.explicit_ipv6.map(|pv6| { - info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0))); - }); - - let challenge_response = self.host.authenticate(&info, msg.challenge); - if challenge_response.is_none() { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped")); - } - (H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), challenge_response.unwrap()) - }; - - connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse { - anti_loopback_response: &anti_loopback_response, - challenge_response: &challenge_response - }, now).await?; - - init_received = true; - }, - - MESSAGE_TYPE_INIT_RESPONSE => { - let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; - - let mut info = connection.info.lock().await; - if info.initialized { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init response")); - } - info.initialized = true; - let info = info.clone(); - - if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self")); - } - if !self.host.authenticate(&info, &challenge_sent).map_or(false, |cr| msg.challenge_response.eq(&cr)) { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed")); - } - - self.host.on_connect(&info); - }, - - _ => { - // Skip messages that aren't recognized or don't need to be parsed. - let mut remaining = message_size as usize; - while remaining > 0 { - let s = remaining.min(buf.len()); - reader.read_exact(&mut buf.as_mut_slice()[0..s]).await?; - remaining -= s; - } - } - - } - - connection.bytes_received.fetch_add((header_size as u64) + message_size, Ordering::Relaxed); + /// Initiate an outgoing connection with a deadline based timeout. + async fn connect(self: Arc, address: &SocketAddr, deadline: Instant) -> std::io::Result { + self.host.on_connect_attempt(address); + let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; + configure_tcp_socket(&stream)?; + stream.bind(self.bind_address.clone())?; + let stream = tokio::time::timeout_at(deadline, stream.connect(address.clone())).await; + if stream.is_ok() { + Ok(self.connection_start(address.clone(), stream.unwrap()?, false).await) + } else { + Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "connect timed out")) } } + /// Sets up and spawns the task for a new TCP connection whether inbound or outbound. async fn connection_start(self: &Arc, address: SocketAddr, stream: TcpStream, inbound: bool) -> bool { let mut ok = false; let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| { @@ -415,7 +330,8 @@ impl NodeInternal { node_contact: None, remote_address: address.clone(), explicit_addresses: Vec::new(), - connect_time: SystemTime::now(), + connect_time: ms_since_epoch(), + connect_instant: ms_monotonic(), inbound, initialized: false }), @@ -431,25 +347,322 @@ impl NodeInternal { ok } - async fn connect(self: Arc, address: &SocketAddr, deadline: Instant) -> std::io::Result { - self.host.on_connect_attempt(address); - let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; - configure_tcp_socket(&stream)?; - stream.bind(self.bind_address.clone())?; - let stream = tokio::time::timeout_at(deadline, stream.connect(address.clone())).await; - if stream.is_ok() { - Ok(self.connection_start(address.clone(), stream.unwrap()?, false).await) - } else { - Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "connect timed out")) + /// Main I/O task launched for each connection. + /// + /// This handles reading from the connection and reacting to what it sends. Killing this + /// task is done when the connection is closed. + async fn connection_io_task_main(self: Arc, connection: &Arc, mut reader: BufReader) -> std::io::Result<()> { + let mut anti_loopback_challenge_sent = [0_u8; 64]; + let mut domain_challenge_sent = [0_u8; 64]; + let mut auth_challenge_sent = [0_u8; 64]; + self.host.get_secure_random(&mut anti_loopback_challenge_sent); + self.host.get_secure_random(&mut domain_challenge_sent); + self.host.get_secure_random(&mut auth_challenge_sent); + connection.send_obj(MESSAGE_TYPE_INIT, &msg::Init { + anti_loopback_challenge: &anti_loopback_challenge_sent, + domain_challenge: &domain_challenge_sent, + auth_challenge: &auth_challenge_sent, + node_name: self.host.name().map(|n| n.to_string()), + node_contact: self.host.contact().map(|c| c.to_string()), + locally_bound_port: self.bind_address.port(), + explicit_ipv4: None, + explicit_ipv6: None + }, ms_monotonic()).await?; + + let mut initialized = false; + let background_tasks = AsyncTaskReaper::new(); + let mut init_received = false; + let mut buf: Vec = Vec::new(); + buf.resize(4096, 0); + loop { + let message_type = reader.read_u8().await?; + let message_size = varint::read_async(&mut reader).await?; + let header_size = 1 + message_size.1; + let message_size = message_size.0; + if message_size > (D::MAX_VALUE_SIZE + ((D::KEY_SIZE + 10) * 256) + 65536) as u64 { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large")); + } + + let now = ms_monotonic(); + connection.last_receive_time.store(now, Ordering::Relaxed); + + match message_type { + + MESSAGE_TYPE_INIT => { + if init_received { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init")); + } + + let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; + let (anti_loopback_response, domain_challenge_response, auth_challenge_response) = { + let mut info = connection.info.lock().await; + info.node_name = msg.node_name.clone(); + info.node_contact = msg.node_contact.clone(); + let _ = msg.explicit_ipv4.map(|pv4| { + info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port))); + }); + let _ = msg.explicit_ipv6.map(|pv6| { + info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0))); + }); + + let auth_challenge_response = self.host.authenticate(&info, msg.auth_challenge); + if auth_challenge_response.is_none() { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped")); + } + ( + H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), + H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), msg.domain_challenge), + auth_challenge_response.unwrap() + ) + }; + + connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse { + anti_loopback_response: &anti_loopback_response, + domain_response: &domain_challenge_response, + auth_response: &auth_challenge_response + }, now).await?; + + init_received = true; + }, + + MESSAGE_TYPE_INIT_RESPONSE => { + let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; + + let mut info = connection.info.lock().await; + if info.initialized { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init response")); + } + info.initialized = true; + let info = info.clone(); + + if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self")); + } + if msg.domain_response.eq(&H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), &domain_challenge_sent)) { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "domain mismatch")); + } + if !self.host.authenticate(&info, &auth_challenge_sent).map_or(false, |cr| msg.auth_response.eq(&cr)) { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed")); + } + + self.host.on_connect(&info); + initialized = true; + }, + + _ => { + if !initialized { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "init exchange must be completed before other messages are sent")); + } + + match message_type { + + MESSAGE_TYPE_STATUS => { + let msg: msg::Status = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; + self.connection_request_summary(connection, msg.total_record_count, now, msg.reference_time).await?; + }, + + MESSAGE_TYPE_GET_SUMMARY => { + //let msg: msg::GetSummary = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?; + }, + + MESSAGE_TYPE_SUMMARY => { + let mut remaining = message_size as isize; + + // Read summary header. + let summary_header_size = varint::read_async(&mut reader).await?; + remaining -= summary_header_size.1 as isize; + let summary_header_size = summary_header_size.0; + if (summary_header_size as i64) > (remaining as i64) { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid summary header")); + } + let summary_header: msg::SummaryHeader = connection.read_obj(&mut reader, &mut buf, summary_header_size as usize).await?; + remaining -= summary_header_size as isize; + + // Read and evaluate summary that we were sent. + match summary_header.summary_type { + SUMMARY_TYPE_KEYS => { + self.connection_receive_and_process_remote_hash_list( + connection, + remaining, + &mut reader, + now, + summary_header.reference_time, + &summary_header.prefix[0..summary_header.prefix.len().min((summary_header.prefix_bits / 8) as usize)]).await?; + }, + SUMMARY_TYPE_IBLT => { + //let summary = IBLT::new_from_reader(&mut reader, remaining as usize).await?; + }, + _ => {} // ignore unknown summary types + } + + // Request another summary if needed, keeping a ping-pong game going in a tight loop until synced. + self.connection_request_summary(connection, summary_header.total_record_count, now, summary_header.reference_time).await?; + }, + + MESSAGE_TYPE_HAVE_RECORDS => { + let mut remaining = message_size as isize; + let reference_time = varint::read_async(&mut reader).await?; + remaining -= reference_time.1 as isize; + if remaining <= 0 { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid message")); + } + self.connection_receive_and_process_remote_hash_list(connection, remaining, &mut reader, now, reference_time.0 as i64, &[]).await? + }, + + MESSAGE_TYPE_GET_RECORDS => { + }, + + MESSAGE_TYPE_RECORD => { + let value = connection.read_msg(&mut reader, &mut buf, message_size as usize).await?; + if value.len() > D::MAX_VALUE_SIZE { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "value larger than MAX_VALUE_SIZE")); + } + let key = H::sha512(&[value]); + match self.datastore.store(&key, value) { + StoreResult::Ok(reference_time) => { + let mut have_records_msg = [0_u8; 2 + 10 + ANNOUNCE_HASH_BYTES]; + let mut msg_len = varint::encode(&mut have_records_msg, reference_time as u64); + have_records_msg[msg_len] = ANNOUNCE_HASH_BYTES as u8; + msg_len += 1; + have_records_msg[msg_len..(msg_len + ANNOUNCE_HASH_BYTES)].copy_from_slice(&key[..ANNOUNCE_HASH_BYTES]); + msg_len += ANNOUNCE_HASH_BYTES; + + let self2 = self.clone(); + let connection2 = connection.clone(); + background_tasks.spawn(async move { + let connections = self2.connections.lock().await; + let mut recipients = Vec::with_capacity(connections.len()); + for (_, c) in connections.iter() { + if !Arc::ptr_eq(&(c.0), &connection2) { + recipients.push(c.0.clone()); + } + } + drop(connections); // release lock + + for c in recipients.iter() { + // This typically completes instantly due to buffering, as this message is small. + // Add a small timeout in the case that some connections are stalled. Misses will + // not impact the overall network much. + let _ = tokio::time::timeout(Duration::from_millis(250), c.send_msg(MESSAGE_TYPE_HAVE_RECORDS, &have_records_msg[..msg_len], now)).await; + } + }); + }, + StoreResult::Rejected => { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid datum received")); + }, + _ => {} // duplicate or ignored values are just... ignored + } + }, + + _ => { + // Skip messages that aren't recognized or don't need to be parsed. + let mut remaining = message_size as usize; + while remaining > 0 { + let s = remaining.min(buf.len()); + reader.read_exact(&mut buf.as_mut_slice()[0..s]).await?; + remaining -= s; + } + } + + } + } + } + + connection.bytes_received.fetch_add((header_size as u64) + message_size, Ordering::Relaxed); } } + + /// Request a summary if needed, or do nothing if not. + /// + /// This is where all the logic lives that determines whether to request summaries, choosing a + /// prefix, etc. It's called when the remote node tells us its total record count with an + /// associated reference time, which happens in status announcements and in summaries. + async fn connection_request_summary(&self, connection: &Arc, total_record_count: u64, now: i64, reference_time: i64) -> std::io::Result<()> { + let my_total_record_count = self.datastore.total_count(); + if my_total_record_count < total_record_count { + // Figure out how many bits need to be in a randomly chosen prefix to choose a slice of + // the data set such that the set difference should be around 4096 records. This assumes + // random distribution, which should be mostly maintained by probing prefixes at random. + let prefix_bits = ((total_record_count - my_total_record_count) as f64) / 4096.0; + let prefix_bits = if prefix_bits > 1.0 { + (prefix_bits.log2().ceil() as usize).min(64) + } else { + 0 as usize + }; + let prefix_bytes = (prefix_bits / 8) + (((prefix_bits % 8) != 0) as usize); + + // Generate a random prefix of this many bits (to the nearest byte). + let mut prefix = [0_u8; 64]; + self.host.get_secure_random(&mut prefix[..prefix_bytes]); + + // Request a set summary for this prefix, providing our own count for this prefix so + // the remote can decide whether to send something like an IBLT or just hashes. + let (local_range_start, local_range_end) = range_from_prefix(&prefix, prefix_bits); + connection.send_obj(MESSAGE_TYPE_GET_SUMMARY, &msg::GetSummary { + reference_time, + prefix: &prefix[..prefix_bytes], + prefix_bits: prefix_bits as u8, + record_count: self.datastore.count(reference_time, &local_range_start, &local_range_end) + }, now).await + } else { + Ok(()) + } + } + + /// Read a stream of record hashes (or hash prefixes) from a connection and request records we don't have. + async fn connection_receive_and_process_remote_hash_list(&self, connection: &Arc, mut remaining: isize, reader: &mut BufReader, now: i64, reference_time: i64, common_prefix: &[u8]) -> std::io::Result<()> { + if remaining > 0 { + // Hash list is prefaced by the number of bytes in each hash, since whole 64 byte hashes do not have to be sent. + let prefix_entry_size = reader.read_u8().await? as usize; + let total_prefix_size = common_prefix.len() + prefix_entry_size; + + if prefix_entry_size > 0 && total_prefix_size <= 64 { + remaining -= 1; + if remaining >= (prefix_entry_size as isize) { + let mut get_records_msg: Vec = Vec::with_capacity(((remaining as usize) / prefix_entry_size) * total_prefix_size); + varint::write(&mut get_records_msg, reference_time as u64)?; + get_records_msg.push(total_prefix_size as u8); + + let mut key_prefix_buf = [0_u8; 64]; + key_prefix_buf[..common_prefix.len()].copy_from_slice(common_prefix); + + while remaining >= (prefix_entry_size as isize) { + remaining -= prefix_entry_size as isize; + reader.read_exact(&mut key_prefix_buf[common_prefix.len()..total_prefix_size]).await?; + + if if total_prefix_size < 64 { + let (s, e) = range_from_prefix(&key_prefix_buf[..total_prefix_size], total_prefix_size * 8); + self.datastore.count(reference_time, &s, &e) == 0 + } else { + !self.datastore.contains(reference_time, &key_prefix_buf) + } { + let _ = get_records_msg.write_all(&key_prefix_buf[..total_prefix_size]); + } + } + + if remaining == 0 { + return connection.send_msg(MESSAGE_TYPE_GET_RECORDS, get_records_msg.as_slice(), now).await; + } + } + } + } + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hash list")); + } } impl Drop for NodeInternal { fn drop(&mut self) { - for (_, c) in self.connections.blocking_lock().drain() { - c.1.map(|c| c.abort()); - } + let _ = tokio::runtime::Handle::try_current().map_or_else(|_| { + for (_, c) in self.connections.blocking_lock().drain() { + c.1.map(|c| c.abort()); + } + }, |h| { + let _ = h.block_on(async { + for (_, c) in self.connections.lock().await.drain() { + c.1.map(|c| c.abort()); + } + }); + }); } } @@ -464,11 +677,18 @@ struct Connection { } impl Connection { - async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> { - self.writer.lock().await.write_all(data).await?; - self.last_send_time.store(now, Ordering::Relaxed); - self.bytes_sent.fetch_add(data.len() as u64, Ordering::Relaxed); - Ok(()) + async fn send_msg(&self, message_type: u8, data: &[u8], now: i64) -> std::io::Result<()> { + let mut type_and_size = [0_u8; 16]; + type_and_size[0] = message_type; + let tslen = 1 + varint::encode(&mut type_and_size[1..], data.len() as u64) as usize; + let total_size = tslen + data.len(); + if self.writer.lock().await.write_vectored(&[IoSlice::new(&type_and_size[..tslen]), IoSlice::new(data)]).await? == total_size { + self.last_send_time.store(now, Ordering::Relaxed); + self.bytes_sent.fetch_add(total_size as u64, Ordering::Relaxed); + Ok(()) + } else { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error")) + } } async fn send_obj(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> { diff --git a/syncwhole/src/protocol.rs b/syncwhole/src/protocol.rs index 8da4d3194..5e50fbd05 100644 --- a/syncwhole/src/protocol.rs +++ b/syncwhole/src/protocol.rs @@ -18,63 +18,63 @@ pub const MESSAGE_TYPE_INIT_RESPONSE: u8 = 2; /// Sent every few seconds to notify peers of number of records, clock, etc. pub const MESSAGE_TYPE_STATUS: u8 = 3; +/// Get a set summary of a prefix in the data set. +pub const MESSAGE_TYPE_GET_SUMMARY: u8 = 4; + +/// Set summary of a prefix. +pub const MESSAGE_TYPE_SUMMARY: u8 = 5; + /// Payload is a list of keys of records. Usually sent to advertise recently received new records. -pub const MESSAGE_TYPE_HAVE_RECORDS: u8 = 4; +pub const MESSAGE_TYPE_HAVE_RECORDS: u8 = 6; /// Payload is a list of keys of records the sending node wants. -pub const MESSAGE_TYPE_GET_RECORDS: u8 = 5; +pub const MESSAGE_TYPE_GET_RECORDS: u8 = 7; /// Payload is a record, with key being omitted if the data store's KEY_IS_COMPUTED constant is true. -pub const MESSAGE_TYPE_RECORD: u8 = 6; +pub const MESSAGE_TYPE_RECORD: u8 = 8; + +/// Summary type: simple array of keys under the given prefix. +pub const SUMMARY_TYPE_KEYS: u8 = 0; + +/// An IBLT set summary. +pub const SUMMARY_TYPE_IBLT: u8 = 1; + +/// Number of bytes of each SHA512 hash to announce, request, etc. This is okay to change but 16 is plenty. +pub const ANNOUNCE_HASH_BYTES: usize = 16; pub mod msg { use serde::{Serialize, Deserialize}; #[derive(Serialize, Deserialize)] pub struct IPv4 { - #[serde(rename = "i")] pub ip: [u8; 4], - #[serde(rename = "p")] pub port: u16 } #[derive(Serialize, Deserialize)] pub struct IPv6 { - #[serde(rename = "i")] pub ip: [u8; 16], - #[serde(rename = "p")] pub port: u16 } #[derive(Serialize, Deserialize)] pub struct Init<'a> { /// A random challenge to be hashed with a secret to detect and drop connections to self. - #[serde(rename = "alc")] #[serde(with = "serde_bytes")] pub anti_loopback_challenge: &'a [u8], + /// A random challenge for checking the data set domain. + #[serde(with = "serde_bytes")] + pub domain_challenge: &'a [u8], + /// A random challenge for login/authentication. #[serde(with = "serde_bytes")] - pub challenge: &'a [u8], - - /// An arbitrary name for this data set to avoid connecting to peers not replicating it. - #[serde(rename = "d")] - pub domain: String, - - /// Size of keys in this data set in bytes. - #[serde(rename = "ks")] - pub key_size: u16, - - /// Maximum allowed size of values in this data set in bytes. - #[serde(rename = "mvs")] - pub max_value_size: u64, + pub auth_challenge: &'a [u8], /// Optional name to advertise for this node. - #[serde(rename = "nn")] pub node_name: Option, /// Optional contact information for this node, such as a URL or an e-mail address. - #[serde(rename = "nc")] pub node_contact: Option, /// Port to which this node has locally bound. @@ -83,35 +83,85 @@ pub mod msg { /// An IPv4 address where this node can be reached. /// If both explicit_ipv4 and explicit_ipv6 are omitted the physical source IP:port may be used. - #[serde(rename = "ei4")] pub explicit_ipv4: Option, /// An IPv6 address where this node can be reached. /// If both explicit_ipv4 and explicit_ipv6 are omitted the physical source IP:port may be used. - #[serde(rename = "ei6")] pub explicit_ipv6: Option, } #[derive(Serialize, Deserialize)] pub struct InitResponse<'a> { /// HMAC-SHA512(local secret, anti_loopback_challenge) to detect and drop loops. - #[serde(rename = "alr")] #[serde(with = "serde_bytes")] pub anti_loopback_response: &'a [u8], + /// HMAC-SHA512(SHA512(domain), domain_challenge) to check that the data set domain matches. + #[serde(with = "serde_bytes")] + pub domain_response: &'a [u8], + /// HMAC-SHA512(secret, challenge) for authentication. (If auth is not enabled, an all-zero secret is used.) #[serde(with = "serde_bytes")] - pub challenge_response: &'a [u8], + pub auth_response: &'a [u8], } #[derive(Serialize, Deserialize, Clone)] pub struct Status { /// Total number of records in data set. - #[serde(rename = "rc")] - pub record_count: u64, - - /// Local wall clock time in milliseconds since Unix epoch. #[serde(rename = "c")] - pub clock: u64, + pub total_record_count: u64, + + /// Reference wall clock time in milliseconds since Unix epoch. + #[serde(rename = "t")] + pub reference_time: i64, + } + + #[derive(Serialize, Deserialize, Clone)] + pub struct GetSummary<'a> { + /// Reference wall clock time in milliseconds since Unix epoch. + #[serde(rename = "t")] + pub reference_time: i64, + + /// Prefix within key space. + #[serde(rename = "p")] + #[serde(with = "serde_bytes")] + pub prefix: &'a [u8], + + /// Length of prefix in bits (trailing bits in byte array are ignored). + #[serde(rename = "b")] + pub prefix_bits: u8, + + /// Number of records in this range the requesting node already has, used to choose summary type. + #[serde(rename = "r")] + pub record_count: u64, + } + + #[derive(Serialize, Deserialize, Clone)] + pub struct SummaryHeader<'a> { + /// Total number of records in data set, for easy rapid generation of next query. + #[serde(rename = "c")] + pub total_record_count: u64, + + /// Reference wall clock time in milliseconds since Unix epoch. + #[serde(rename = "t")] + pub reference_time: i64, + + /// Random salt value used by some summary types. + #[serde(rename = "x")] + #[serde(with = "serde_bytes")] + pub salt: &'a [u8], + + /// Prefix within key space. + #[serde(rename = "p")] + #[serde(with = "serde_bytes")] + pub prefix: &'a [u8], + + /// Length of prefix in bits (trailing bits in byte array are ignored). + #[serde(rename = "b")] + pub prefix_bits: u8, + + /// Type of summary that follows this header. + #[serde(rename = "s")] + pub summary_type: u8, } } diff --git a/syncwhole/src/utils.rs b/syncwhole/src/utils.rs new file mode 100644 index 000000000..49f50d5b9 --- /dev/null +++ b/syncwhole/src/utils.rs @@ -0,0 +1,102 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + * + * (c)2022 ZeroTier, Inc. + * https://www.zerotier.com/ + */ + +use std::collections::HashMap; +use std::future::Future; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use tokio::task::JoinHandle; + +/// Get the real time clock in milliseconds since Unix epoch. +pub fn ms_since_epoch() -> i64 { + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as i64 +} + +/// Get the current monotonic clock in milliseconds. +pub fn ms_monotonic() -> i64 { + std::time::Instant::now().elapsed().as_millis() as i64 +} + +/// Encode a byte slice to a hexadecimal string. +pub fn to_hex_string(b: &[u8]) -> String { + const HEX_CHARS: [u8; 16] = [b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', b'b', b'c', b'd', b'e', b'f']; + let mut s = String::new(); + s.reserve(b.len() * 2); + for c in b { + let x = *c as usize; + s.push(HEX_CHARS[x >> 4] as char); + s.push(HEX_CHARS[x & 0xf] as char); + } + s +} + +#[inline(always)] +pub fn xorshift64(mut x: u64) -> u64 { + x = u64::from_le(x); + x ^= x.wrapping_shl(13); + x ^= x.wrapping_shr(7); + x ^= x.wrapping_shl(17); + x.to_le() +} + +#[inline(always)] +pub fn splitmix64(mut x: u64) -> u64 { + x = u64::from_le(x); + x ^= x.wrapping_shr(30); + x = x.wrapping_mul(0xbf58476d1ce4e5b9); + x ^= x.wrapping_shr(27); + x = x.wrapping_mul(0x94d049bb133111eb); + x ^= x.wrapping_shr(31); + x.to_le() +} + +#[inline(always)] +pub fn splitmix64_inverse(mut x: u64) -> u64 { + x = u64::from_le(x); + x ^= x.wrapping_shr(31) ^ x.wrapping_shr(62); + x = x.wrapping_mul(0x319642b2d24d8ec3); + x ^= x.wrapping_shr(27) ^ x.wrapping_shr(54); + x = x.wrapping_mul(0x96de1b173f119089); + x ^= x.wrapping_shr(30) ^ x.wrapping_shr(60); + x.to_le() +} + +/// Wrapper for tokio::spawn() that aborts tasks not yet completed when it is dropped. +pub struct AsyncTaskReaper { + ctr: AtomicUsize, + handles: Arc>>>, +} + +impl AsyncTaskReaper { + pub fn new() -> Self { + Self { + ctr: AtomicUsize::new(0), + handles: Arc::new(std::sync::Mutex::new(HashMap::new())) + } + } + + /// Spawn a new task. + /// Note that currently any output is ignored. This is primarily for background tasks + /// that are used similarly to goroutines in Go. + pub fn spawn(&self, future: F) { + let id = self.ctr.fetch_add(1, Ordering::Relaxed); + let handles = self.handles.clone(); + self.handles.lock().unwrap().insert(id, tokio::spawn(async move { + let _ = future.await; + let _ = handles.lock().unwrap().remove(&id); + })); + } +} + +impl Drop for AsyncTaskReaper { + fn drop(&mut self) { + for (_, h) in self.handles.lock().unwrap().iter() { + h.abort(); + } + } +} diff --git a/syncwhole/src/varint.rs b/syncwhole/src/varint.rs index 89cd4acdf..b94f41c94 100644 --- a/syncwhole/src/varint.rs +++ b/syncwhole/src/varint.rs @@ -36,17 +36,11 @@ pub async fn write_async(w: &mut W, v: u64) -> std::io::R pub async fn read_async(r: &mut R) -> std::io::Result<(u64, usize)> { let mut v = 0_u64; - let mut buf = [0_u8; 1]; let mut pos = 0; let mut count = 0; loop { - loop { - if r.read(&mut buf).await? == 1 { - break; - } - } count += 1; - let b = buf[0]; + let b = r.read_u8().await?; if b <= 0x7f { v |= (b as u64).wrapping_shl(pos); pos += 7;