mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-06-03 19:13:43 +02:00
Bunch of sync stuff including a neat set reconiciliation thing.
This commit is contained in:
parent
0d67fcee92
commit
e55d3e4d4b
8 changed files with 461 additions and 129 deletions
|
@ -4,6 +4,5 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
sha2 = { version = "^0", features = ["asm"] }
|
||||
smol = { version = "^1", features = [] }
|
||||
getrandom = "^0"
|
||||
zerotier-core-crypto = { path = "../zerotier-core-crypto" }
|
||||
|
|
198
allthethings/src/iblt.rs
Normal file
198
allthethings/src/iblt.rs
Normal file
|
@ -0,0 +1,198 @@
|
|||
/* This Source Code Form is subject to the terms of the Mozilla Public
|
||||
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
*
|
||||
* (c)2021 ZeroTier, Inc.
|
||||
* https://www.zerotier.com/
|
||||
*/
|
||||
|
||||
use std::mem::zeroed;
|
||||
|
||||
use crate::IDENTITY_HASH_SIZE;
|
||||
|
||||
// The number of indexing sub-hashes to use, must be <= IDENTITY_HASH_SIZE / 8
|
||||
const KEY_MAPPING_ITERATIONS: usize = IDENTITY_HASH_SIZE / 8;
|
||||
|
||||
#[inline(always)]
|
||||
fn xorshift64(mut x: u64) -> u64 {
|
||||
x ^= x.wrapping_shl(13);
|
||||
x ^= x.wrapping_shr(7);
|
||||
x ^= x.wrapping_shl(17);
|
||||
x
|
||||
}
|
||||
|
||||
#[repr(packed)]
|
||||
struct IBLTEntry {
|
||||
key_sum: [u64; IDENTITY_HASH_SIZE / 8],
|
||||
check_hash_sum: u64,
|
||||
count: i64,
|
||||
}
|
||||
|
||||
impl Default for IBLTEntry {
|
||||
fn default() -> Self { unsafe { zeroed() } }
|
||||
}
|
||||
|
||||
/// An IBLT (invertible bloom lookup table) specialized for reconciling sets of identity hashes.
|
||||
/// This skips some extra hashing that would be necessary in a universal implementation since identity
|
||||
/// hashes are already randomly distributed strong hashes.
|
||||
pub struct IBLT {
|
||||
map: Box<[IBLTEntry]>,
|
||||
}
|
||||
|
||||
impl IBLTEntry {
|
||||
#[inline(always)]
|
||||
fn is_singular(&self) -> bool {
|
||||
if self.count == 1 || self.count == -1 {
|
||||
u64::from_le(self.key_sum[0]).wrapping_add(xorshift64(u64::from_le(self.key_sum[1]))) == u64::from_le(self.check_hash_sum)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IBLT {
|
||||
/// Construct a new IBLT with a given capacity.
|
||||
pub fn new(buckets: usize) -> Self {
|
||||
assert!(KEY_MAPPING_ITERATIONS <= (IDENTITY_HASH_SIZE / 8) && (IDENTITY_HASH_SIZE % 8) == 0);
|
||||
assert!(buckets > 0);
|
||||
Self {
|
||||
map: {
|
||||
let mut tmp = Vec::new();
|
||||
tmp.resize_with(buckets, IBLTEntry::default);
|
||||
tmp.into_boxed_slice()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ins_rem(&mut self, key: &[u64; IDENTITY_HASH_SIZE / 8], delta: i64) {
|
||||
let check_hash = u64::from_le(key[0]).wrapping_add(xorshift64(u64::from_le(key[1]))).to_le();
|
||||
for mapping_sub_hash in 0..KEY_MAPPING_ITERATIONS {
|
||||
let b = unsafe { self.map.get_unchecked_mut((u64::from_le(key[mapping_sub_hash]) as usize) % self.map.len()) };
|
||||
for j in 0..(IDENTITY_HASH_SIZE / 8) {
|
||||
b.key_sum[j] ^= key[j];
|
||||
}
|
||||
b.check_hash_sum ^= check_hash;
|
||||
b.count = i64::from_le(b.count).wrapping_add(delta).to_le();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64", target_arch = "powerpc64"))]
|
||||
#[inline(always)]
|
||||
pub fn insert(&mut self, key: &[u8; IDENTITY_HASH_SIZE]) {
|
||||
self.ins_rem(unsafe { &*key.as_ptr().cast::<[u64; IDENTITY_HASH_SIZE / 8]>() }, 1);
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64", target_arch = "powerpc64")))]
|
||||
#[inline(always)]
|
||||
pub fn insert(&mut self, key: &[u8; IDENTITY_HASH_SIZE]) {
|
||||
let mut tmp = [0_u64; IDENTITY_HASH_SIZE / 8];
|
||||
unsafe { copy_nonoverlapping(key.as_ptr(), tmp.as_mut_ptr().cast(), IDENTITY_HASH_SIZE) };
|
||||
self.ins_rem(&tmp, 1);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64", target_arch = "powerpc64"))]
|
||||
#[inline(always)]
|
||||
pub fn remove(&mut self, key: &[u8; IDENTITY_HASH_SIZE]) {
|
||||
self.ins_rem(unsafe { &*key.as_ptr().cast::<[u64; IDENTITY_HASH_SIZE / 8]>() }, -1);
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64", target_arch = "powerpc64")))]
|
||||
#[inline(always)]
|
||||
pub fn remove(&mut self, key: &[u8; IDENTITY_HASH_SIZE]) {
|
||||
let mut tmp = [0_u64; IDENTITY_HASH_SIZE / 8];
|
||||
unsafe { copy_nonoverlapping(key.as_ptr(), tmp.as_mut_ptr().cast(), IDENTITY_HASH_SIZE) };
|
||||
self.ins_rem(&tmp, -1);
|
||||
}
|
||||
|
||||
/// Subtract another IBLT from this one to compute set difference.
|
||||
pub fn subtract(&mut self, other: &IBLT) {
|
||||
if other.map.len() == self.map.len() {
|
||||
for i in 0..self.map.len() {
|
||||
let self_b = unsafe { self.map.get_unchecked_mut(i) };
|
||||
let other_b = unsafe { other.map.get_unchecked(i) };
|
||||
for j in 0..(IDENTITY_HASH_SIZE / 8) {
|
||||
self_b.key_sum[j] ^= other_b.key_sum[j];
|
||||
}
|
||||
self_b.check_hash_sum ^= other_b.check_hash_sum;
|
||||
self_b.count = i64::from_le(self_b.count).wrapping_sub(i64::from_le(other_b.count)).to_le();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Call a function for every value that can be extracted from this IBLT.
|
||||
///
|
||||
/// The function is called with the key and a boolean. The boolean is meaningful
|
||||
/// if this IBLT is the result of subtract(). In that case the boolean is true
|
||||
/// if the "local" IBLT contained the item and false if the "remote" side contained
|
||||
/// the item.
|
||||
///
|
||||
/// The starting_singular_bucket parameter must be the internal index of a
|
||||
/// bucket with only one entry (1 or -1). It can be obtained from the return
|
||||
/// values of either subtract() or singular_bucket().
|
||||
pub fn list<F: FnMut(&[u8; IDENTITY_HASH_SIZE], bool) -> bool>(&mut self, mut f: F) {
|
||||
let mut singular_buckets: Vec<usize> = Vec::with_capacity(1024);
|
||||
let buckets = self.map.len();
|
||||
|
||||
for i in 0..buckets {
|
||||
if unsafe { self.map.get_unchecked(i) }.is_singular() {
|
||||
singular_buckets.push(i);
|
||||
};
|
||||
}
|
||||
|
||||
let mut key = [0_u64; IDENTITY_HASH_SIZE / 8];
|
||||
while !singular_buckets.is_empty() {
|
||||
let b = unsafe { self.map.get_unchecked_mut(singular_buckets.pop().unwrap()) };
|
||||
if b.is_singular() {
|
||||
for j in 0..(IDENTITY_HASH_SIZE / 8) {
|
||||
key[j] = b.key_sum[j];
|
||||
}
|
||||
if f(unsafe { &*key.as_ptr().cast::<[u8; IDENTITY_HASH_SIZE]>() }, b.count == 1) {
|
||||
let check_hash = u64::from_le(key[0]).wrapping_add(xorshift64(u64::from_le(key[1]))).to_le();
|
||||
for mapping_sub_hash in 0..KEY_MAPPING_ITERATIONS {
|
||||
let bi = (u64::from_le(key[mapping_sub_hash]) as usize) % buckets;
|
||||
let b = unsafe { self.map.get_unchecked_mut(bi) };
|
||||
for j in 0..(IDENTITY_HASH_SIZE / 8) {
|
||||
b.key_sum[j] ^= key[j];
|
||||
}
|
||||
b.check_hash_sum ^= check_hash;
|
||||
b.count = i64::from_le(b.count).wrapping_sub(1).to_le();
|
||||
if b.is_singular() {
|
||||
singular_buckets.push(bi);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use zerotier_core_crypto::hash::SHA384;
|
||||
use crate::iblt::IBLT;
|
||||
|
||||
#[allow(unused_variables)]
|
||||
#[test]
|
||||
fn insert_and_list() {
|
||||
let mut t = IBLT::new(1024);
|
||||
let expected_cnt = 512;
|
||||
for i in 0..expected_cnt {
|
||||
let k = SHA384::hash(&(i as u64).to_le_bytes());
|
||||
t.insert(&k);
|
||||
}
|
||||
let mut cnt = 0;
|
||||
t.list(|k, d| {
|
||||
cnt += 1;
|
||||
//println!("{} {}", zerotier_core_crypto::hex::to_string(k), d);
|
||||
true
|
||||
});
|
||||
println!("retrieved {} keys", cnt);
|
||||
assert_eq!(cnt, expected_cnt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn benchmark() {
|
||||
}
|
||||
}
|
|
@ -10,6 +10,22 @@ mod store;
|
|||
mod replicator;
|
||||
mod protocol;
|
||||
mod varint;
|
||||
mod memorystore;
|
||||
mod iblt;
|
||||
|
||||
pub struct Config {
|
||||
/// Number of P2P connections desired.
|
||||
pub target_link_count: usize,
|
||||
|
||||
/// Maximum allowed size of an object.
|
||||
pub max_object_size: usize,
|
||||
|
||||
/// TCP port to which this should bind.
|
||||
pub tcp_port: u16,
|
||||
|
||||
/// A name for this replicated data set. This is just used to prevent linking to peers replicating different data.
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
pub(crate) fn ms_since_epoch() -> u64 {
|
||||
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64
|
||||
|
@ -19,7 +35,8 @@ pub(crate) fn ms_monotonic() -> u64 {
|
|||
std::time::Instant::now().elapsed().as_millis() as u64
|
||||
}
|
||||
|
||||
/// SHA384 is the hash currently used. Others could be supported in the future.
|
||||
pub const IDENTITY_HASH_SIZE: usize = 48;
|
||||
|
||||
pub use store::{Store, StoreObjectResult};
|
||||
pub use replicator::{Replicator, Config};
|
||||
pub use store::{Store, StorePutResult};
|
||||
pub use replicator::Replicator;
|
||||
|
|
79
allthethings/src/memorystore.rs
Normal file
79
allthethings/src/memorystore.rs
Normal file
|
@ -0,0 +1,79 @@
|
|||
use std::collections::Bound::Included;
|
||||
use std::collections::BTreeMap;
|
||||
use std::io::Write;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use smol::net::SocketAddr;
|
||||
|
||||
use zerotier_core_crypto::random::xorshift64_random;
|
||||
|
||||
use crate::{IDENTITY_HASH_SIZE, ms_since_epoch, Store, StorePutResult};
|
||||
|
||||
/// A Store that stores all objects in memory, mostly for testing.
|
||||
pub struct MemoryStore(Mutex<BTreeMap<[u8; IDENTITY_HASH_SIZE], Vec<u8>>>, Mutex<Vec<SocketAddr>>, AtomicU64);
|
||||
|
||||
impl MemoryStore {
|
||||
pub fn new() -> Self { Self(Mutex::new(BTreeMap::new()), Mutex::new(Vec::new()), AtomicU64::new(u64::MAX)) }
|
||||
}
|
||||
|
||||
impl Store for MemoryStore {
|
||||
fn get(&self, _reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE], buffer: &mut Vec<u8>) -> bool {
|
||||
buffer.clear();
|
||||
self.0.lock().unwrap().get(identity_hash).map_or(false, |value| {
|
||||
let _ = buffer.write_all(value.as_slice());
|
||||
true
|
||||
})
|
||||
}
|
||||
|
||||
fn put(&self, _reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE], object: &[u8]) -> StorePutResult {
|
||||
let mut result = StorePutResult::Duplicate;
|
||||
let _ = self.0.lock().unwrap().entry(identity_hash.clone()).or_insert_with(|| {
|
||||
self.2.store(ms_since_epoch(), Ordering::Relaxed);
|
||||
result = StorePutResult::Ok;
|
||||
object.to_vec()
|
||||
});
|
||||
result
|
||||
}
|
||||
|
||||
fn have(&self, _reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE]) -> bool {
|
||||
self.0.lock().unwrap().contains_key(identity_hash)
|
||||
}
|
||||
|
||||
fn total_count(&self, _reference_time: u64) -> Option<u64> {
|
||||
Some(self.0.lock().unwrap().len() as u64)
|
||||
}
|
||||
|
||||
fn last_object_receive_time(&self) -> Option<u64> {
|
||||
let rt = self.2.load(Ordering::Relaxed);
|
||||
if rt == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(rt)
|
||||
}
|
||||
}
|
||||
|
||||
fn count(&self, _reference_time: u64, start: &[u8; IDENTITY_HASH_SIZE], end: &[u8; IDENTITY_HASH_SIZE]) -> Option<u64> {
|
||||
if start.le(end) {
|
||||
Some(self.0.lock().unwrap().range((Included(*start), Included(*end))).count() as u64)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn save_remote_endpoint(&self, to_address: &SocketAddr) {
|
||||
let mut sv = self.1.lock().unwrap();
|
||||
if !sv.contains(to_address) {
|
||||
sv.push(to_address.clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn get_remote_endpoint(&self) -> Option<SocketAddr> {
|
||||
let sv = self.1.lock().unwrap();
|
||||
if sv.is_empty() {
|
||||
None
|
||||
} else {
|
||||
sv.get((xorshift64_random() as usize) % sv.len()).cloned()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,24 +6,38 @@
|
|||
* https://www.zerotier.com/
|
||||
*/
|
||||
|
||||
pub(crate) const PROTOCOL_VERSION: u8 = 1;
|
||||
pub const PROTOCOL_VERSION: u8 = 1;
|
||||
pub const HASH_ALGORITHM_SHA384: u8 = 1;
|
||||
|
||||
pub(crate) const MESSAGE_TYPE_NOP: u8 = 0;
|
||||
pub(crate) const MESSAGE_TYPE_HAVE_NEW_OBJECT: u8 = 1;
|
||||
pub(crate) const MESSAGE_TYPE_OBJECT: u8 = 2;
|
||||
pub(crate) const MESSAGE_TYPE_GET_OBJECTS: u8 = 3;
|
||||
pub const MESSAGE_TYPE_NOP: u8 = 0;
|
||||
pub const MESSAGE_TYPE_HAVE_NEW_OBJECT: u8 = 1;
|
||||
pub const MESSAGE_TYPE_OBJECT: u8 = 2;
|
||||
pub const MESSAGE_TYPE_GET_OBJECTS: u8 = 3;
|
||||
|
||||
/// HELLO message, which is all u8's and is packed and so can be parsed directly in place.
|
||||
/// This message is sent at the start of any connection by both sides.
|
||||
#[repr(packed)]
|
||||
pub(crate) struct Hello {
|
||||
pub struct Hello {
|
||||
pub hello_size: u8, // technically a varint but below 0x80
|
||||
pub protocol_version: u8,
|
||||
pub hash_algorithm: u8,
|
||||
pub flags: [u8; 4], // u32, little endian
|
||||
pub clock: [u8; 8], // u64, little endian
|
||||
pub data_set_size: [u8; 8], // u64, little endian
|
||||
pub last_object_receive_time: [u8; 8], // u64, little endian, u64::MAX if unspecified
|
||||
pub domain_hash: [u8; 48],
|
||||
pub instance_id: [u8; 16],
|
||||
pub loopback_check_code_salt: [u8; 8],
|
||||
pub loopback_check_code_salt: [u8; 16],
|
||||
pub loopback_check_code: [u8; 16],
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::mem::size_of;
|
||||
use crate::protocol::*;
|
||||
|
||||
#[test]
|
||||
fn check_sizing() {
|
||||
// Make sure packed structures are really packed.
|
||||
assert_eq!(size_of::<Hello>(), 1 + 1 + 1 + 4 + 8 + 8 + 48 + 16 + 16 + 16);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,53 +10,27 @@ use std::collections::HashMap;
|
|||
use std::convert::TryInto;
|
||||
use std::error::Error;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::marker::PhantomData;
|
||||
use std::mem::{size_of, transmute};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
use getrandom::getrandom;
|
||||
use sha2::{Digest, Sha384};
|
||||
use smol::{Executor, Task, Timer};
|
||||
use smol::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
use smol::lock::Mutex;
|
||||
use smol::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, SocketAddr};
|
||||
use smol::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream};
|
||||
use smol::stream::StreamExt;
|
||||
|
||||
use crate::{IDENTITY_HASH_SIZE, ms_monotonic, ms_since_epoch, protocol};
|
||||
use crate::store::{StoreObjectResult, Store};
|
||||
use zerotier_core_crypto::hash::SHA384;
|
||||
use zerotier_core_crypto::random;
|
||||
|
||||
use crate::{IDENTITY_HASH_SIZE, ms_monotonic, ms_since_epoch, protocol, Config};
|
||||
use crate::store::{Store, StorePutResult};
|
||||
use crate::varint;
|
||||
|
||||
const CONNECTION_TIMEOUT_SECONDS: u64 = 60;
|
||||
const CONNECTION_SYNC_RESTART_TIMEOUT_SECONDS: u64 = 5;
|
||||
|
||||
static mut XORSHIFT64_STATE: u64 = 0;
|
||||
|
||||
/// Get a non-cryptographic random number.
|
||||
fn xorshift64_random() -> u64 {
|
||||
let mut x = unsafe { XORSHIFT64_STATE };
|
||||
x ^= x.wrapping_shl(13);
|
||||
x ^= x.wrapping_shr(7);
|
||||
x ^= x.wrapping_shl(17);
|
||||
unsafe { XORSHIFT64_STATE = x };
|
||||
x
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
/// Number of P2P connections desired.
|
||||
pub target_link_count: usize,
|
||||
|
||||
/// Maximum allowed size of an object.
|
||||
pub max_object_size: usize,
|
||||
|
||||
/// TCP port to which this should bind.
|
||||
pub tcp_port: u16,
|
||||
|
||||
/// A name for this replicated data set. This is just used to prevent linking to peers replicating different data.
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
struct ConnectionKey {
|
||||
instance_id: [u8; 16],
|
||||
|
@ -74,33 +48,32 @@ impl Hash for ConnectionKey {
|
|||
struct Connection {
|
||||
remote_address: SocketAddr,
|
||||
last_receive: Arc<AtomicU64>,
|
||||
task: Task<()>
|
||||
task: Task<()>,
|
||||
}
|
||||
|
||||
struct ReplicatorImpl<'ex> {
|
||||
executor: Arc<Executor<'ex>>,
|
||||
instance_id: [u8; 16],
|
||||
loopback_check_code_secret: [u8; 16],
|
||||
loopback_check_code_secret: [u8; 48],
|
||||
domain_hash: [u8; 48],
|
||||
store: Arc<dyn Store>,
|
||||
config: Config,
|
||||
connections: Mutex<HashMap<ConnectionKey, Connection>>,
|
||||
connections_in_progress: Mutex<HashMap<SocketAddr, Task<()>>>,
|
||||
announced_objects_requested: Mutex<HashMap<[u8; IDENTITY_HASH_SIZE], u64>>,
|
||||
}
|
||||
|
||||
pub struct Replicator<'ex> {
|
||||
v4_listener_task: Option<Task<()>>,
|
||||
v6_listener_task: Option<Task<()>>,
|
||||
service_task: Task<()>,
|
||||
_marker: PhantomData<std::cell::UnsafeCell<&'ex ()>>,
|
||||
background_cleanup_task: Task<()>,
|
||||
_impl: Arc<ReplicatorImpl<'ex>>
|
||||
}
|
||||
|
||||
impl<'ex> Replicator<'ex> {
|
||||
/// Create a new replicator to replicate the contents of the provided store.
|
||||
/// All async tasks, sockets, and connections will be dropped if the replicator is dropped.
|
||||
pub async fn start(executor: &Arc<Executor<'ex>>, store: Arc<dyn Store>, config: Config) -> Result<Replicator<'ex>, Box<dyn Error>> {
|
||||
let _ = unsafe { getrandom(&mut *(&mut XORSHIFT64_STATE as *mut u64).cast::<[u8; 8]>()) };
|
||||
|
||||
let listener_v4 = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.tcp_port)).await;
|
||||
let listener_v6 = TcpListener::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, config.tcp_port, 0, 0)).await;
|
||||
if listener_v4.is_err() && listener_v6.is_err() {
|
||||
|
@ -111,30 +84,27 @@ impl<'ex> Replicator<'ex> {
|
|||
executor: executor.clone(),
|
||||
instance_id: {
|
||||
let mut tmp = [0_u8; 16];
|
||||
getrandom(&mut tmp).expect("getrandom failed");
|
||||
random::fill_bytes_secure(&mut tmp);
|
||||
tmp
|
||||
},
|
||||
loopback_check_code_secret: {
|
||||
let mut tmp = [0_u8; 16];
|
||||
getrandom(&mut tmp).expect("getrandom failed");
|
||||
let mut tmp = [0_u8; 48];
|
||||
random::fill_bytes_secure(&mut tmp);
|
||||
tmp
|
||||
},
|
||||
domain_hash: {
|
||||
let mut h = Sha384::new();
|
||||
h.update(config.domain.as_bytes());
|
||||
h.finalize().as_ref().try_into().unwrap()
|
||||
},
|
||||
domain_hash: SHA384::hash(config.domain.as_bytes()),
|
||||
config,
|
||||
store,
|
||||
connections: Mutex::new(HashMap::new()),
|
||||
announced_objects_requested: Mutex::new(HashMap::new())
|
||||
connections_in_progress: Mutex::new(HashMap::new()),
|
||||
announced_objects_requested: Mutex::new(HashMap::new()),
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
v4_listener_task: listener_v4.map_or(None, |listener_v4| Some(executor.spawn(r.clone().listener_task_main(listener_v4)))),
|
||||
v6_listener_task: listener_v6.map_or(None, |listener_v6| Some(executor.spawn(r.clone().listener_task_main(listener_v6)))),
|
||||
service_task: executor.spawn(r.service_main()),
|
||||
_marker: PhantomData::default(),
|
||||
v4_listener_task: listener_v4.map_or(None, |listener_v4| Some(executor.spawn(r.clone().tcp_listener_task(listener_v4)))),
|
||||
v6_listener_task: listener_v6.map_or(None, |listener_v6| Some(executor.spawn(r.clone().tcp_listener_task(listener_v6)))),
|
||||
background_cleanup_task: executor.spawn(r.clone().background_cleanup_task()),
|
||||
_impl: r
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -144,47 +114,74 @@ unsafe impl<'ex> Send for Replicator<'ex> {}
|
|||
unsafe impl<'ex> Sync for Replicator<'ex> {}
|
||||
|
||||
impl<'ex> ReplicatorImpl<'ex> {
|
||||
async fn service_main(self: Arc<ReplicatorImpl<'ex>>) {
|
||||
let mut timer = smol::Timer::interval(Duration::from_secs(5));
|
||||
async fn background_cleanup_task(self: Arc<ReplicatorImpl<'ex>>) {
|
||||
let mut timer = smol::Timer::interval(Duration::from_secs(CONNECTION_TIMEOUT_SECONDS / 10));
|
||||
loop {
|
||||
timer.next().await;
|
||||
let now_mt = ms_monotonic();
|
||||
self.announced_objects_requested.lock().await.retain(|_, ts| now_mt.saturating_sub(*ts) < (CONNECTION_TIMEOUT_SECONDS * 1000));
|
||||
self.connections.lock().await.retain(|_, c| (now_mt.saturating_sub(c.last_receive.load(Ordering::Relaxed))) < (CONNECTION_TIMEOUT_SECONDS * 1000));
|
||||
}
|
||||
}
|
||||
|
||||
async fn listener_task_main(self: Arc<ReplicatorImpl<'ex>>, listener: TcpListener) {
|
||||
loop {
|
||||
let stream = listener.accept().await;
|
||||
if stream.is_ok() {
|
||||
let (stream, remote_address) = stream.unwrap();
|
||||
self.handle_new_connection(stream, remote_address, false).await;
|
||||
// Garbage collect the map used to track objects we've requested.
|
||||
self.announced_objects_requested.lock().await.retain(|_, ts| now_mt.saturating_sub(*ts) < (CONNECTION_TIMEOUT_SECONDS * 1000));
|
||||
|
||||
let mut connections = self.connections.lock().await;
|
||||
|
||||
// Close connections that haven't spoken in too long.
|
||||
connections.retain(|_, c| (now_mt.saturating_sub(c.last_receive.load(Ordering::Relaxed))) < (CONNECTION_TIMEOUT_SECONDS * 1000));
|
||||
let num_connections = connections.len();
|
||||
drop(connections); // release lock
|
||||
|
||||
// Try to connect to more nodes if the count is below the target count.
|
||||
if num_connections < self.config.target_link_count {
|
||||
let new_link_seed = self.store.get_remote_endpoint();
|
||||
if new_link_seed.is_some() {
|
||||
let new_link_seed = new_link_seed.unwrap();
|
||||
let mut connections_in_progress = self.connections_in_progress.lock().await;
|
||||
if !connections_in_progress.contains_key(&new_link_seed) {
|
||||
let s2 = self.clone();
|
||||
let _ = connections_in_progress.insert(new_link_seed.clone(), self.executor.spawn(async move {
|
||||
let new_link = TcpStream::connect(&new_link_seed).await;
|
||||
if new_link.is_ok() {
|
||||
s2.handle_new_connection(new_link.unwrap(), new_link_seed, true).await;
|
||||
} else {
|
||||
let _task = s2.connections_in_progress.lock().await.remove(&new_link_seed);
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_new_connection(self: &Arc<ReplicatorImpl<'ex>>, mut stream: TcpStream, remote_address: SocketAddr, outgoing: bool) {
|
||||
stream.set_nodelay(true);
|
||||
async fn tcp_listener_task(self: Arc<ReplicatorImpl<'ex>>, listener: TcpListener) {
|
||||
loop {
|
||||
let stream = listener.accept().await;
|
||||
if stream.is_ok() {
|
||||
let (stream, remote_address) = stream.unwrap();
|
||||
let mut connections_in_progress = self.connections_in_progress.lock().await;
|
||||
if !connections_in_progress.contains_key(&remote_address) {
|
||||
let s2 = self.clone();
|
||||
let _ = connections_in_progress.insert(remote_address, self.executor.spawn(s2.handle_new_connection(stream, remote_address.clone(), false)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut loopback_check_code_salt = [0_u8; 8];
|
||||
getrandom(&mut loopback_check_code_salt).expect("getrandom failed");
|
||||
|
||||
let mut h = Sha384::new();
|
||||
h.update(&loopback_check_code_salt);
|
||||
h.update(&self.loopback_check_code_secret);
|
||||
let loopback_check_code: [u8; 48] = h.finalize().as_ref().try_into().unwrap();
|
||||
async fn handle_new_connection(self: Arc<ReplicatorImpl<'ex>>, mut stream: TcpStream, remote_address: SocketAddr, outgoing: bool) {
|
||||
let _ = stream.set_nodelay(true);
|
||||
|
||||
let mut loopback_check_code_salt = [0_u8; 16];
|
||||
random::fill_bytes_secure(&mut loopback_check_code_salt);
|
||||
let hello = protocol::Hello {
|
||||
hello_size: size_of::<protocol::Hello>() as u8,
|
||||
protocol_version: protocol::PROTOCOL_VERSION,
|
||||
hash_algorithm: protocol::HASH_ALGORITHM_SHA384,
|
||||
flags: [0_u8; 4],
|
||||
clock: ms_since_epoch().to_le_bytes(),
|
||||
data_set_size: self.store.total_size().to_le_bytes(),
|
||||
last_object_receive_time: self.store.last_object_receive_time().unwrap_or(u64::MAX).to_le_bytes(),
|
||||
domain_hash: self.domain_hash.clone(),
|
||||
instance_id: self.instance_id.clone(),
|
||||
loopback_check_code_salt,
|
||||
loopback_check_code: (&loopback_check_code[0..16]).try_into().unwrap()
|
||||
loopback_check_code: (&SHA384::hmac(&self.loopback_check_code_secret, &loopback_check_code_salt)[0..16]).try_into().unwrap(),
|
||||
};
|
||||
let hello: [u8; size_of::<protocol::Hello>()] = unsafe { transmute(hello) };
|
||||
|
||||
|
@ -194,48 +191,50 @@ impl<'ex> ReplicatorImpl<'ex> {
|
|||
let hello: protocol::Hello = unsafe { transmute(hello_buf) };
|
||||
|
||||
// Sanity check HELLO packet. In the future we may support different versions and sizes.
|
||||
if hello.hello_size == size_of::<protocol::Hello>() as u8 && hello.protocol_version == protocol::PROTOCOL_VERSION {
|
||||
// If this hash's first 16 bytes are equal to the one in the HELLO, this connection is
|
||||
// from this node and should be dropped.
|
||||
let mut h = Sha384::new();
|
||||
h.update(&hello.loopback_check_code_salt);
|
||||
h.update(&self.loopback_check_code_secret);
|
||||
let loopback_if_equal: [u8; 48] = h.finalize().as_ref().try_into().unwrap();
|
||||
|
||||
if !loopback_if_equal[0..16].eq(&hello.loopback_check_code) {
|
||||
if hello.hello_size == size_of::<protocol::Hello>() as u8 && hello.protocol_version == protocol::PROTOCOL_VERSION && hello.hash_algorithm == protocol::HASH_ALGORITHM_SHA384 {
|
||||
if !SHA384::hmac(&self.loopback_check_code_secret, &hello.loopback_check_code_salt)[0..16].eq(&hello.loopback_check_code) {
|
||||
let k = ConnectionKey {
|
||||
instance_id: hello.instance_id.clone(),
|
||||
ip: remote_address.ip()
|
||||
ip: remote_address.ip(),
|
||||
};
|
||||
let mut connections = self.connections.lock().await;
|
||||
let s2 = self.clone();
|
||||
let _ = connections.entry(k).or_insert_with(move || {
|
||||
stream.set_nodelay(false);
|
||||
let _ = stream.set_nodelay(false);
|
||||
|
||||
if outgoing {
|
||||
s2.store.save_remote_endpoint(&remote_address);
|
||||
}
|
||||
|
||||
let last_receive = Arc::new(AtomicU64::new(ms_monotonic()));
|
||||
Connection {
|
||||
remote_address,
|
||||
last_receive: last_receive.clone(),
|
||||
task: self.executor.spawn(self.clone().connection_io_task_main(stream, hello.instance_id, last_receive))
|
||||
task: s2.executor.spawn(s2.clone().connection_io_task(stream, hello.instance_id, last_receive)),
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _task = self.connections_in_progress.lock().await.remove(&remote_address);
|
||||
}
|
||||
|
||||
async fn connection_io_task_main(self: Arc<ReplicatorImpl<'ex>>, stream: TcpStream, remote_instance_id: [u8; 16], last_receive: Arc<AtomicU64>) {
|
||||
async fn connection_io_task(self: Arc<ReplicatorImpl<'ex>>, stream: TcpStream, remote_instance_id: [u8; 16], last_receive: Arc<AtomicU64>) {
|
||||
let mut reader = BufReader::with_capacity(65536, stream.clone());
|
||||
let writer = Arc::new(Mutex::new(stream));
|
||||
|
||||
let writer2 = writer.clone();
|
||||
let _sync_search_init_task = self.executor.spawn(async {
|
||||
let writer = writer2;
|
||||
//let writer2 = writer.clone();
|
||||
let _sync_search_init_task = self.executor.spawn(async move {
|
||||
//let writer = writer2;
|
||||
let mut periodic_timer = Timer::interval(Duration::from_secs(1));
|
||||
loop {
|
||||
let _ = periodic_timer.next().await;
|
||||
}
|
||||
});
|
||||
|
||||
let mut get_buffer = Vec::new();
|
||||
let mut tmp_mem = Vec::new();
|
||||
tmp_mem.resize(self.config.max_object_size, 0);
|
||||
let tmp = tmp_mem.as_mut_slice();
|
||||
|
@ -248,7 +247,7 @@ impl<'ex> ReplicatorImpl<'ex> {
|
|||
last_receive.store(ms_monotonic(), Ordering::Relaxed);
|
||||
|
||||
match message_type {
|
||||
protocol::MESSAGE_TYPE_NOP => {},
|
||||
protocol::MESSAGE_TYPE_NOP => {}
|
||||
|
||||
protocol::MESSAGE_TYPE_HAVE_NEW_OBJECT => {
|
||||
if reader.read_exact(&mut tmp[0..IDENTITY_HASH_SIZE]).await.is_err() {
|
||||
|
@ -256,7 +255,7 @@ impl<'ex> ReplicatorImpl<'ex> {
|
|||
}
|
||||
let identity_hash: [u8; 48] = (&tmp[0..IDENTITY_HASH_SIZE]).try_into().unwrap();
|
||||
let mut announced_objects_requested = self.announced_objects_requested.lock().await;
|
||||
if !announced_objects_requested.contains_key(&identity_hash) && !self.store.have(&identity_hash) {
|
||||
if !announced_objects_requested.contains_key(&identity_hash) && !self.store.have(ms_since_epoch(), &identity_hash) {
|
||||
announced_objects_requested.insert(identity_hash.clone(), ms_monotonic());
|
||||
drop(announced_objects_requested); // release mutex
|
||||
|
||||
|
@ -285,23 +284,30 @@ impl<'ex> ReplicatorImpl<'ex> {
|
|||
break 'main_io_loop;
|
||||
}
|
||||
|
||||
let identity_hash: [u8; 48] = Sha384::digest(object.as_ref()).as_ref().try_into().unwrap();
|
||||
match self.store.put(&identity_hash, object) {
|
||||
StoreObjectResult::Invalid => {
|
||||
let identity_hash: [u8; 48] = SHA384::hash(object);
|
||||
match self.store.put(ms_since_epoch(), &identity_hash, object) {
|
||||
StorePutResult::Invalid => {
|
||||
break 'main_io_loop;
|
||||
},
|
||||
StoreObjectResult::Ok | StoreObjectResult::Duplicate => {
|
||||
}
|
||||
StorePutResult::Ok | StorePutResult::Duplicate => {
|
||||
if self.announced_objects_requested.lock().await.remove(&identity_hash).is_some() {
|
||||
// TODO: propagate rumor if we requested this object in response to a HAVE message.
|
||||
}
|
||||
},
|
||||
}
|
||||
_ => {
|
||||
let _ = self.announced_objects_requested.lock().await.remove(&identity_hash);
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
protocol::MESSAGE_TYPE_GET_OBJECTS => {
|
||||
// Get the reference time for this query.
|
||||
let reference_time = varint::async_read(&mut reader).await;
|
||||
if reference_time.is_err() {
|
||||
break 'main_io_loop;
|
||||
}
|
||||
let reference_time = reference_time.unwrap();
|
||||
|
||||
// Read common prefix if the requester is requesting a set of hashes with the same beginning.
|
||||
// A common prefix length of zero means they're requesting by full hash.
|
||||
if reader.read_exact(&mut tmp[0..1]).await.is_err() {
|
||||
|
@ -328,20 +334,17 @@ impl<'ex> ReplicatorImpl<'ex> {
|
|||
break 'main_io_loop;
|
||||
}
|
||||
let identity_hash: [u8; IDENTITY_HASH_SIZE] = (&tmp[0..IDENTITY_HASH_SIZE]).try_into().unwrap();
|
||||
let object = self.store.get(&identity_hash);
|
||||
if object.is_some() {
|
||||
let object2 = object.unwrap();
|
||||
let object = object2.as_slice();
|
||||
if self.store.get(reference_time, &identity_hash, &mut get_buffer) {
|
||||
let mut w = writer.lock().await;
|
||||
if varint::async_write(&mut *w, object.len() as u64).await.is_err() {
|
||||
if varint::async_write(&mut *w, get_buffer.len() as u64).await.is_err() {
|
||||
break 'main_io_loop;
|
||||
}
|
||||
if w.write_all(object).await.is_err() {
|
||||
if w.write_all(get_buffer.as_slice()).await.is_err() {
|
||||
break 'main_io_loop;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_ => {
|
||||
break 'main_io_loop;
|
||||
|
|
|
@ -10,8 +10,10 @@ use smol::net::SocketAddr;
|
|||
|
||||
use crate::IDENTITY_HASH_SIZE;
|
||||
|
||||
/// Result code from the put() method in Database.
|
||||
pub enum StoreObjectResult {
|
||||
pub const MIN_IDENTITY_HASH: [u8; 48] = [0_u8; 48];
|
||||
pub const MAX_IDENTITY_HASH: [u8; 48] = [0xff_u8; 48];
|
||||
|
||||
pub enum StorePutResult {
|
||||
/// Datum stored successfully.
|
||||
Ok,
|
||||
/// Datum is one we already have.
|
||||
|
@ -22,23 +24,28 @@ pub enum StoreObjectResult {
|
|||
Ignored,
|
||||
}
|
||||
|
||||
/// Trait that must be implemented for the data store that is to be replicated.
|
||||
/// Trait that must be implemented by the data store that is to be replicated.
|
||||
pub trait Store: Sync + Send {
|
||||
/// Get the total size of this data set in objects.
|
||||
fn total_size(&self) -> u64;
|
||||
|
||||
/// Get an object from the database.
|
||||
fn get(&self, identity_hash: &[u8; IDENTITY_HASH_SIZE]) -> Option<Vec<u8>>;
|
||||
/// Get an object from the database, storing it in the supplied buffer.
|
||||
/// A return of 'false' leaves the buffer state undefined. If the return is true any previous
|
||||
/// data in the supplied buffer will have been cleared and replaced with the retrieved object.
|
||||
fn get(&self, reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE], buffer: &mut Vec<u8>) -> bool;
|
||||
|
||||
/// Store an entry in the database.
|
||||
fn put(&self, identity_hash: &[u8; IDENTITY_HASH_SIZE], object: &[u8]) -> StoreObjectResult;
|
||||
fn put(&self, reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE], object: &[u8]) -> StorePutResult;
|
||||
|
||||
/// Check if we have an object by its identity hash.
|
||||
fn have(&self, identity_hash: &[u8; IDENTITY_HASH_SIZE]) -> bool;
|
||||
fn have(&self, reference_time: u64, identity_hash: &[u8; IDENTITY_HASH_SIZE]) -> bool;
|
||||
|
||||
/// Get the total count of objects.
|
||||
fn total_count(&self, reference_time: u64) -> Option<u64>;
|
||||
|
||||
/// Get the time the last object was received in milliseconds since epoch.
|
||||
fn last_object_receive_time(&self) -> Option<u64>;
|
||||
|
||||
/// Count the number of identity hash keys in this range (inclusive) of identity hashes.
|
||||
/// This may return None if an error occurs, but should return 0 if the set is empty.
|
||||
fn count(&self, start: &[u8; IDENTITY_HASH_SIZE], end: &[u8; IDENTITY_HASH_SIZE]) -> Option<u64>;
|
||||
fn count(&self, reference_time: u64, start: &[u8; IDENTITY_HASH_SIZE], end: &[u8; IDENTITY_HASH_SIZE]) -> Option<u64>;
|
||||
|
||||
/// Called when a connection to a remote node was successful.
|
||||
/// This is always called on successful outbound connect.
|
||||
|
|
|
@ -59,3 +59,18 @@ pub fn next_u64_secure() -> u64 {
|
|||
|
||||
#[inline(always)]
|
||||
pub fn fill_bytes_secure(dest: &mut [u8]) { randomize(Level::Strong, dest); }
|
||||
|
||||
static mut XORSHIFT64_STATE: u64 = 0;
|
||||
|
||||
/// Get a non-cryptographic random number.
|
||||
pub fn xorshift64_random() -> u64 {
|
||||
let mut x = unsafe { XORSHIFT64_STATE };
|
||||
while x == 0 {
|
||||
x = next_u64_secure();
|
||||
}
|
||||
x ^= x.wrapping_shl(13);
|
||||
x ^= x.wrapping_shr(7);
|
||||
x ^= x.wrapping_shl(17);
|
||||
unsafe { XORSHIFT64_STATE = x };
|
||||
x
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue