From e55d3e4d4b7c76a5f9a8ec0b861ce5d781058d55 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Tue, 21 Dec 2021 21:43:09 -0500 Subject: [PATCH] Bunch of sync stuff including a neat set reconiciliation thing. --- allthethings/Cargo.toml | 3 +- allthethings/src/iblt.rs | 198 ++++++++++++++++++++++++++ allthethings/src/lib.rs | 21 ++- allthethings/src/memorystore.rs | 79 +++++++++++ allthethings/src/protocol.rs | 30 ++-- allthethings/src/replicator.rs | 215 +++++++++++++++-------------- allthethings/src/store.rs | 29 ++-- zerotier-core-crypto/src/random.rs | 15 ++ 8 files changed, 461 insertions(+), 129 deletions(-) create mode 100644 allthethings/src/iblt.rs create mode 100644 allthethings/src/memorystore.rs diff --git a/allthethings/Cargo.toml b/allthethings/Cargo.toml index 348aaaffb..ef0539ed5 100644 --- a/allthethings/Cargo.toml +++ b/allthethings/Cargo.toml @@ -4,6 +4,5 @@ version = "0.1.0" edition = "2021" [dependencies] -sha2 = { version = "^0", features = ["asm"] } smol = { version = "^1", features = [] } -getrandom = "^0" +zerotier-core-crypto = { path = "../zerotier-core-crypto" } diff --git a/allthethings/src/iblt.rs b/allthethings/src/iblt.rs new file mode 100644 index 000000000..eed1d53e2 --- /dev/null +++ b/allthethings/src/iblt.rs @@ -0,0 +1,198 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + * + * (c)2021 ZeroTier, Inc. + * https://www.zerotier.com/ + */ + +use std::mem::zeroed; + +use crate::IDENTITY_HASH_SIZE; + +// The number of indexing sub-hashes to use, must be <= IDENTITY_HASH_SIZE / 8 +const KEY_MAPPING_ITERATIONS: usize = IDENTITY_HASH_SIZE / 8; + +#[inline(always)] +fn xorshift64(mut x: u64) -> u64 { + x ^= x.wrapping_shl(13); + x ^= x.wrapping_shr(7); + x ^= x.wrapping_shl(17); + x +} + +#[repr(packed)] +struct IBLTEntry { + key_sum: [u64; IDENTITY_HASH_SIZE / 8], + check_hash_sum: u64, + count: i64, +} + +impl Default for IBLTEntry { + fn default() -> Self { unsafe { zeroed() } } +} + +/// An IBLT (invertible bloom lookup table) specialized for reconciling sets of identity hashes. +/// This skips some extra hashing that would be necessary in a universal implementation since identity +/// hashes are already randomly distributed strong hashes. +pub struct IBLT { + map: Box<[IBLTEntry]>, +} + +impl IBLTEntry { + #[inline(always)] + fn is_singular(&self) -> bool { + if self.count == 1 || self.count == -1 { + u64::from_le(self.key_sum[0]).wrapping_add(xorshift64(u64::from_le(self.key_sum[1]))) == u64::from_le(self.check_hash_sum) + } else { + false + } + } +} + +impl IBLT { + /// Construct a new IBLT with a given capacity. + pub fn new(buckets: usize) -> Self { + assert!(KEY_MAPPING_ITERATIONS <= (IDENTITY_HASH_SIZE / 8) && (IDENTITY_HASH_SIZE % 8) == 0); + assert!(buckets > 0); + Self { + map: { + let mut tmp = Vec::new(); + tmp.resize_with(buckets, IBLTEntry::default); + tmp.into_boxed_slice() + } + } + } + + fn ins_rem(&mut self, key: &[u64; IDENTITY_HASH_SIZE / 8], delta: i64) { + let check_hash = u64::from_le(key[0]).wrapping_add(xorshift64(u64::from_le(key[1]))).to_le(); + for mapping_sub_hash in 0..KEY_MAPPING_ITERATIONS { + let b = unsafe { self.map.get_unchecked_mut((u64::from_le(key[mapping_sub_hash]) as usize) % self.map.len()) }; + for j in 0..(IDENTITY_HASH_SIZE / 8) { + b.key_sum[j] ^= key[j]; + } + b.check_hash_sum ^= check_hash; + b.count = i64::from_le(b.count).wrapping_add(delta).to_le(); + } + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64", target_arch = "powerpc64"))] + #[inline(always)] + pub fn insert(&mut self, key: &[u8; IDENTITY_HASH_SIZE]) { + self.ins_rem(unsafe { &*key.as_ptr().cast::<[u64; IDENTITY_HASH_SIZE / 8]>() }, 1); + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64", target_arch = "powerpc64")))] + #[inline(always)] + pub fn insert(&mut self, key: &[u8; IDENTITY_HASH_SIZE]) { + let mut tmp = [0_u64; IDENTITY_HASH_SIZE / 8]; + unsafe { copy_nonoverlapping(key.as_ptr(), tmp.as_mut_ptr().cast(), IDENTITY_HASH_SIZE) }; + self.ins_rem(&tmp, 1); + } + + #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64", target_arch = "powerpc64"))] + #[inline(always)] + pub fn remove(&mut self, key: &[u8; IDENTITY_HASH_SIZE]) { + self.ins_rem(unsafe { &*key.as_ptr().cast::<[u64; IDENTITY_HASH_SIZE / 8]>() }, -1); + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64", target_arch = "powerpc64")))] + #[inline(always)] + pub fn remove(&mut self, key: &[u8; IDENTITY_HASH_SIZE]) { + let mut tmp = [0_u64; IDENTITY_HASH_SIZE / 8]; + unsafe { copy_nonoverlapping(key.as_ptr(), tmp.as_mut_ptr().cast(), IDENTITY_HASH_SIZE) }; + self.ins_rem(&tmp, -1); + } + + /// Subtract another IBLT from this one to compute set difference. + pub fn subtract(&mut self, other: &IBLT) { + if other.map.len() == self.map.len() { + for i in 0..self.map.len() { + let self_b = unsafe { self.map.get_unchecked_mut(i) }; + let other_b = unsafe { other.map.get_unchecked(i) }; + for j in 0..(IDENTITY_HASH_SIZE / 8) { + self_b.key_sum[j] ^= other_b.key_sum[j]; + } + self_b.check_hash_sum ^= other_b.check_hash_sum; + self_b.count = i64::from_le(self_b.count).wrapping_sub(i64::from_le(other_b.count)).to_le(); + } + } + } + + /// Call a function for every value that can be extracted from this IBLT. + /// + /// The function is called with the key and a boolean. The boolean is meaningful + /// if this IBLT is the result of subtract(). In that case the boolean is true + /// if the "local" IBLT contained the item and false if the "remote" side contained + /// the item. + /// + /// The starting_singular_bucket parameter must be the internal index of a + /// bucket with only one entry (1 or -1). It can be obtained from the return + /// values of either subtract() or singular_bucket(). + pub fn list bool>(&mut self, mut f: F) { + let mut singular_buckets: Vec = Vec::with_capacity(1024); + let buckets = self.map.len(); + + for i in 0..buckets { + if unsafe { self.map.get_unchecked(i) }.is_singular() { + singular_buckets.push(i); + }; + } + + let mut key = [0_u64; IDENTITY_HASH_SIZE / 8]; + while !singular_buckets.is_empty() { + let b = unsafe { self.map.get_unchecked_mut(singular_buckets.pop().unwrap()) }; + if b.is_singular() { + for j in 0..(IDENTITY_HASH_SIZE / 8) { + key[j] = b.key_sum[j]; + } + if f(unsafe { &*key.as_ptr().cast::<[u8; IDENTITY_HASH_SIZE]>() }, b.count == 1) { + let check_hash = u64::from_le(key[0]).wrapping_add(xorshift64(u64::from_le(key[1]))).to_le(); + for mapping_sub_hash in 0..KEY_MAPPING_ITERATIONS { + let bi = (u64::from_le(key[mapping_sub_hash]) as usize) % buckets; + let b = unsafe { self.map.get_unchecked_mut(bi) }; + for j in 0..(IDENTITY_HASH_SIZE / 8) { + b.key_sum[j] ^= key[j]; + } + b.check_hash_sum ^= check_hash; + b.count = i64::from_le(b.count).wrapping_sub(1).to_le(); + if b.is_singular() { + singular_buckets.push(bi); + } + } + } else { + break; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use zerotier_core_crypto::hash::SHA384; + use crate::iblt::IBLT; + + #[allow(unused_variables)] + #[test] + fn insert_and_list() { + let mut t = IBLT::new(1024); + let expected_cnt = 512; + for i in 0..expected_cnt { + let k = SHA384::hash(&(i as u64).to_le_bytes()); + t.insert(&k); + } + let mut cnt = 0; + t.list(|k, d| { + cnt += 1; + //println!("{} {}", zerotier_core_crypto::hex::to_string(k), d); + true + }); + println!("retrieved {} keys", cnt); + assert_eq!(cnt, expected_cnt); + } + + #[test] + fn benchmark() { + } +} diff --git a/allthethings/src/lib.rs b/allthethings/src/lib.rs index 95a3651a0..18e657688 100644 --- a/allthethings/src/lib.rs +++ b/allthethings/src/lib.rs @@ -10,6 +10,22 @@ mod store; mod replicator; mod protocol; mod varint; +mod memorystore; +mod iblt; + +pub struct Config { + /// Number of P2P connections desired. + pub target_link_count: usize, + + /// Maximum allowed size of an object. + pub max_object_size: usize, + + /// TCP port to which this should bind. + pub tcp_port: u16, + + /// A name for this replicated data set. This is just used to prevent linking to peers replicating different data. + pub domain: String, +} pub(crate) fn ms_since_epoch() -> u64 { std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64 @@ -19,7 +35,8 @@ pub(crate) fn ms_monotonic() -> u64 { std::time::Instant::now().elapsed().as_millis() as u64 } +/// SHA384 is the hash currently used. Others could be supported in the future. pub const IDENTITY_HASH_SIZE: usize = 48; -pub use store::{Store, StoreObjectResult}; -pub use replicator::{Replicator, Config}; +pub use store::{Store, StorePutResult}; +pub use replicator::Replicator; diff --git a/allthethings/src/memorystore.rs b/allthethings/src/memorystore.rs new file mode 100644 index 000000000..95466a964 --- /dev/null +++ b/allthethings/src/memorystore.rs @@ -0,0 +1,79 @@ +use std::collections::Bound::Included; +use std::collections::BTreeMap; +use std::io::Write; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Mutex; + +use smol::net::SocketAddr; + +use zerotier_core_crypto::random::xorshift64_random; + +use crate::{IDENTITY_HASH_SIZE, ms_since_epoch, Store, StorePutResult}; + +/// A Store that stores all objects in memory, mostly for testing. +pub struct MemoryStore(Mutex>>, Mutex>, AtomicU64); + +impl MemoryStore { + pub fn new() -> Self { Self(Mutex::new(BTreeMap::new()), Mutex::new(Vec::new()), AtomicU64::new(u64::MAX)) } +} + +impl Store for MemoryStore { + fn get(&self, _reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE], buffer: &mut Vec) -> bool { + buffer.clear(); + self.0.lock().unwrap().get(identity_hash).map_or(false, |value| { + let _ = buffer.write_all(value.as_slice()); + true + }) + } + + fn put(&self, _reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE], object: &[u8]) -> StorePutResult { + let mut result = StorePutResult::Duplicate; + let _ = self.0.lock().unwrap().entry(identity_hash.clone()).or_insert_with(|| { + self.2.store(ms_since_epoch(), Ordering::Relaxed); + result = StorePutResult::Ok; + object.to_vec() + }); + result + } + + fn have(&self, _reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE]) -> bool { + self.0.lock().unwrap().contains_key(identity_hash) + } + + fn total_count(&self, _reference_time: u64) -> Option { + Some(self.0.lock().unwrap().len() as u64) + } + + fn last_object_receive_time(&self) -> Option { + let rt = self.2.load(Ordering::Relaxed); + if rt == u64::MAX { + None + } else { + Some(rt) + } + } + + fn count(&self, _reference_time: u64, start: &[u8; IDENTITY_HASH_SIZE], end: &[u8; IDENTITY_HASH_SIZE]) -> Option { + if start.le(end) { + Some(self.0.lock().unwrap().range((Included(*start), Included(*end))).count() as u64) + } else { + None + } + } + + fn save_remote_endpoint(&self, to_address: &SocketAddr) { + let mut sv = self.1.lock().unwrap(); + if !sv.contains(to_address) { + sv.push(to_address.clone()); + } + } + + fn get_remote_endpoint(&self) -> Option { + let sv = self.1.lock().unwrap(); + if sv.is_empty() { + None + } else { + sv.get((xorshift64_random() as usize) % sv.len()).cloned() + } + } +} diff --git a/allthethings/src/protocol.rs b/allthethings/src/protocol.rs index 077be4621..357c10a21 100644 --- a/allthethings/src/protocol.rs +++ b/allthethings/src/protocol.rs @@ -6,24 +6,38 @@ * https://www.zerotier.com/ */ -pub(crate) const PROTOCOL_VERSION: u8 = 1; +pub const PROTOCOL_VERSION: u8 = 1; +pub const HASH_ALGORITHM_SHA384: u8 = 1; -pub(crate) const MESSAGE_TYPE_NOP: u8 = 0; -pub(crate) const MESSAGE_TYPE_HAVE_NEW_OBJECT: u8 = 1; -pub(crate) const MESSAGE_TYPE_OBJECT: u8 = 2; -pub(crate) const MESSAGE_TYPE_GET_OBJECTS: u8 = 3; +pub const MESSAGE_TYPE_NOP: u8 = 0; +pub const MESSAGE_TYPE_HAVE_NEW_OBJECT: u8 = 1; +pub const MESSAGE_TYPE_OBJECT: u8 = 2; +pub const MESSAGE_TYPE_GET_OBJECTS: u8 = 3; /// HELLO message, which is all u8's and is packed and so can be parsed directly in place. /// This message is sent at the start of any connection by both sides. #[repr(packed)] -pub(crate) struct Hello { +pub struct Hello { pub hello_size: u8, // technically a varint but below 0x80 pub protocol_version: u8, + pub hash_algorithm: u8, pub flags: [u8; 4], // u32, little endian pub clock: [u8; 8], // u64, little endian - pub data_set_size: [u8; 8], // u64, little endian + pub last_object_receive_time: [u8; 8], // u64, little endian, u64::MAX if unspecified pub domain_hash: [u8; 48], pub instance_id: [u8; 16], - pub loopback_check_code_salt: [u8; 8], + pub loopback_check_code_salt: [u8; 16], pub loopback_check_code: [u8; 16], } + +#[cfg(test)] +mod tests { + use std::mem::size_of; + use crate::protocol::*; + + #[test] + fn check_sizing() { + // Make sure packed structures are really packed. + assert_eq!(size_of::(), 1 + 1 + 1 + 4 + 8 + 8 + 48 + 16 + 16 + 16); + } +} diff --git a/allthethings/src/replicator.rs b/allthethings/src/replicator.rs index e72b15eab..b27b15bb3 100644 --- a/allthethings/src/replicator.rs +++ b/allthethings/src/replicator.rs @@ -10,53 +10,27 @@ use std::collections::HashMap; use std::convert::TryInto; use std::error::Error; use std::hash::{Hash, Hasher}; -use std::marker::PhantomData; use std::mem::{size_of, transmute}; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; -use getrandom::getrandom; -use sha2::{Digest, Sha384}; use smol::{Executor, Task, Timer}; use smol::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use smol::lock::Mutex; -use smol::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, SocketAddr}; +use smol::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream}; use smol::stream::StreamExt; -use crate::{IDENTITY_HASH_SIZE, ms_monotonic, ms_since_epoch, protocol}; -use crate::store::{StoreObjectResult, Store}; +use zerotier_core_crypto::hash::SHA384; +use zerotier_core_crypto::random; + +use crate::{IDENTITY_HASH_SIZE, ms_monotonic, ms_since_epoch, protocol, Config}; +use crate::store::{Store, StorePutResult}; use crate::varint; const CONNECTION_TIMEOUT_SECONDS: u64 = 60; const CONNECTION_SYNC_RESTART_TIMEOUT_SECONDS: u64 = 5; -static mut XORSHIFT64_STATE: u64 = 0; - -/// Get a non-cryptographic random number. -fn xorshift64_random() -> u64 { - let mut x = unsafe { XORSHIFT64_STATE }; - x ^= x.wrapping_shl(13); - x ^= x.wrapping_shr(7); - x ^= x.wrapping_shl(17); - unsafe { XORSHIFT64_STATE = x }; - x -} - -pub struct Config { - /// Number of P2P connections desired. - pub target_link_count: usize, - - /// Maximum allowed size of an object. - pub max_object_size: usize, - - /// TCP port to which this should bind. - pub tcp_port: u16, - - /// A name for this replicated data set. This is just used to prevent linking to peers replicating different data. - pub domain: String, -} - #[derive(PartialEq, Eq, Clone)] struct ConnectionKey { instance_id: [u8; 16], @@ -74,33 +48,32 @@ impl Hash for ConnectionKey { struct Connection { remote_address: SocketAddr, last_receive: Arc, - task: Task<()> + task: Task<()>, } struct ReplicatorImpl<'ex> { executor: Arc>, instance_id: [u8; 16], - loopback_check_code_secret: [u8; 16], + loopback_check_code_secret: [u8; 48], domain_hash: [u8; 48], store: Arc, config: Config, connections: Mutex>, + connections_in_progress: Mutex>>, announced_objects_requested: Mutex>, } pub struct Replicator<'ex> { v4_listener_task: Option>, v6_listener_task: Option>, - service_task: Task<()>, - _marker: PhantomData>, + background_cleanup_task: Task<()>, + _impl: Arc> } impl<'ex> Replicator<'ex> { /// Create a new replicator to replicate the contents of the provided store. /// All async tasks, sockets, and connections will be dropped if the replicator is dropped. pub async fn start(executor: &Arc>, store: Arc, config: Config) -> Result, Box> { - let _ = unsafe { getrandom(&mut *(&mut XORSHIFT64_STATE as *mut u64).cast::<[u8; 8]>()) }; - let listener_v4 = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.tcp_port)).await; let listener_v6 = TcpListener::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, config.tcp_port, 0, 0)).await; if listener_v4.is_err() && listener_v6.is_err() { @@ -111,30 +84,27 @@ impl<'ex> Replicator<'ex> { executor: executor.clone(), instance_id: { let mut tmp = [0_u8; 16]; - getrandom(&mut tmp).expect("getrandom failed"); + random::fill_bytes_secure(&mut tmp); tmp }, loopback_check_code_secret: { - let mut tmp = [0_u8; 16]; - getrandom(&mut tmp).expect("getrandom failed"); + let mut tmp = [0_u8; 48]; + random::fill_bytes_secure(&mut tmp); tmp }, - domain_hash: { - let mut h = Sha384::new(); - h.update(config.domain.as_bytes()); - h.finalize().as_ref().try_into().unwrap() - }, + domain_hash: SHA384::hash(config.domain.as_bytes()), config, store, connections: Mutex::new(HashMap::new()), - announced_objects_requested: Mutex::new(HashMap::new()) + connections_in_progress: Mutex::new(HashMap::new()), + announced_objects_requested: Mutex::new(HashMap::new()), }); Ok(Self { - v4_listener_task: listener_v4.map_or(None, |listener_v4| Some(executor.spawn(r.clone().listener_task_main(listener_v4)))), - v6_listener_task: listener_v6.map_or(None, |listener_v6| Some(executor.spawn(r.clone().listener_task_main(listener_v6)))), - service_task: executor.spawn(r.service_main()), - _marker: PhantomData::default(), + v4_listener_task: listener_v4.map_or(None, |listener_v4| Some(executor.spawn(r.clone().tcp_listener_task(listener_v4)))), + v6_listener_task: listener_v6.map_or(None, |listener_v6| Some(executor.spawn(r.clone().tcp_listener_task(listener_v6)))), + background_cleanup_task: executor.spawn(r.clone().background_cleanup_task()), + _impl: r }) } } @@ -144,47 +114,74 @@ unsafe impl<'ex> Send for Replicator<'ex> {} unsafe impl<'ex> Sync for Replicator<'ex> {} impl<'ex> ReplicatorImpl<'ex> { - async fn service_main(self: Arc>) { - let mut timer = smol::Timer::interval(Duration::from_secs(5)); + async fn background_cleanup_task(self: Arc>) { + let mut timer = smol::Timer::interval(Duration::from_secs(CONNECTION_TIMEOUT_SECONDS / 10)); loop { timer.next().await; let now_mt = ms_monotonic(); - self.announced_objects_requested.lock().await.retain(|_, ts| now_mt.saturating_sub(*ts) < (CONNECTION_TIMEOUT_SECONDS * 1000)); - self.connections.lock().await.retain(|_, c| (now_mt.saturating_sub(c.last_receive.load(Ordering::Relaxed))) < (CONNECTION_TIMEOUT_SECONDS * 1000)); - } - } - async fn listener_task_main(self: Arc>, listener: TcpListener) { - loop { - let stream = listener.accept().await; - if stream.is_ok() { - let (stream, remote_address) = stream.unwrap(); - self.handle_new_connection(stream, remote_address, false).await; + // Garbage collect the map used to track objects we've requested. + self.announced_objects_requested.lock().await.retain(|_, ts| now_mt.saturating_sub(*ts) < (CONNECTION_TIMEOUT_SECONDS * 1000)); + + let mut connections = self.connections.lock().await; + + // Close connections that haven't spoken in too long. + connections.retain(|_, c| (now_mt.saturating_sub(c.last_receive.load(Ordering::Relaxed))) < (CONNECTION_TIMEOUT_SECONDS * 1000)); + let num_connections = connections.len(); + drop(connections); // release lock + + // Try to connect to more nodes if the count is below the target count. + if num_connections < self.config.target_link_count { + let new_link_seed = self.store.get_remote_endpoint(); + if new_link_seed.is_some() { + let new_link_seed = new_link_seed.unwrap(); + let mut connections_in_progress = self.connections_in_progress.lock().await; + if !connections_in_progress.contains_key(&new_link_seed) { + let s2 = self.clone(); + let _ = connections_in_progress.insert(new_link_seed.clone(), self.executor.spawn(async move { + let new_link = TcpStream::connect(&new_link_seed).await; + if new_link.is_ok() { + s2.handle_new_connection(new_link.unwrap(), new_link_seed, true).await; + } else { + let _task = s2.connections_in_progress.lock().await.remove(&new_link_seed); + } + })); + } + } } } } - async fn handle_new_connection(self: &Arc>, mut stream: TcpStream, remote_address: SocketAddr, outgoing: bool) { - stream.set_nodelay(true); + async fn tcp_listener_task(self: Arc>, listener: TcpListener) { + loop { + let stream = listener.accept().await; + if stream.is_ok() { + let (stream, remote_address) = stream.unwrap(); + let mut connections_in_progress = self.connections_in_progress.lock().await; + if !connections_in_progress.contains_key(&remote_address) { + let s2 = self.clone(); + let _ = connections_in_progress.insert(remote_address, self.executor.spawn(s2.handle_new_connection(stream, remote_address.clone(), false))); + } + } + } + } - let mut loopback_check_code_salt = [0_u8; 8]; - getrandom(&mut loopback_check_code_salt).expect("getrandom failed"); - - let mut h = Sha384::new(); - h.update(&loopback_check_code_salt); - h.update(&self.loopback_check_code_secret); - let loopback_check_code: [u8; 48] = h.finalize().as_ref().try_into().unwrap(); + async fn handle_new_connection(self: Arc>, mut stream: TcpStream, remote_address: SocketAddr, outgoing: bool) { + let _ = stream.set_nodelay(true); + let mut loopback_check_code_salt = [0_u8; 16]; + random::fill_bytes_secure(&mut loopback_check_code_salt); let hello = protocol::Hello { hello_size: size_of::() as u8, protocol_version: protocol::PROTOCOL_VERSION, + hash_algorithm: protocol::HASH_ALGORITHM_SHA384, flags: [0_u8; 4], clock: ms_since_epoch().to_le_bytes(), - data_set_size: self.store.total_size().to_le_bytes(), + last_object_receive_time: self.store.last_object_receive_time().unwrap_or(u64::MAX).to_le_bytes(), domain_hash: self.domain_hash.clone(), instance_id: self.instance_id.clone(), loopback_check_code_salt, - loopback_check_code: (&loopback_check_code[0..16]).try_into().unwrap() + loopback_check_code: (&SHA384::hmac(&self.loopback_check_code_secret, &loopback_check_code_salt)[0..16]).try_into().unwrap(), }; let hello: [u8; size_of::()] = unsafe { transmute(hello) }; @@ -194,48 +191,50 @@ impl<'ex> ReplicatorImpl<'ex> { let hello: protocol::Hello = unsafe { transmute(hello_buf) }; // Sanity check HELLO packet. In the future we may support different versions and sizes. - if hello.hello_size == size_of::() as u8 && hello.protocol_version == protocol::PROTOCOL_VERSION { - // If this hash's first 16 bytes are equal to the one in the HELLO, this connection is - // from this node and should be dropped. - let mut h = Sha384::new(); - h.update(&hello.loopback_check_code_salt); - h.update(&self.loopback_check_code_secret); - let loopback_if_equal: [u8; 48] = h.finalize().as_ref().try_into().unwrap(); - - if !loopback_if_equal[0..16].eq(&hello.loopback_check_code) { + if hello.hello_size == size_of::() as u8 && hello.protocol_version == protocol::PROTOCOL_VERSION && hello.hash_algorithm == protocol::HASH_ALGORITHM_SHA384 { + if !SHA384::hmac(&self.loopback_check_code_secret, &hello.loopback_check_code_salt)[0..16].eq(&hello.loopback_check_code) { let k = ConnectionKey { instance_id: hello.instance_id.clone(), - ip: remote_address.ip() + ip: remote_address.ip(), }; let mut connections = self.connections.lock().await; + let s2 = self.clone(); let _ = connections.entry(k).or_insert_with(move || { - stream.set_nodelay(false); + let _ = stream.set_nodelay(false); + + if outgoing { + s2.store.save_remote_endpoint(&remote_address); + } + let last_receive = Arc::new(AtomicU64::new(ms_monotonic())); Connection { remote_address, last_receive: last_receive.clone(), - task: self.executor.spawn(self.clone().connection_io_task_main(stream, hello.instance_id, last_receive)) + task: s2.executor.spawn(s2.clone().connection_io_task(stream, hello.instance_id, last_receive)), } }); } } } } + + let _task = self.connections_in_progress.lock().await.remove(&remote_address); } - async fn connection_io_task_main(self: Arc>, stream: TcpStream, remote_instance_id: [u8; 16], last_receive: Arc) { + async fn connection_io_task(self: Arc>, stream: TcpStream, remote_instance_id: [u8; 16], last_receive: Arc) { let mut reader = BufReader::with_capacity(65536, stream.clone()); let writer = Arc::new(Mutex::new(stream)); - let writer2 = writer.clone(); - let _sync_search_init_task = self.executor.spawn(async { - let writer = writer2; + //let writer2 = writer.clone(); + let _sync_search_init_task = self.executor.spawn(async move { + //let writer = writer2; let mut periodic_timer = Timer::interval(Duration::from_secs(1)); loop { let _ = periodic_timer.next().await; } }); + let mut get_buffer = Vec::new(); let mut tmp_mem = Vec::new(); tmp_mem.resize(self.config.max_object_size, 0); let tmp = tmp_mem.as_mut_slice(); @@ -248,7 +247,7 @@ impl<'ex> ReplicatorImpl<'ex> { last_receive.store(ms_monotonic(), Ordering::Relaxed); match message_type { - protocol::MESSAGE_TYPE_NOP => {}, + protocol::MESSAGE_TYPE_NOP => {} protocol::MESSAGE_TYPE_HAVE_NEW_OBJECT => { if reader.read_exact(&mut tmp[0..IDENTITY_HASH_SIZE]).await.is_err() { @@ -256,7 +255,7 @@ impl<'ex> ReplicatorImpl<'ex> { } let identity_hash: [u8; 48] = (&tmp[0..IDENTITY_HASH_SIZE]).try_into().unwrap(); let mut announced_objects_requested = self.announced_objects_requested.lock().await; - if !announced_objects_requested.contains_key(&identity_hash) && !self.store.have(&identity_hash) { + if !announced_objects_requested.contains_key(&identity_hash) && !self.store.have(ms_since_epoch(), &identity_hash) { announced_objects_requested.insert(identity_hash.clone(), ms_monotonic()); drop(announced_objects_requested); // release mutex @@ -285,23 +284,30 @@ impl<'ex> ReplicatorImpl<'ex> { break 'main_io_loop; } - let identity_hash: [u8; 48] = Sha384::digest(object.as_ref()).as_ref().try_into().unwrap(); - match self.store.put(&identity_hash, object) { - StoreObjectResult::Invalid => { + let identity_hash: [u8; 48] = SHA384::hash(object); + match self.store.put(ms_since_epoch(), &identity_hash, object) { + StorePutResult::Invalid => { break 'main_io_loop; - }, - StoreObjectResult::Ok | StoreObjectResult::Duplicate => { + } + StorePutResult::Ok | StorePutResult::Duplicate => { if self.announced_objects_requested.lock().await.remove(&identity_hash).is_some() { // TODO: propagate rumor if we requested this object in response to a HAVE message. } - }, + } _ => { let _ = self.announced_objects_requested.lock().await.remove(&identity_hash); } } - }, + } protocol::MESSAGE_TYPE_GET_OBJECTS => { + // Get the reference time for this query. + let reference_time = varint::async_read(&mut reader).await; + if reference_time.is_err() { + break 'main_io_loop; + } + let reference_time = reference_time.unwrap(); + // Read common prefix if the requester is requesting a set of hashes with the same beginning. // A common prefix length of zero means they're requesting by full hash. if reader.read_exact(&mut tmp[0..1]).await.is_err() { @@ -328,20 +334,17 @@ impl<'ex> ReplicatorImpl<'ex> { break 'main_io_loop; } let identity_hash: [u8; IDENTITY_HASH_SIZE] = (&tmp[0..IDENTITY_HASH_SIZE]).try_into().unwrap(); - let object = self.store.get(&identity_hash); - if object.is_some() { - let object2 = object.unwrap(); - let object = object2.as_slice(); + if self.store.get(reference_time, &identity_hash, &mut get_buffer) { let mut w = writer.lock().await; - if varint::async_write(&mut *w, object.len() as u64).await.is_err() { + if varint::async_write(&mut *w, get_buffer.len() as u64).await.is_err() { break 'main_io_loop; } - if w.write_all(object).await.is_err() { + if w.write_all(get_buffer.as_slice()).await.is_err() { break 'main_io_loop; } } } - }, + } _ => { break 'main_io_loop; diff --git a/allthethings/src/store.rs b/allthethings/src/store.rs index bc487e60d..b65a79ddf 100644 --- a/allthethings/src/store.rs +++ b/allthethings/src/store.rs @@ -10,8 +10,10 @@ use smol::net::SocketAddr; use crate::IDENTITY_HASH_SIZE; -/// Result code from the put() method in Database. -pub enum StoreObjectResult { +pub const MIN_IDENTITY_HASH: [u8; 48] = [0_u8; 48]; +pub const MAX_IDENTITY_HASH: [u8; 48] = [0xff_u8; 48]; + +pub enum StorePutResult { /// Datum stored successfully. Ok, /// Datum is one we already have. @@ -22,23 +24,28 @@ pub enum StoreObjectResult { Ignored, } -/// Trait that must be implemented for the data store that is to be replicated. +/// Trait that must be implemented by the data store that is to be replicated. pub trait Store: Sync + Send { - /// Get the total size of this data set in objects. - fn total_size(&self) -> u64; - - /// Get an object from the database. - fn get(&self, identity_hash: &[u8; IDENTITY_HASH_SIZE]) -> Option>; + /// Get an object from the database, storing it in the supplied buffer. + /// A return of 'false' leaves the buffer state undefined. If the return is true any previous + /// data in the supplied buffer will have been cleared and replaced with the retrieved object. + fn get(&self, reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE], buffer: &mut Vec) -> bool; /// Store an entry in the database. - fn put(&self, identity_hash: &[u8; IDENTITY_HASH_SIZE], object: &[u8]) -> StoreObjectResult; + fn put(&self, reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE], object: &[u8]) -> StorePutResult; /// Check if we have an object by its identity hash. - fn have(&self, identity_hash: &[u8; IDENTITY_HASH_SIZE]) -> bool; + fn have(&self, reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE]) -> bool; + + /// Get the total count of objects. + fn total_count(&self, reference_time: u64) -> Option; + + /// Get the time the last object was received in milliseconds since epoch. + fn last_object_receive_time(&self) -> Option; /// Count the number of identity hash keys in this range (inclusive) of identity hashes. /// This may return None if an error occurs, but should return 0 if the set is empty. - fn count(&self, start: &[u8; IDENTITY_HASH_SIZE], end: &[u8; IDENTITY_HASH_SIZE]) -> Option; + fn count(&self, reference_time: u64, start: &[u8; IDENTITY_HASH_SIZE], end: &[u8; IDENTITY_HASH_SIZE]) -> Option; /// Called when a connection to a remote node was successful. /// This is always called on successful outbound connect. diff --git a/zerotier-core-crypto/src/random.rs b/zerotier-core-crypto/src/random.rs index 372fa10f4..87ceac4dc 100644 --- a/zerotier-core-crypto/src/random.rs +++ b/zerotier-core-crypto/src/random.rs @@ -59,3 +59,18 @@ pub fn next_u64_secure() -> u64 { #[inline(always)] pub fn fill_bytes_secure(dest: &mut [u8]) { randomize(Level::Strong, dest); } + +static mut XORSHIFT64_STATE: u64 = 0; + +/// Get a non-cryptographic random number. +pub fn xorshift64_random() -> u64 { + let mut x = unsafe { XORSHIFT64_STATE }; + while x == 0 { + x = next_u64_secure(); + } + x ^= x.wrapping_shl(13); + x ^= x.wrapping_shr(7); + x ^= x.wrapping_shl(17); + unsafe { XORSHIFT64_STATE = x }; + x +}