diff --git a/syncwhole/Cargo.lock b/syncwhole/Cargo.lock index 721bb5dfc..764ded32b 100644 --- a/syncwhole/Cargo.lock +++ b/syncwhole/Cargo.lock @@ -14,6 +14,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +dependencies = [ + "generic-array", +] + [[package]] name = "byteorder" version = "1.4.3" @@ -32,6 +41,45 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cpufeatures" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "generic-array" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "libc" version = "0.2.119" @@ -205,6 +253,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha2" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55deaec60f81eefe3cce0dc50bda92d6d8e88f2a27df7c5033b42afeb1ed2676" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "smallvec" version = "1.8.0" @@ -239,6 +298,7 @@ dependencies = [ "rmp", "rmp-serde", "serde", + "sha2", "tokio", ] @@ -258,12 +318,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "typenum" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" + [[package]] name = "unicode-xid" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "winapi" version = "0.3.9" diff --git a/syncwhole/Cargo.toml b/syncwhole/Cargo.toml index e3873a8e8..b1f618261 100644 --- a/syncwhole/Cargo.toml +++ b/syncwhole/Cargo.toml @@ -10,3 +10,6 @@ tokio = { version = "^1", features = ["net", "rt", "parking_lot", "time", "io-st serde = { version = "^1", features = ["derive"], default-features = false } rmp = "^0" rmp-serde = "^1" + +[dev-dependencies] +sha2 = "^0" diff --git a/syncwhole/src/datastore.rs b/syncwhole/src/datastore.rs index f282d3021..99f9a16bf 100644 --- a/syncwhole/src/datastore.rs +++ b/syncwhole/src/datastore.rs @@ -27,13 +27,13 @@ pub enum StoreResult { /// Entry was accepted (whether or not an old value was replaced). Ok, - /// Entry was rejected as a duplicate but was otherwise valid. + /// Entry was a duplicate of one we already have but was otherwise valid. Duplicate, - /// Entry was rejected as invalid. - /// - /// An invalid entry is one that is malformed, fails a signature check, etc., and returning - /// this causes the synchronization service to drop the link to the node that sent it. + /// Entry was valid but was ignored for an unspecified reason. + Ignored, + + /// Entry was rejected as malformed or otherwise invalid (e.g. failed signature check). Rejected } @@ -45,12 +45,19 @@ pub enum StoreResult { /// what time I think it is" value to be considered locally so that data can be replicated /// as of any given time. /// +/// The constants KEY_SIZE, MAX_VALUE_SIZE, and KEY_IS_COMPUTED are protocol constants +/// for your replication domain. They can't be changed once defined unless all nodes +/// are upgraded at once. +/// /// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a function of /// values. If this is true, get_key() must be implemented. /// -/// The implementation must be thread safe. +/// The implementation must be thread safe and may be called concurrently. pub trait DataStore: Sync + Send { /// Type to be enclosed in the Ok() enum value in LoadResult. + /// + /// Allowing this type to be defined lets you use any type that makes sense with + /// your implementation. Examples include Box<[u8]>, Arc<[u8]>, Vec, etc. type LoadResultValueType: AsRef<[u8]> + Send; /// Size of keys, which must be fixed in length. These are typically hashes. @@ -61,11 +68,14 @@ pub trait DataStore: Sync + Send { /// This should be true if the key is computed, such as by hashing the value. /// - /// If this is true only values are sent over the wire and get_key() is used to compute - /// keys from values. If this is false both keys and values are replicated. + /// If this is true then keys do not have to be sent over the wire. Instead they + /// are computed by calling get_key(). If this is false keys are assumed not to + /// be computable from values and are explicitly sent. const KEY_IS_COMPUTED: bool; - /// Get the key corresponding to a value. + /// Compute the key corresponding to a value. + /// + /// The 'key' slice must be KEY_SIZE bytes in length. /// /// If KEY_IS_COMPUTED is true this must be implemented. The default implementation /// panics to indicate this. If KEY_IS_COMPUTED is false this is never called. @@ -74,44 +84,56 @@ pub trait DataStore: Sync + Send { panic!("get_key() must be implemented if KEY_IS_COMPUTED is true"); } - /// Get the domain of this data store, which is just an arbitrary unique identifier. + /// Get the domain of this data store. + /// + /// This is an arbitrary unique identifier that must be the same for all nodes that + /// are replicating the same data. It's checked on connect to avoid trying to share + /// data across data sets if this is not desired. fn domain(&self) -> &str; /// Get an item if it exists as of a given reference time. - /// - /// The supplied key must be of length KEY_SIZE or this may panic. fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult; - /// Store an item in the data store and return Ok, Duplicate, or Rejected. - /// - /// The supplied key must be of length KEY_SIZE or this may panic. + /// Store an item in the data store and return its status. /// /// Note that no time is supplied here. The data store must determine this in an implementation /// dependent manner if this is a temporally subjective data store. It could be determined by /// the wall clock, from the object itself, etc. + /// + /// The implementation is responsible for validating inputs and returning 'Rejected' if they + /// are invalid. A return value of Rejected can be used to do things like drop connections + /// to peers that send invalid data, so it should only be returned if the data is malformed + /// or something like a signature check fails. The 'Ignored' enum value should be returned + /// for inputs that are valid but were not stored for some other reason, such as being + /// expired. It's important to return 'Ok' for accepted values to hint to the replicator + /// that they should be aggressively advertised to other peers. + /// + /// If KEY_IS_COMPUTED is true, the key supplied here can be assumed to be correct. It will + /// have been computed via get_key(). fn store(&self, key: &[u8], value: &[u8]) -> StoreResult; /// Get the number of items under a prefix as of a given reference time. /// - /// The default implementation uses for_each(). This can be specialized if it can be done - /// more efficiently than that. + /// The default implementation uses for_each_key(). This can be specialized if it can + /// be done more efficiently than that. fn count(&self, reference_time: i64, key_prefix: &[u8]) -> u64 { let mut cnt: u64 = 0; - self.for_each(reference_time, key_prefix, |_, _| { + self.for_each_key(reference_time, key_prefix, |_| { cnt += 1; true }); cnt } - /// Iterate through keys beneath a key prefix, stopping at the end or if the function returns false. + /// Iterate through keys beneath a key prefix, stopping if the function returns false. /// - /// The default implementation uses for_each(). + /// The default implementation uses for_each(). This can be specialized if it's faster to + /// only load keys. fn for_each_key bool>(&self, reference_time: i64, key_prefix: &[u8], mut f: F) { self.for_each(reference_time, key_prefix, |k, _| f(k)); } - /// Iterate through keys and values beneath a key prefix, stopping at the end or if the function returns false. + /// Iterate through keys and values beneath a key prefix, stopping if the function returns false. fn for_each bool>(&self, reference_time: i64, key_prefix: &[u8], f: F); } @@ -196,4 +218,3 @@ impl PartialEq for MemoryDatabase { } impl Eq for MemoryDatabase {} - diff --git a/syncwhole/src/host.rs b/syncwhole/src/host.rs index 5085096b5..bb4337100 100644 --- a/syncwhole/src/host.rs +++ b/syncwhole/src/host.rs @@ -14,32 +14,51 @@ use crate::node::RemoteNodeInfo; /// A trait that users of syncwhole implement to provide configuration information and listen for events. pub trait Host: Sync + Send { - /// Compute SHA512. - fn sha512(msg: &[u8]) -> [u8; 64]; - - /// Get a list of endpoints to which we always want to try to stay connected. + /// Get a list of peer addresses to which we always want to try to stay connected. /// - /// The node will repeatedly try to connect to these until a link is established and - /// reconnect on link failure. They should be stable well known nodes for this domain. - fn get_static_endpoints(&self) -> &[SocketAddr]; + /// These are always contacted until a link is established regardless of anything else. + fn fixed_peers(&self) -> &[SocketAddr]; - /// Get additional endpoints to try. - /// - /// This should return any endpoints not in the supplied endpoint set if the size - /// of the set is less than the minimum active link count the host wishes to maintain. - fn get_more_endpoints(&self, current_endpoints: &HashSet) -> Vec; + /// Get a random peer address not in the supplied set. + fn another_peer(&self, exclude: &HashSet) -> Option; /// Get the maximum number of endpoints allowed. /// - /// This is checked on incoming connect and incoming links are refused if the total is over this count. - fn max_endpoints(&self) -> usize; + /// This is checked on incoming connect and incoming links are refused if the total is + /// over this count. Fixed endpoints will be contacted even if the total is over this limit. + fn max_connection_count(&self) -> usize; - /// Called whenever we have successfully connected to a remote node (after connection is initialized). + /// Get the number of connections we ideally want. + /// + /// Attempts will be made to lazily contact remote endpoints if the total number of links + /// is under this amount. Note that fixed endpoints will still be contacted even if the + /// total is over the desired count. + /// + /// This should always be less than max_connection_count(). + fn desired_connection_count(&self) -> usize; + + /// Test whether an inbound connection should be allowed from an address. + fn allow(&self, remote_address: &SocketAddr) -> bool; + + /// Called when an attempt is made to connect to a remote address. + fn on_connect_attempt(&self, address: &SocketAddr); + + /// Called when a connection has been successfully established. + /// + /// Hosts are encouraged to learn endpoints when a successful outbound connection is made. Check the + /// inbound flag in the remote node info structure. fn on_connect(&self, info: &RemoteNodeInfo); - /// Called when an open connection is closed. - fn on_connection_closed(&self, endpoint: &SocketAddr, reason: Option>); + /// Called when an open connection is closed for any reason. + fn on_connection_closed(&self, address: &SocketAddr, reason: Option>); /// Fill a buffer with secure random bytes. + /// + /// This is supplied to reduce inherent dependencies and allow the user to choose the implementation. fn get_secure_random(&self, buf: &mut [u8]); + + /// Compute a SHA512 digest of the input. + /// + /// This is supplied to reduce inherent dependencies and allow the user to choose the implementation. + fn sha512(msg: &[u8]) -> [u8; 64]; } diff --git a/syncwhole/src/iblt.rs b/syncwhole/src/iblt.rs index 7db4bb24c..fe0c16cb4 100644 --- a/syncwhole/src/iblt.rs +++ b/syncwhole/src/iblt.rs @@ -6,21 +6,20 @@ * https://www.zerotier.com/ */ -use std::mem::{size_of, zeroed}; +use std::alloc::{alloc_zeroed, dealloc, Layout}; +use std::mem::size_of; use std::ptr::write_bytes; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::varint; -// max value: 6, 5 was determined to be good via empirical testing -const KEY_MAPPING_ITERATIONS: usize = 5; - #[inline(always)] fn xorshift64(mut x: u64) -> u64 { + x = u64::from_le(x); x ^= x.wrapping_shl(13); x ^= x.wrapping_shr(7); x ^= x.wrapping_shl(17); - x + x.to_le() } #[inline(always)] @@ -45,22 +44,9 @@ fn splitmix64_inverse(mut x: u64) -> u64 { x.to_le() } -// https://nullprogram.com/blog/2018/07/31/ #[inline(always)] -fn triple32(mut x: u32) -> u32 { - x ^= x.wrapping_shr(17); - x = x.wrapping_mul(0xed5ad4bb); - x ^= x.wrapping_shr(11); - x = x.wrapping_mul(0xac4c1b51); - x ^= x.wrapping_shr(15); - x = x.wrapping_mul(0x31848bab); - x ^= x.wrapping_shr(14); - x -} - -#[inline(always)] -fn next_iteration_index(prev_iteration_index: u64, salt: u64) -> u64 { - prev_iteration_index.wrapping_add(triple32(prev_iteration_index.wrapping_shr(32) as u32) as u64).wrapping_add(salt) +fn next_iteration_index(prev_iteration_index: u64) -> u64 { + splitmix64(prev_iteration_index.wrapping_add(1)) } #[derive(Clone, PartialEq, Eq)] @@ -84,29 +70,31 @@ impl IBLTEntry { /// An Invertible Bloom Lookup Table for set reconciliation with 64-bit hashes. #[derive(Clone, PartialEq, Eq)] pub struct IBLT { - salt: u64, - map: [IBLTEntry; B] + map: *mut [IBLTEntry; B] } impl IBLT { + /// Number of buckets (capacity) of this IBLT. pub const BUCKETS: usize = B; - pub fn new(salt: u64) -> Self { + /// This was determined to be effective via empirical testing with random keys. This + /// is a protocol constant that can't be changed without upgrading all nodes in a domain. + const KEY_MAPPING_ITERATIONS: usize = 2; + + pub fn new() -> Self { + assert!(B < u32::MAX as usize); // sanity check Self { - salt, - map: unsafe { zeroed() } + map: unsafe { alloc_zeroed(Layout::new::<[IBLTEntry; B]>()).cast() } } } - pub fn reset(&mut self, salt: u64) { - self.salt = salt; - unsafe { write_bytes((&mut self.map as *mut IBLTEntry).cast::(), 0, size_of::<[IBLTEntry; B]>()) }; + pub fn reset(&mut self) { + unsafe { write_bytes(self.map.cast::(), 0, size_of::<[IBLTEntry; B]>()) }; } pub async fn read(&mut self, r: &mut R) -> std::io::Result<()> { - r.read_exact(unsafe { &mut *(&mut self.salt as *mut u64).cast::<[u8; 8]>() }).await?; let mut prev_c = 0_i64; - for b in self.map.iter_mut() { + for b in unsafe { (*self.map).iter_mut() } { let _ = r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() }).await?; let _ = r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() }).await?; let mut c = varint::read_async(r).await? as i64; @@ -122,9 +110,8 @@ impl IBLT { } pub async fn write(&self, w: &mut W) -> std::io::Result<()> { - let _ = w.write_all(unsafe { &*(&self.salt as *const u64).cast::<[u8; 8]>() }).await?; let mut prev_c = 0_i64; - for b in self.map.iter() { + for b in unsafe { (*self.map).iter() } { let _ = w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() }).await?; let _ = w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() }).await?; let mut c = (b.count - prev_c).wrapping_shl(1); @@ -137,13 +124,25 @@ impl IBLT { Ok(()) } + /// Get this IBLT as a byte array. + pub fn to_bytes(&self) -> Vec { + let mut out = std::io::Cursor::new(Vec::::with_capacity(B * 20)); + let current = tokio::runtime::Handle::try_current(); + if current.is_ok() { + assert!(current.unwrap().block_on(self.write(&mut out)).is_ok()); + } else { + assert!(tokio::runtime::Builder::new_current_thread().build().unwrap().block_on(self.write(&mut out)).is_ok()); + } + out.into_inner() + } + fn ins_rem(&mut self, mut key: u64, delta: i64) { - key = splitmix64(key ^ self.salt); + key = splitmix64(key); let check_hash = xorshift64(key); let mut iteration_index = u64::from_le(key); - for _ in 0..KEY_MAPPING_ITERATIONS { - iteration_index = next_iteration_index(iteration_index, self.salt); - let b = unsafe { self.map.get_unchecked_mut((iteration_index as usize) % B) }; + for _ in 0..Self::KEY_MAPPING_ITERATIONS { + iteration_index = next_iteration_index(iteration_index); + let b = unsafe { (*self.map).get_unchecked_mut((iteration_index as usize) % B) }; b.key_sum ^= key; b.check_hash_sum ^= check_hash; b.count += delta; @@ -166,53 +165,51 @@ impl IBLT { self.ins_rem(unsafe { u64::from_ne_bytes(*(key.as_ptr().cast::<[u8; 8]>())) }, -1); } + /// Subtract another IBLT from this one to get a set difference. pub fn subtract(&mut self, other: &Self) { - for b in 0..B { - let s = &mut self.map[b]; - let o = &other.map[b]; + for (s, o) in unsafe { (*self.map).iter_mut().zip((*other.map).iter()) } { s.key_sum ^= o.key_sum; s.check_hash_sum ^= o.check_hash_sum; - s.count += o.count; + s.count -= o.count; } } pub fn list(mut self, mut f: F) -> bool { - let mut singular_buckets = [0_usize; B]; - let mut singular_bucket_count = 0_usize; + let mut queue: Vec = Vec::with_capacity(B); for b in 0..B { - if self.map[b].is_singular() { - singular_buckets[singular_bucket_count] = b; - singular_bucket_count += 1; + if unsafe { (*self.map).get_unchecked(b).is_singular() } { + queue.push(b as u32); } } - while singular_bucket_count > 0 { - singular_bucket_count -= 1; - let b = &self.map[singular_buckets[singular_bucket_count]]; - + loop { + let b = queue.pop(); + let b = if b.is_some() { + unsafe { (*self.map).get_unchecked_mut(b.unwrap() as usize) } + } else { + break; + }; if b.is_singular() { let key = b.key_sum; let check_hash = xorshift64(key); let mut iteration_index = u64::from_le(key); - f(&(splitmix64_inverse(key) ^ self.salt).to_ne_bytes()); + f(&(splitmix64_inverse(key)).to_ne_bytes()); - for _ in 0..KEY_MAPPING_ITERATIONS { - iteration_index = next_iteration_index(iteration_index, self.salt); - let b_idx = (iteration_index as usize) % B; - let b = unsafe { self.map.get_unchecked_mut(b_idx) }; + for _ in 0..Self::KEY_MAPPING_ITERATIONS { + iteration_index = next_iteration_index(iteration_index); + let b_idx = iteration_index % (B as u64); + let b = unsafe { (*self.map).get_unchecked_mut(b_idx as usize) }; b.key_sum ^= key; b.check_hash_sum ^= check_hash; b.count -= 1; if b.is_singular() { - if singular_bucket_count >= B { - // This would indicate an invalid IBLT. + if queue.len() >= (B * 2) { // sanity check for invalid IBLT return false; } - singular_buckets[singular_bucket_count] = b_idx; - singular_bucket_count += 1; + queue.push(b_idx as u32); } } } @@ -222,6 +219,12 @@ impl IBLT { } } +impl Drop for IBLT { + fn drop(&mut self) { + unsafe { dealloc(self.map.cast(), Layout::new::<[IBLTEntry; B]>()) }; + } +} + #[cfg(test)] mod tests { use std::collections::HashSet; @@ -230,32 +233,75 @@ mod tests { #[test] fn splitmix_is_invertiblex() { - for i in 1..1024_u64 { + for i in 1..2000_u64 { assert_eq!(i, splitmix64_inverse(splitmix64(i))) } } #[test] - fn insert_and_list() { + fn fill_list_performance() { let mut rn = xorshift64(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos() as u64); - for _ in 0..256 { - let mut alice: IBLT<1024> = IBLT::new(rn); - rn = xorshift64(rn); - let mut expected: HashSet = HashSet::with_capacity(1024); - let count = 600; + let mut expected: HashSet = HashSet::with_capacity(4096); + let mut count = 64; + const CAPACITY: usize = 4096; + while count <= CAPACITY { + let mut test: IBLT = IBLT::new(); + expected.clear(); + for _ in 0..count { let x = rn; - rn = xorshift64(rn); + rn = splitmix64(rn); expected.insert(x); - alice.insert(&x.to_ne_bytes()); + test.insert(&x.to_ne_bytes()); } - let mut cnt = 0; - alice.list(|x| { + + let mut list_count = 0; + test.list(|x| { let x = u64::from_ne_bytes(*x); - cnt += 1; + list_count += 1; assert!(expected.contains(&x)); }); - assert_eq!(cnt, count); + + //println!("inserted: {}\tlisted: {}\tcapacity: {}\tscore: {:.4}\tfill: {:.4}", count, list_count, CAPACITY, (list_count as f64) / (count as f64), (count as f64) / (CAPACITY as f64)); + count += 64; + } + } + + #[test] + fn merge_sets() { + let mut rn = xorshift64(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos() as u64); + const CAPACITY: usize = 16384; + const REMOTE_SIZE: usize = 1024 * 1024; + const STEP: usize = 1024; + let mut missing_count = 1024; + let mut missing: HashSet = HashSet::with_capacity(CAPACITY); + while missing_count <= CAPACITY { + missing.clear(); + let mut local: IBLT = IBLT::new(); + let mut remote: IBLT = IBLT::new(); + + for k in 0..REMOTE_SIZE { + if k >= missing_count { + local.insert(&rn.to_ne_bytes()); + } else { + missing.insert(rn); + } + remote.insert(&rn.to_ne_bytes()); + rn = splitmix64(rn); + } + + local.subtract(&mut remote); + let bytes = local.to_bytes().len(); + let mut cnt = 0; + local.list(|k| { + let k = u64::from_ne_bytes(*k); + assert!(missing.contains(&k)); + cnt += 1; + }); + + println!("total: {} missing: {:5} recovered: {:5} size: {:5} score: {:.4} bytes/item: {:.2}", REMOTE_SIZE, missing.len(), cnt, bytes, (cnt as f64) / (missing.len() as f64), (bytes as f64) / (cnt as f64)); + + missing_count += STEP; } } } diff --git a/syncwhole/src/node.rs b/syncwhole/src/node.rs index 905a42b1a..8971b14dc 100644 --- a/syncwhole/src/node.rs +++ b/syncwhole/src/node.rs @@ -6,7 +6,7 @@ * https://www.zerotier.com/ */ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::sync::{Arc, Weak}; use std::sync::atomic::{AtomicI64, Ordering}; @@ -26,20 +26,46 @@ use crate::protocol::*; use crate::varint; const CONNECTION_TIMEOUT: i64 = 60000; -const CONNECTION_KEEPALIVE_AFTER: i64 = 20000; +const CONNECTION_KEEPALIVE_AFTER: i64 = CONNECTION_TIMEOUT / 3; +const HOUSEKEEPING_INTERVAL: i64 = CONNECTION_KEEPALIVE_AFTER / 2; const IO_BUFFER_SIZE: usize = 65536; -#[derive(Clone, PartialEq, Eq)] +/// Information about a remote node to which we are connected. +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RemoteNodeInfo { + /// Optional name advertised by remote node (arbitrary). pub node_name: Option, + + /// Optional contact information advertised by remote node (arbitrary). pub node_contact: Option, - pub endpoint: SocketAddr, - pub preferred_endpoints: Vec, + + /// Actual remote endpoint address. + pub remote_address: SocketAddr, + + /// Explicitly advertised remote addresses supplied by remote node (not necessarily verified). + pub explicit_addresses: Vec, + + /// Time TCP connection was established. pub connect_time: SystemTime, + + /// True if this is an inbound TCP connection. pub inbound: bool, + + /// True if this connection has exchanged init messages. pub initialized: bool, } +fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> { + if socket.set_reuseport(true).is_err() { + socket.set_reuseaddr(true)?; + } + Ok(()) +} + +/// An instance of the syncwhole data set synchronization engine. +/// +/// This holds a number of async tasks that are terminated or aborted if this object +/// is dropped. In other words this implements structured concurrency. pub struct Node { internal: Arc>, housekeeping_task: JoinHandle<()>, @@ -49,9 +75,7 @@ pub struct Node { impl Node { pub async fn new(db: Arc, host: Arc, bind_address: SocketAddr) -> std::io::Result { let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; - if listener.set_reuseport(true).is_err() { - listener.set_reuseaddr(true)?; - } + configure_tcp_socket(&listener)?; listener.bind(bind_address.clone())?; let listener = listener.listen(1024)?; @@ -65,6 +89,7 @@ impl Node { host: host.clone(), bind_address, connections: Mutex::new(HashMap::with_capacity(64)), + attempts: Mutex::new(HashMap::with_capacity(64)), }); Ok(Self { @@ -76,7 +101,7 @@ impl Node { #[inline(always)] pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result { - self.internal.connect(endpoint).await + self.internal.clone().connect(endpoint).await } pub fn list_connections(&self) -> Vec { @@ -105,17 +130,20 @@ pub struct NodeInternal { host: Arc, bind_address: SocketAddr, connections: Mutex, Option>>)>>, + attempts: Mutex>>>, } impl NodeInternal { async fn housekeeping_task_main(self: Arc) { loop { - tokio::time::sleep(Duration::from_millis((CONNECTION_KEEPALIVE_AFTER / 2) as u64)).await; + tokio::time::sleep(Duration::from_millis(HOUSEKEEPING_INTERVAL as u64)).await; let mut to_ping: Vec> = Vec::new(); let mut dead: Vec<(SocketAddr, Option>>)> = Vec::new(); + let mut current_endpoints: HashSet = HashSet::new(); let mut connections = self.connections.lock().await; + current_endpoints.reserve(connections.len() + 1); let now = ms_monotonic(); connections.retain(|sa, c| { let cc = c.0.upgrade(); @@ -125,6 +153,7 @@ impl NodeInternal { if (now - cc.last_send_time.load(Ordering::Relaxed)) >= CONNECTION_KEEPALIVE_AFTER { to_ping.push(cc); } + current_endpoints.insert(sa.clone()); true } else { c.1.take().map(|j| j.abort()); @@ -152,15 +181,43 @@ impl NodeInternal { for c in to_ping.iter() { let _ = c.send(&[MESSAGE_TYPE_NOP, 0], now).await; } + + let desired = self.host.desired_connection_count(); + let fixed = self.host.fixed_peers(); + + let mut attempts = self.attempts.lock().await; + + for ep in fixed.iter() { + if !current_endpoints.contains(ep) { + let self2 = self.clone(); + let ep2 = ep.clone(); + attempts.insert(ep.clone(), tokio::spawn(async move { self2.connect(&ep2).await })); + current_endpoints.insert(ep.clone()); + } + } + + while current_endpoints.len() < desired { + let ep = self.host.another_peer(¤t_endpoints); + if ep.is_some() { + let ep = ep.unwrap(); + current_endpoints.insert(ep.clone()); + let self2 = self.clone(); + attempts.insert(ep.clone(), tokio::spawn(async move { self2.connect(&ep).await })); + } else { + break; + } + } } } async fn listener_task_main(self: Arc, listener: TcpListener) { loop { let socket = listener.accept().await; - if self.connections.lock().await.len() < self.host.max_endpoints() && socket.is_ok() { + if self.connections.lock().await.len() < self.host.max_connection_count() && socket.is_ok() { let (stream, endpoint) = socket.unwrap(); - Self::connection_start(&self, endpoint, stream, true).await; + if self.host.allow(&endpoint) { + Self::connection_start(&self, endpoint, stream, true).await; + } } } } @@ -175,8 +232,8 @@ impl NodeInternal { max_value_size: D::MAX_VALUE_SIZE as u64, node_name: None, node_contact: None, - preferred_ipv4: None, - preferred_ipv6: None + explicit_ipv4: None, + explicit_ipv6: None }, ms_monotonic()).await?; let mut init_received = false; @@ -194,6 +251,7 @@ impl NodeInternal { let now = ms_monotonic(); match message_type { + MESSAGE_TYPE_INIT => { if init_received { return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init")); @@ -220,13 +278,14 @@ impl NodeInternal { let mut info = connection.info.lock().unwrap(); info.node_name = msg.node_name.clone(); info.node_contact = msg.node_contact.clone(); - let _ = msg.preferred_ipv4.map(|pv4| { - info.preferred_endpoints.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(pv4.ip[0], pv4.ip[1], pv4.ip[2], pv4.ip[3]), pv4.port))); + let _ = msg.explicit_ipv4.map(|pv4| { + info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port))); }); - let _ = msg.preferred_ipv6.map(|pv6| { - info.preferred_endpoints.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0))); + let _ = msg.explicit_ipv6.map(|pv6| { + info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0))); }); }, + MESSAGE_TYPE_INIT_RESPONSE => { if initialized { return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init response")); @@ -246,6 +305,7 @@ impl NodeInternal { let info = info.clone(); self.host.on_connect(&info); }, + _ => { // Skip messages that aren't recognized or don't need to be parsed like NOP. let mut remaining = message_size as usize; @@ -256,16 +316,16 @@ impl NodeInternal { } connection.last_receive_time.store(ms_monotonic(), Ordering::Relaxed); } + } } } - async fn connection_start(self: &Arc, endpoint: SocketAddr, stream: TcpStream, inbound: bool) -> bool { - let _ = stream.set_nodelay(true); + async fn connection_start(self: &Arc, address: SocketAddr, stream: TcpStream, inbound: bool) -> bool { let (reader, writer) = stream.into_split(); let mut ok = false; - let _ = self.connections.lock().await.entry(endpoint.clone()).or_insert_with(|| { + let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| { ok = true; let now = ms_monotonic(); let connection = Arc::new(Connection { @@ -275,8 +335,8 @@ impl NodeInternal { info: std::sync::Mutex::new(RemoteNodeInfo { node_name: None, node_contact: None, - endpoint: endpoint.clone(), - preferred_endpoints: Vec::new(), + remote_address: address.clone(), + explicit_addresses: Vec::new(), connect_time: SystemTime::now(), inbound, initialized: false @@ -287,23 +347,26 @@ impl NodeInternal { ok } - async fn connect(self: &Arc, endpoint: &SocketAddr) -> std::io::Result { - if !self.connections.lock().await.contains_key(endpoint) { - let stream = if endpoint.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; - if stream.set_reuseport(true).is_err() { - stream.set_reuseaddr(true)?; - } + async fn connect(self: Arc, address: &SocketAddr) -> std::io::Result { + let mut success = false; + if !self.connections.lock().await.contains_key(address) { + self.host.on_connect_attempt(address); + let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; + configure_tcp_socket(&stream)?; stream.bind(self.bind_address.clone())?; - let stream = stream.connect(endpoint.clone()).await?; - Ok(self.connection_start(endpoint.clone(), stream, false).await) - } else { - Ok(false) + let stream = stream.connect(address.clone()).await?; + success = self.connection_start(address.clone(), stream, false).await; } + self.attempts.lock().await.remove(address); + Ok(success) } } impl Drop for NodeInternal { fn drop(&mut self) { + for a in self.attempts.blocking_lock().iter() { + a.1.abort(); + } for (_, c) in self.connections.blocking_lock().drain() { c.1.map(|c| c.abort()); } diff --git a/syncwhole/src/protocol.rs b/syncwhole/src/protocol.rs index 896925e48..90e9a8be0 100644 --- a/syncwhole/src/protocol.rs +++ b/syncwhole/src/protocol.rs @@ -45,10 +45,10 @@ pub mod msg { pub node_name: Option, #[serde(rename = "nc")] pub node_contact: Option, - #[serde(rename = "pi4")] - pub preferred_ipv4: Option, - #[serde(rename = "pi6")] - pub preferred_ipv6: Option, + #[serde(rename = "ei4")] + pub explicit_ipv4: Option, + #[serde(rename = "ei6")] + pub explicit_ipv6: Option, } #[derive(Serialize, Deserialize)]