mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-04-26 08:57:26 +02:00
sync
This commit is contained in:
parent
2158675fd2
commit
1913d956b3
8 changed files with 735 additions and 486 deletions
|
@ -6,12 +6,24 @@
|
|||
* https://www.zerotier.com/
|
||||
*/
|
||||
|
||||
use std::ops::Bound::Included;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use crate::ms_since_epoch;
|
||||
/// Generate a range of SHA512 hashes from a prefix and a number of bits.
|
||||
/// The range will be inclusive and cover all keys under the prefix.
|
||||
pub fn range_from_prefix(prefix: &[u8], prefix_bits: usize) -> ([u8; 64], [u8; 64]) {
|
||||
let prefix_bits = prefix_bits.min(prefix.len() * 8).min(64);
|
||||
let mut start = [0_u8; 64];
|
||||
let mut end = [0xff_u8; 64];
|
||||
let whole_bytes = prefix_bits / 8;
|
||||
let remaining_bits = prefix_bits % 8;
|
||||
start[0..whole_bytes].copy_from_slice(&prefix[0..whole_bytes]);
|
||||
end[0..whole_bytes].copy_from_slice(&prefix[0..whole_bytes]);
|
||||
if remaining_bits != 0 && whole_bytes < prefix.len() {
|
||||
start[whole_bytes] |= prefix[whole_bytes];
|
||||
end[whole_bytes] &= prefix[whole_bytes] | ((0xff_u8).wrapping_shr(remaining_bits as u32));
|
||||
}
|
||||
(start, end)
|
||||
}
|
||||
|
||||
/// Result returned by DB::load().
|
||||
/// Result returned by DataStore::load().
|
||||
pub enum LoadResult<V: AsRef<[u8]> + Send> {
|
||||
/// Object was found.
|
||||
Ok(V),
|
||||
|
@ -23,10 +35,12 @@ pub enum LoadResult<V: AsRef<[u8]> + Send> {
|
|||
TimeNotAvailable
|
||||
}
|
||||
|
||||
/// Result returned by DB::store().
|
||||
/// Result returned by DataStore::store().
|
||||
pub enum StoreResult {
|
||||
/// Entry was accepted (whether or not an old value was replaced).
|
||||
Ok,
|
||||
/// Entry was accepted.
|
||||
/// The integer included with Ok is the reference time that should be advertised.
|
||||
/// If this is not a temporally subjective data set then zero can be used.
|
||||
Ok(i64),
|
||||
|
||||
/// Entry was a duplicate of one we already have but was otherwise valid.
|
||||
Duplicate,
|
||||
|
@ -40,19 +54,16 @@ pub enum StoreResult {
|
|||
|
||||
/// API to be implemented by the data set we want to replicate.
|
||||
///
|
||||
/// Keys as understood by syncwhole are SHA512 hashes of values. The user can of course
|
||||
/// have their own concept of a "key" separate from this, but that would not be used
|
||||
/// for data set replication. Replication is content identity based.
|
||||
///
|
||||
/// The API specified here supports temporally subjective data sets. These are data sets
|
||||
/// where the existence or non-existence of a record may depend on the (real world) time.
|
||||
/// A parameter for reference time allows a remote querying node to send its own "this is
|
||||
/// what time I think it is" value to be considered locally so that data can be replicated
|
||||
/// as of any given time.
|
||||
///
|
||||
/// The constants KEY_SIZE, MAX_VALUE_SIZE, and KEY_IS_COMPUTED are protocol constants
|
||||
/// for your replication domain. They can't be changed once defined unless all nodes
|
||||
/// are upgraded at once.
|
||||
///
|
||||
/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a function of
|
||||
/// values. If this is true, key_from_value() must be implemented.
|
||||
///
|
||||
/// The implementation must be thread safe and may be called concurrently.
|
||||
pub trait DataStore: Sync + Send {
|
||||
/// Type to be enclosed in the Ok() enum value in LoadResult.
|
||||
|
@ -61,32 +72,13 @@ pub trait DataStore: Sync + Send {
|
|||
/// your implementation. Examples include Box<[u8]>, Arc<[u8]>, Vec<u8>, etc.
|
||||
type LoadResultValueType: AsRef<[u8]> + Send;
|
||||
|
||||
/// Size of keys, which must be fixed in length. These are typically hashes.
|
||||
const KEY_SIZE: usize;
|
||||
/// Key hash size, always 64 for SHA512.
|
||||
const KEY_SIZE: usize = 64;
|
||||
|
||||
/// Maximum size of a value in bytes.
|
||||
const MAX_VALUE_SIZE: usize;
|
||||
|
||||
/// This should be true if the key is computed, such as by hashing the value.
|
||||
///
|
||||
/// If this is true then keys do not have to be sent over the wire. Instead they
|
||||
/// are computed by calling get_key(). If this is false keys are assumed not to
|
||||
/// be computable from values and are explicitly sent.
|
||||
const KEY_IS_COMPUTED: bool;
|
||||
|
||||
/// Compute the key corresponding to a value.
|
||||
///
|
||||
/// If KEY_IS_COMPUTED is true this must be implemented. The default implementation
|
||||
/// panics to indicate this. If KEY_IS_COMPUTED is false this is never called.
|
||||
#[allow(unused_variables)]
|
||||
fn key_from_value(&self, value: &[u8], key_buffer: &mut [u8]) {
|
||||
panic!("key_from_value() must be implemented if KEY_IS_COMPUTED is true");
|
||||
}
|
||||
|
||||
/// Get the current wall clock in milliseconds since Unix epoch.
|
||||
///
|
||||
/// This is delegated to the data store to support scenarios where you want to fix
|
||||
/// the clock or snapshot at a given time.
|
||||
fn clock(&self) -> i64;
|
||||
|
||||
/// Get the domain of this data store.
|
||||
|
@ -99,144 +91,56 @@ pub trait DataStore: Sync + Send {
|
|||
/// Get an item if it exists as of a given reference time.
|
||||
fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult<Self::LoadResultValueType>;
|
||||
|
||||
/// Check whether this data store contains a key.
|
||||
///
|
||||
/// The default implementation just uses load(). Override if you can provide a faster
|
||||
/// version.
|
||||
fn contains(&self, reference_time: i64, key: &[u8]) -> bool {
|
||||
match self.load(reference_time, key) {
|
||||
LoadResult::Ok(_) => true,
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
|
||||
/// Store an item in the data store and return its status.
|
||||
///
|
||||
/// Note that no time is supplied here. The data store must determine this in an implementation
|
||||
/// dependent manner if this is a temporally subjective data store. It could be determined by
|
||||
/// the wall clock, from the object itself, etc.
|
||||
///
|
||||
/// The implementation is responsible for validating inputs and returning 'Rejected' if they
|
||||
/// are invalid. A return value of Rejected can be used to do things like drop connections
|
||||
/// to peers that send invalid data, so it should only be returned if the data is malformed
|
||||
/// or something like a signature check fails. The 'Ignored' enum value should be returned
|
||||
/// for inputs that are valid but were not stored for some other reason, such as being
|
||||
/// expired. It's important to return 'Ok' for accepted values to hint to the replicator
|
||||
/// that they should be aggressively advertised to other peers.
|
||||
/// The key supplied here will always be the SHA512 hash of the value. There is no need to
|
||||
/// re-compute and check the key, but the value must be validated.
|
||||
///
|
||||
/// If KEY_IS_COMPUTED is true, the key supplied here can be assumed to be correct. It will
|
||||
/// have been computed via get_key().
|
||||
/// Validation of the value and returning the appropriate StoreResult is important to the
|
||||
/// operation of the synchronization algorithm:
|
||||
///
|
||||
/// StoreResult::Ok - Value was valid and was accepted and saved.
|
||||
///
|
||||
/// StoreResult::Duplicate - Value was valid but is a duplicate of one we already have.
|
||||
///
|
||||
/// StoreResult::Ignored - Value was valid but for some other reason was not saved.
|
||||
///
|
||||
/// StoreResult::Rejected - Value was not valid, causes link to peer to be dropped.
|
||||
///
|
||||
/// Rejected should only be returned if the value actually fails a validity check, signature
|
||||
/// verification, proof of work check, or some other required criteria. Ignored must be
|
||||
/// returned if the value is valid but is too old or was rejected for some other normal reason.
|
||||
fn store(&self, key: &[u8], value: &[u8]) -> StoreResult;
|
||||
|
||||
/// Get the number of items under a prefix as of a given reference time.
|
||||
///
|
||||
/// The default implementation uses for_each_key(). This can be specialized if it can
|
||||
/// be done more efficiently than that.
|
||||
fn count(&self, reference_time: i64, key_prefix: &[u8]) -> u64 {
|
||||
let mut cnt: u64 = 0;
|
||||
self.for_each_key(reference_time, key_prefix, |_| {
|
||||
cnt += 1;
|
||||
true
|
||||
});
|
||||
cnt
|
||||
}
|
||||
/// Get the number of items in a range.
|
||||
fn count(&self, reference_time: i64, key_range_start: &[u8], key_range_end: &[u8]) -> u64;
|
||||
|
||||
/// Get the total number of records in this data store.
|
||||
fn total_count(&self) -> u64;
|
||||
|
||||
/// Iterate through keys beneath a key prefix, stopping if the function returns false.
|
||||
/// Iterate through keys, stopping if the function returns false.
|
||||
///
|
||||
/// The default implementation uses for_each(). This can be specialized if it's faster to
|
||||
/// only load keys.
|
||||
fn for_each_key<F: FnMut(&[u8]) -> bool>(&self, reference_time: i64, key_prefix: &[u8], mut f: F) {
|
||||
self.for_each(reference_time, key_prefix, |k, _| f(k));
|
||||
fn for_each_key<F: FnMut(&[u8]) -> bool>(&self, reference_time: i64, key_range_start: &[u8], key_range_end: &[u8], mut f: F) {
|
||||
self.for_each(reference_time, key_range_start, key_range_end, |k, _| f(k));
|
||||
}
|
||||
|
||||
/// Iterate through keys and values beneath a key prefix, stopping if the function returns false.
|
||||
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, reference_time: i64, key_prefix: &[u8], f: F);
|
||||
}
|
||||
|
||||
/// A simple in-memory data store backed by a BTreeMap.
|
||||
pub struct MemoryDataStore<const KEY_SIZE: usize> {
|
||||
max_age: i64,
|
||||
domain: String,
|
||||
db: Mutex<BTreeMap<[u8; KEY_SIZE], (i64, Arc<[u8]>)>>
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> MemoryDataStore<KEY_SIZE> {
|
||||
pub fn new(max_age: i64, domain: String) -> Self {
|
||||
Self {
|
||||
max_age: if max_age > 0 { max_age } else { i64::MAX },
|
||||
domain,
|
||||
db: Mutex::new(BTreeMap::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> DataStore for MemoryDataStore<KEY_SIZE> {
|
||||
type LoadResultValueType = Arc<[u8]>;
|
||||
const KEY_SIZE: usize = KEY_SIZE;
|
||||
const MAX_VALUE_SIZE: usize = 65536;
|
||||
const KEY_IS_COMPUTED: bool = false;
|
||||
|
||||
fn clock(&self) -> i64 { ms_since_epoch() }
|
||||
|
||||
fn domain(&self) -> &str { self.domain.as_str() }
|
||||
|
||||
fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult<Self::LoadResultValueType> {
|
||||
let db = self.db.lock().unwrap();
|
||||
let e = db.get(key);
|
||||
if e.is_some() {
|
||||
let e = e.unwrap();
|
||||
if (reference_time - e.0) <= self.max_age {
|
||||
LoadResult::Ok(e.1.clone())
|
||||
} else {
|
||||
LoadResult::NotFound
|
||||
}
|
||||
} else {
|
||||
LoadResult::NotFound
|
||||
}
|
||||
}
|
||||
|
||||
fn store(&self, key: &[u8], value: &[u8]) -> StoreResult {
|
||||
let ts = crate::ms_since_epoch();
|
||||
let mut isdup = false;
|
||||
self.db.lock().unwrap().entry(key.try_into().unwrap()).and_modify(|e| {
|
||||
if e.1.as_ref().eq(value) {
|
||||
isdup = true;
|
||||
} else {
|
||||
*e = (ts, Arc::from(value));
|
||||
}
|
||||
}).or_insert_with(|| {
|
||||
(ts, Arc::from(value))
|
||||
});
|
||||
if isdup {
|
||||
StoreResult::Duplicate
|
||||
} else {
|
||||
StoreResult::Ok
|
||||
}
|
||||
}
|
||||
|
||||
fn total_count(&self) -> u64 { self.db.lock().unwrap().len() as u64 }
|
||||
|
||||
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, reference_time: i64, key_prefix: &[u8], mut f: F) {
|
||||
let mut r_start = [0_u8; KEY_SIZE];
|
||||
let mut r_end = [0xff_u8; KEY_SIZE];
|
||||
(&mut r_start[0..key_prefix.len()]).copy_from_slice(key_prefix);
|
||||
(&mut r_end[0..key_prefix.len()]).copy_from_slice(key_prefix);
|
||||
for (k, v) in self.db.lock().unwrap().range((Included(r_start), Included(r_end))) {
|
||||
if (reference_time - v.0) <= self.max_age {
|
||||
if !f(k, &v.1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> PartialEq for MemoryDataStore<KEY_SIZE> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.max_age == other.max_age && self.domain == other.domain && self.db.lock().unwrap().eq(&*other.db.lock().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> Eq for MemoryDataStore<KEY_SIZE> {}
|
||||
|
||||
impl<const KEY_SIZE: usize> Clone for MemoryDataStore<KEY_SIZE> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
max_age: self.max_age,
|
||||
domain: self.domain.clone(),
|
||||
db: Mutex::new(self.db.lock().unwrap().clone())
|
||||
}
|
||||
}
|
||||
/// Iterate through keys and values, stopping if the function returns false.
|
||||
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, reference_time: i64, key_range_start: &[u8], key_range_end: &[u8], f: F);
|
||||
}
|
||||
|
|
|
@ -6,43 +6,12 @@
|
|||
* https://www.zerotier.com/
|
||||
*/
|
||||
|
||||
use std::alloc::{alloc_zeroed, dealloc, Layout};
|
||||
use std::mem::size_of;
|
||||
use std::ptr::write_bytes;
|
||||
use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut, write_bytes};
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use crate::varint;
|
||||
|
||||
#[inline(always)]
|
||||
fn xorshift64(mut x: u64) -> u64 {
|
||||
x = u64::from_le(x);
|
||||
x ^= x.wrapping_shl(13);
|
||||
x ^= x.wrapping_shr(7);
|
||||
x ^= x.wrapping_shl(17);
|
||||
x.to_le()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn splitmix64(mut x: u64) -> u64 {
|
||||
x = u64::from_le(x);
|
||||
x ^= x.wrapping_shr(30);
|
||||
x = x.wrapping_mul(0xbf58476d1ce4e5b9);
|
||||
x ^= x.wrapping_shr(27);
|
||||
x = x.wrapping_mul(0x94d049bb133111eb);
|
||||
x ^= x.wrapping_shr(31);
|
||||
x.to_le()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn splitmix64_inverse(mut x: u64) -> u64 {
|
||||
x = u64::from_le(x);
|
||||
x ^= x.wrapping_shr(31) ^ x.wrapping_shr(62);
|
||||
x = x.wrapping_mul(0x319642b2d24d8ec3);
|
||||
x ^= x.wrapping_shr(27) ^ x.wrapping_shr(54);
|
||||
x = x.wrapping_mul(0x96de1b173f119089);
|
||||
x ^= x.wrapping_shr(30) ^ x.wrapping_shr(60);
|
||||
x.to_le()
|
||||
}
|
||||
use crate::utils::*;
|
||||
|
||||
#[inline(always)]
|
||||
fn next_iteration_index(prev_iteration_index: u64) -> u64 {
|
||||
|
@ -50,16 +19,17 @@ fn next_iteration_index(prev_iteration_index: u64) -> u64 {
|
|||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
#[repr(C, packed)]
|
||||
struct IBLTEntry {
|
||||
key_sum: u64,
|
||||
check_hash_sum: u64,
|
||||
count: i64
|
||||
count: i32
|
||||
}
|
||||
|
||||
impl IBLTEntry {
|
||||
#[inline(always)]
|
||||
fn is_singular(&self) -> bool {
|
||||
if self.count == 1 || self.count == -1 {
|
||||
if i32::from_le(self.count) == 1 || i32::from_le(self.count) == -1 {
|
||||
xorshift64(self.key_sum) == self.check_hash_sum
|
||||
} else {
|
||||
false
|
||||
|
@ -68,84 +38,102 @@ impl IBLTEntry {
|
|||
}
|
||||
|
||||
/// An Invertible Bloom Lookup Table for set reconciliation with 64-bit hashes.
|
||||
///
|
||||
/// Usage inspired by this paper:
|
||||
///
|
||||
/// https://dash.harvard.edu/bitstream/handle/1/14398536/GENTILI-SENIORTHESIS-2015.pdf
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct IBLT<const B: usize> {
|
||||
map: *mut [IBLTEntry; B]
|
||||
pub struct IBLT {
|
||||
map: Vec<IBLTEntry>
|
||||
}
|
||||
|
||||
impl<const B: usize> IBLT<B> {
|
||||
/// Number of buckets (capacity) of this IBLT.
|
||||
pub const BUCKETS: usize = B;
|
||||
|
||||
impl IBLT {
|
||||
/// This was determined to be effective via empirical testing with random keys. This
|
||||
/// is a protocol constant that can't be changed without upgrading all nodes in a domain.
|
||||
const KEY_MAPPING_ITERATIONS: usize = 2;
|
||||
|
||||
pub fn new() -> Self {
|
||||
assert!(B < u32::MAX as usize); // sanity check
|
||||
Self {
|
||||
map: unsafe { alloc_zeroed(Layout::new::<[IBLTEntry; B]>()).cast() }
|
||||
pub fn new(buckets: usize) -> Self {
|
||||
assert!(buckets < i32::MAX as usize);
|
||||
assert_eq!(size_of::<IBLTEntry>(), 20);
|
||||
let mut iblt = Self { map: Vec::with_capacity(buckets) };
|
||||
unsafe {
|
||||
iblt.map.set_len(buckets);
|
||||
iblt.reset();
|
||||
}
|
||||
iblt
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
unsafe { write_bytes(self.map.cast::<u8>(), 0, size_of::<[IBLTEntry; B]>()) };
|
||||
}
|
||||
/// Compute the size in bytes of an IBLT with the given number of buckets.
|
||||
#[inline(always)]
|
||||
pub fn size_bytes_with_buckets(buckets: usize) -> usize { buckets * size_of::<IBLTEntry>() }
|
||||
|
||||
pub async fn read<R: AsyncReadExt + Unpin>(&mut self, r: &mut R) -> std::io::Result<()> {
|
||||
let mut prev_c = 0_i64;
|
||||
for b in unsafe { (*self.map).iter_mut() } {
|
||||
let _ = r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() }).await?;
|
||||
let _ = r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() }).await?;
|
||||
let mut c = varint::read_async(r).await?.0 as i64;
|
||||
if (c & 1) == 0 {
|
||||
c = c.wrapping_shr(1);
|
||||
} else {
|
||||
c = -c.wrapping_shr(1);
|
||||
}
|
||||
b.count = c + prev_c;
|
||||
prev_c = b.count;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWriteExt + Unpin>(&self, w: &mut W) -> std::io::Result<()> {
|
||||
let mut prev_c = 0_i64;
|
||||
for b in unsafe { (*self.map).iter() } {
|
||||
let _ = w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() }).await?;
|
||||
let _ = w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() }).await?;
|
||||
let mut c = (b.count - prev_c).wrapping_shl(1);
|
||||
prev_c = b.count;
|
||||
if c < 0 {
|
||||
c = -c | 1;
|
||||
}
|
||||
let _ = varint::write_async(w, c as u64).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get this IBLT as a byte array.
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
let mut out = std::io::Cursor::new(Vec::<u8>::with_capacity(B * 20));
|
||||
let current = tokio::runtime::Handle::try_current();
|
||||
if current.is_ok() {
|
||||
assert!(current.unwrap().block_on(self.write(&mut out)).is_ok());
|
||||
/// Compute the IBLT size in buckets to reconcile a given set difference, or return 0 if no advantage.
|
||||
/// This returns zero if an IBLT would take up as much or more space than just sending local_set_size
|
||||
/// hashes of hash_size_bytes.
|
||||
pub fn calc_iblt_parameters(hash_size_bytes: usize, local_set_size: u64, difference_size: u64) -> usize {
|
||||
let hashes_would_be = (hash_size_bytes as f64) * (local_set_size as f64);
|
||||
let buckets_should_be = (difference_size as f64) * 1.8; // factor determined experimentally for best bytes/item, can be tuned
|
||||
let iblt_would_be = buckets_should_be * (size_of::<IBLTEntry>() as f64);
|
||||
if iblt_would_be < hashes_would_be {
|
||||
buckets_should_be.ceil() as usize
|
||||
} else {
|
||||
assert!(tokio::runtime::Builder::new_current_thread().build().unwrap().block_on(self.write(&mut out)).is_ok());
|
||||
0
|
||||
}
|
||||
out.into_inner()
|
||||
}
|
||||
|
||||
fn ins_rem(&mut self, mut key: u64, delta: i64) {
|
||||
/// Get the size of this IBLT in buckets.
|
||||
#[inline(always)]
|
||||
pub fn buckets(&self) -> usize { self.map.len() }
|
||||
|
||||
/// Get the size of this IBLT in bytes.
|
||||
pub fn size_bytes(&self) -> usize { self.map.len() * size_of::<IBLTEntry>() }
|
||||
|
||||
/// Zero this IBLT.
|
||||
#[inline(always)]
|
||||
pub fn reset(&mut self) {
|
||||
unsafe { write_bytes(self.map.as_mut_ptr().cast::<u8>(), 0, self.map.len() * size_of::<IBLTEntry>()); }
|
||||
}
|
||||
|
||||
/// Get this IBLT as a byte slice in place.
|
||||
#[inline(always)]
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
unsafe { &*slice_from_raw_parts(self.map.as_ptr().cast::<u8>(), self.map.len() * size_of::<IBLTEntry>()) }
|
||||
}
|
||||
|
||||
/// Construct an IBLT from an input reader and a size in bytes.
|
||||
pub async fn new_from_reader<R: AsyncReadExt + Unpin>(r: &mut R, bytes: usize) -> std::io::Result<Self> {
|
||||
assert_eq!(size_of::<IBLTEntry>(), 20);
|
||||
if (bytes % size_of::<IBLTEntry>()) != 0 {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "incomplete or invalid IBLT"))
|
||||
} else {
|
||||
let buckets = bytes / size_of::<IBLTEntry>();
|
||||
let mut iblt = Self { map: Vec::with_capacity(buckets) };
|
||||
unsafe {
|
||||
iblt.map.set_len(buckets);
|
||||
r.read_exact(&mut *slice_from_raw_parts_mut(iblt.map.as_mut_ptr().cast::<u8>(), bytes)).await?;
|
||||
}
|
||||
Ok(iblt)
|
||||
}
|
||||
}
|
||||
|
||||
/// Write this IBLT to a stream.
|
||||
/// Note that the size of the IBLT in bytes must be stored separately. Use size_bytes() to get that.
|
||||
#[inline(always)]
|
||||
pub async fn write<W: AsyncWriteExt + Unpin>(&self, w: &mut W) -> std::io::Result<()> {
|
||||
w.write_all(self.as_bytes()).await
|
||||
}
|
||||
|
||||
fn ins_rem(&mut self, mut key: u64, delta: i32) {
|
||||
key = splitmix64(key);
|
||||
let check_hash = xorshift64(key);
|
||||
let mut iteration_index = u64::from_le(key);
|
||||
let buckets = self.map.len();
|
||||
for _ in 0..Self::KEY_MAPPING_ITERATIONS {
|
||||
iteration_index = next_iteration_index(iteration_index);
|
||||
let b = unsafe { (*self.map).get_unchecked_mut((iteration_index as usize) % B) };
|
||||
let b = unsafe { self.map.get_unchecked_mut((iteration_index as usize) % buckets) };
|
||||
b.key_sum ^= key;
|
||||
b.check_hash_sum ^= check_hash;
|
||||
b.count += delta;
|
||||
b.count = (i32::from_le(b.count) + delta).to_le();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,27 +154,35 @@ impl<const B: usize> IBLT<B> {
|
|||
}
|
||||
|
||||
/// Subtract another IBLT from this one to get a set difference.
|
||||
pub fn subtract(&mut self, other: &Self) {
|
||||
for (s, o) in unsafe { (*self.map).iter_mut().zip((*other.map).iter()) } {
|
||||
s.key_sum ^= o.key_sum;
|
||||
s.check_hash_sum ^= o.check_hash_sum;
|
||||
s.count -= o.count;
|
||||
///
|
||||
/// This returns true on success or false on error, which right now can only happen if the
|
||||
/// other IBLT has a different number of buckets or if it contains so many entries
|
||||
pub fn subtract(&mut self, other: &Self) -> bool {
|
||||
if other.map.len() == self.map.len() {
|
||||
for (s, o) in self.map.iter_mut().zip(other.map.iter()) {
|
||||
s.key_sum ^= o.key_sum;
|
||||
s.check_hash_sum ^= o.check_hash_sum;
|
||||
s.count = (i32::from_le(s.count) - i32::from_le(o.count)).to_le();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
pub fn list<F: FnMut(&[u8; 8])>(self, mut f: F) -> bool {
|
||||
let mut queue: Vec<u32> = Vec::with_capacity(B);
|
||||
/// List as many entries in this IBLT as can be extracted.
|
||||
pub fn list<F: FnMut(&[u8; 8])>(mut self, mut f: F) {
|
||||
let mut queue: Vec<u32> = Vec::with_capacity(self.map.len());
|
||||
|
||||
for b in 0..B {
|
||||
if unsafe { (*self.map).get_unchecked(b).is_singular() } {
|
||||
queue.push(b as u32);
|
||||
for bi in 0..self.map.len() {
|
||||
if unsafe { self.map.get_unchecked(bi).is_singular() } {
|
||||
queue.push(bi as u32);
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let b = queue.pop();
|
||||
let b = if b.is_some() {
|
||||
unsafe { (*self.map).get_unchecked_mut(b.unwrap() as usize) }
|
||||
unsafe { self.map.get_unchecked_mut(b.unwrap() as usize) }
|
||||
} else {
|
||||
break;
|
||||
};
|
||||
|
@ -199,29 +195,21 @@ impl<const B: usize> IBLT<B> {
|
|||
|
||||
for _ in 0..Self::KEY_MAPPING_ITERATIONS {
|
||||
iteration_index = next_iteration_index(iteration_index);
|
||||
let b_idx = iteration_index % (B as u64);
|
||||
let b = unsafe { (*self.map).get_unchecked_mut(b_idx as usize) };
|
||||
let b_idx = iteration_index % (self.map.len() as u64);
|
||||
let b = unsafe { self.map.get_unchecked_mut(b_idx as usize) };
|
||||
b.key_sum ^= key;
|
||||
b.check_hash_sum ^= check_hash;
|
||||
b.count -= 1;
|
||||
b.count = (i32::from_le(b.count) - 1).to_le();
|
||||
|
||||
if b.is_singular() {
|
||||
if queue.len() >= (B * 2) { // sanity check for invalid IBLT
|
||||
return false;
|
||||
if queue.len() > self.map.len() { // sanity check for invalid IBLT
|
||||
return;
|
||||
}
|
||||
queue.push(b_idx as u32);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
impl<const B: usize> Drop for IBLT<B> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { dealloc(self.map.cast(), Layout::new::<[IBLTEntry; B]>()) };
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -229,6 +217,7 @@ impl<const B: usize> Drop for IBLT<B> {
|
|||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use crate::iblt::*;
|
||||
|
||||
#[test]
|
||||
|
@ -245,7 +234,7 @@ mod tests {
|
|||
let mut count = 64;
|
||||
const CAPACITY: usize = 4096;
|
||||
while count <= CAPACITY {
|
||||
let mut test: IBLT<CAPACITY> = IBLT::new();
|
||||
let mut test: IBLT = IBLT::new(CAPACITY);
|
||||
expected.clear();
|
||||
|
||||
for _ in 0..count {
|
||||
|
@ -277,8 +266,8 @@ mod tests {
|
|||
let mut missing: HashSet<u64> = HashSet::with_capacity(CAPACITY);
|
||||
while missing_count <= CAPACITY {
|
||||
missing.clear();
|
||||
let mut local: IBLT<CAPACITY> = IBLT::new();
|
||||
let mut remote: IBLT<CAPACITY> = IBLT::new();
|
||||
let mut local: IBLT = IBLT::new(CAPACITY);
|
||||
let mut remote: IBLT = IBLT::new(CAPACITY);
|
||||
|
||||
for k in 0..REMOTE_SIZE {
|
||||
if k >= missing_count {
|
||||
|
@ -291,7 +280,7 @@ mod tests {
|
|||
}
|
||||
|
||||
local.subtract(&mut remote);
|
||||
let bytes = local.to_bytes().len();
|
||||
let bytes = local.as_bytes().len();
|
||||
let mut cnt = 0;
|
||||
local.list(|k| {
|
||||
let k = u64::from_ne_bytes(*k);
|
||||
|
|
|
@ -13,11 +13,4 @@ pub(crate) mod iblt;
|
|||
pub mod datastore;
|
||||
pub mod node;
|
||||
pub mod host;
|
||||
|
||||
pub fn ms_since_epoch() -> i64 {
|
||||
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as i64
|
||||
}
|
||||
|
||||
pub fn ms_monotonic() -> i64 {
|
||||
std::time::Instant::now().elapsed().as_millis() as i64
|
||||
}
|
||||
pub mod utils;
|
||||
|
|
|
@ -13,6 +13,7 @@ use syncwhole::datastore::{DataStore, LoadResult, StoreResult};
|
|||
use syncwhole::host::Host;
|
||||
use syncwhole::ms_since_epoch;
|
||||
use syncwhole::node::{Node, RemoteNodeInfo};
|
||||
use syncwhole::utils::*;
|
||||
|
||||
const TEST_NODE_COUNT: usize = 16;
|
||||
const TEST_PORT_RANGE_START: u16 = 21384;
|
||||
|
@ -55,13 +56,7 @@ impl Host for TestNodeHost {
|
|||
impl DataStore for TestNodeHost {
|
||||
type LoadResultValueType = Arc<[u8]>;
|
||||
|
||||
const KEY_SIZE: usize = 64;
|
||||
const MAX_VALUE_SIZE: usize = 1024;
|
||||
const KEY_IS_COMPUTED: bool = true;
|
||||
|
||||
fn key_from_value(&self, value: &[u8], key_buffer: &mut [u8]) {
|
||||
key_buffer.copy_from_slice(Sha512::digest(value).as_slice());
|
||||
}
|
||||
|
||||
fn clock(&self) -> i64 { ms_since_epoch() }
|
||||
|
||||
|
@ -73,7 +68,7 @@ impl DataStore for TestNodeHost {
|
|||
|
||||
fn store(&self, key: &[u8], value: &[u8]) -> StoreResult {
|
||||
assert_eq!(key.len(), 64);
|
||||
let mut res = StoreResult::Ok;
|
||||
let mut res = StoreResult::Ok(0);
|
||||
self.db.lock().unwrap().entry(key.try_into().unwrap()).and_modify(|e| {
|
||||
if e.as_ref().eq(value) {
|
||||
res = StoreResult::Duplicate;
|
||||
|
@ -86,14 +81,14 @@ impl DataStore for TestNodeHost {
|
|||
res
|
||||
}
|
||||
|
||||
fn count(&self, _: i64, key_range_start: &[u8], key_range_end: &[u8]) -> u64 {
|
||||
self.db.lock().unwrap().range((Included(key_range_start.try_into().unwrap()), Included(key_range_end.try_into().unwrap()))).count() as u64
|
||||
}
|
||||
|
||||
fn total_count(&self) -> u64 { self.db.lock().unwrap().len() as u64 }
|
||||
|
||||
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, _: i64, key_prefix: &[u8], mut f: F) {
|
||||
let mut r_start = [0_u8; Self::KEY_SIZE];
|
||||
let mut r_end = [0xff_u8; Self::KEY_SIZE];
|
||||
(&mut r_start[0..key_prefix.len()]).copy_from_slice(key_prefix);
|
||||
(&mut r_end[0..key_prefix.len()]).copy_from_slice(key_prefix);
|
||||
for (k, v) in self.db.lock().unwrap().range((Included(r_start), Included(r_end))) {
|
||||
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, _: i64, key_range_start: &[u8], key_range_end: &[u8], mut f: F) {
|
||||
for (k, v) in self.db.lock().unwrap().range((Included(key_range_start.try_into().unwrap()), Included(key_range_end.try_into().unwrap()))) {
|
||||
if !f(k, v.as_ref()) {
|
||||
break;
|
||||
}
|
||||
|
@ -126,11 +121,13 @@ fn main() {
|
|||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
/*
|
||||
let mut count = 0;
|
||||
for n in nodes.iter() {
|
||||
count += n.connection_count().await;
|
||||
}
|
||||
println!("{}", count);
|
||||
*/
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -13,7 +13,6 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
|||
use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering};
|
||||
use std::time::SystemTime;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
|
@ -21,12 +20,13 @@ use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
|||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::{Instant, Duration};
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
use crate::datastore::DataStore;
|
||||
use crate::datastore::*;
|
||||
use crate::host::Host;
|
||||
use crate::ms_monotonic;
|
||||
use crate::iblt::IBLT;
|
||||
use crate::protocol::*;
|
||||
use crate::utils::*;
|
||||
use crate::varint;
|
||||
|
||||
/// Inactivity timeout for connections in milliseconds.
|
||||
|
@ -56,8 +56,11 @@ pub struct RemoteNodeInfo {
|
|||
/// Explicitly advertised remote addresses supplied by remote node (not necessarily verified).
|
||||
pub explicit_addresses: Vec<SocketAddr>,
|
||||
|
||||
/// Time TCP connection was established.
|
||||
pub connect_time: SystemTime,
|
||||
/// Time TCP connection was established (ms since epoch).
|
||||
pub connect_time: i64,
|
||||
|
||||
/// Time TCP connection was estaablished (ms, monotonic).
|
||||
pub connect_instant: i64,
|
||||
|
||||
/// True if this is an inbound TCP connection.
|
||||
pub inbound: bool,
|
||||
|
@ -101,7 +104,7 @@ impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
|
|||
datastore: db.clone(),
|
||||
host: host.clone(),
|
||||
connections: Mutex::new(HashMap::with_capacity(64)),
|
||||
bind_address
|
||||
bind_address,
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
|
@ -115,7 +118,6 @@ impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
|
|||
|
||||
pub fn host(&self) -> &Arc<H> { &self.internal.host }
|
||||
|
||||
#[inline(always)]
|
||||
pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> {
|
||||
self.internal.clone().connect(endpoint, Instant::now().add(Duration::from_millis(CONNECTION_TIMEOUT as u64))).await
|
||||
}
|
||||
|
@ -142,14 +144,22 @@ impl<D: DataStore + 'static, H: Host + 'static> Drop for Node<D, H> {
|
|||
}
|
||||
|
||||
pub struct NodeInternal<D: DataStore + 'static, H: Host + 'static> {
|
||||
// Secret used to perform HMAC to detect and drop loopback connections to self.
|
||||
anti_loopback_secret: [u8; 64],
|
||||
|
||||
// Outside code implementations of DataStore and Host traits.
|
||||
datastore: Arc<D>,
|
||||
host: Arc<H>,
|
||||
|
||||
// Connections and their task join handles, by remote endpoint address.
|
||||
connections: Mutex<HashMap<SocketAddr, (Arc<Connection>, Option<JoinHandle<std::io::Result<()>>>)>>,
|
||||
|
||||
// Local address to which this node is bound
|
||||
bind_address: SocketAddr,
|
||||
}
|
||||
|
||||
impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
||||
/// Loop that constantly runs in the background to do cleanup and service things.
|
||||
async fn housekeeping_task_main(self: Arc<Self>) {
|
||||
let mut last_status_sent = ms_monotonic();
|
||||
let mut tasks: Vec<JoinHandle<()>> = Vec::new();
|
||||
|
@ -173,8 +183,8 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
let status = if (now - last_status_sent) >= STATUS_INTERVAL {
|
||||
last_status_sent = now;
|
||||
Some(msg::Status {
|
||||
record_count: self.datastore.total_count(),
|
||||
clock: self.datastore.clock() as u64
|
||||
total_record_count: self.datastore.total_count(),
|
||||
reference_time: self.datastore.clock()
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
@ -191,6 +201,8 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
let sa = sa.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
if cc.info.lock().await.initialized {
|
||||
// This almost always completes instantly due to queues, but add a timeout in case connection
|
||||
// is stalled. In this case the result is a closed connection.
|
||||
if !tokio::time::timeout_at(sleep_until, cc.send_obj(MESSAGE_TYPE_STATUS, &status, now)).await.map_or(false, |r| r.is_ok()) {
|
||||
let _ = self2.connections.lock().await.remove(&sa).map(|c| c.1.map(|j| j.abort()));
|
||||
self2.host.on_connection_closed(&*cc.info.lock().await, "write overflow (timeout)".to_string());
|
||||
|
@ -198,7 +210,7 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
}
|
||||
}));
|
||||
});
|
||||
true
|
||||
true // keep connection
|
||||
} else {
|
||||
let _ = c.1.take().map(|j| j.abort());
|
||||
let host = self.host.clone();
|
||||
|
@ -206,7 +218,7 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
tasks.push(tokio::spawn(async move {
|
||||
host.on_connection_closed(&*cc.info.lock().await, "timeout".to_string());
|
||||
}));
|
||||
false
|
||||
false // discard connection
|
||||
}
|
||||
} else {
|
||||
let host = self.host.clone();
|
||||
|
@ -225,7 +237,7 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
host.on_connection_closed(&*cc.info.lock().await, "remote host closed connection".to_string());
|
||||
}
|
||||
}));
|
||||
false
|
||||
false // discard connection
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -243,7 +255,7 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
}
|
||||
|
||||
// Try to connect to more peers until desired connection count is reached.
|
||||
let desired_connection_count = self.host.desired_connection_count();
|
||||
let desired_connection_count = self.host.desired_connection_count().min(self.host.max_connection_count());
|
||||
while connected_to_addresses.len() < desired_connection_count {
|
||||
let sa = self.host.another_peer(&connected_to_addresses);
|
||||
if sa.is_some() {
|
||||
|
@ -270,133 +282,36 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Incoming TCP acceptor task.
|
||||
async fn listener_task_main(self: Arc<Self>, listener: TcpListener) {
|
||||
loop {
|
||||
let socket = listener.accept().await;
|
||||
if socket.is_ok() {
|
||||
let (stream, endpoint) = socket.unwrap();
|
||||
let num_conn = self.connections.lock().await.len();
|
||||
if num_conn < self.host.max_connection_count() || self.host.fixed_peers().contains(&endpoint) {
|
||||
Self::connection_start(&self, endpoint, stream, true).await;
|
||||
let (stream, address) = socket.unwrap();
|
||||
if self.host.allow(&address) {
|
||||
if self.connections.lock().await.len() < self.host.max_connection_count() || self.host.fixed_peers().contains(&address) {
|
||||
Self::connection_start(&self, address, stream, true).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
async fn connection_io_task_main(self: Arc<Self>, connection: &Arc<Connection>, mut reader: BufReader<OwnedReadHalf>) -> std::io::Result<()> {
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.resize(4096, 0);
|
||||
|
||||
let mut anti_loopback_challenge_sent = [0_u8; 64];
|
||||
let mut challenge_sent = [0_u8; 64];
|
||||
self.host.get_secure_random(&mut anti_loopback_challenge_sent);
|
||||
self.host.get_secure_random(&mut challenge_sent);
|
||||
connection.send_obj(MESSAGE_TYPE_INIT, &msg::Init {
|
||||
anti_loopback_challenge: &anti_loopback_challenge_sent,
|
||||
challenge: &challenge_sent,
|
||||
domain: self.datastore.domain().to_string(),
|
||||
key_size: D::KEY_SIZE as u16,
|
||||
max_value_size: D::MAX_VALUE_SIZE as u64,
|
||||
node_name: self.host.name().map(|n| n.to_string()),
|
||||
node_contact: self.host.contact().map(|c| c.to_string()),
|
||||
locally_bound_port: self.bind_address.port(),
|
||||
explicit_ipv4: None,
|
||||
explicit_ipv6: None
|
||||
}, ms_monotonic()).await?;
|
||||
|
||||
let mut init_received = false;
|
||||
loop {
|
||||
reader.read_exact(&mut buf.as_mut_slice()[0..1]).await?;
|
||||
let message_type = unsafe { *buf.get_unchecked(0) };
|
||||
let message_size = varint::read_async(&mut reader).await?;
|
||||
let header_size = 1 + message_size.1;
|
||||
let message_size = message_size.0;
|
||||
if message_size > (D::MAX_VALUE_SIZE + ((D::KEY_SIZE + 10) * 256) + 65536) as u64 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large"));
|
||||
}
|
||||
|
||||
let now = ms_monotonic();
|
||||
connection.last_receive_time.store(now, Ordering::Relaxed);
|
||||
|
||||
match message_type {
|
||||
|
||||
MESSAGE_TYPE_INIT => {
|
||||
if init_received {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init"));
|
||||
}
|
||||
|
||||
let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
|
||||
|
||||
if !msg.domain.as_str().eq(self.datastore.domain()) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("data set domain mismatch: '{}' != '{}'", msg.domain, self.datastore.domain())));
|
||||
}
|
||||
if msg.key_size != D::KEY_SIZE as u16 || msg.max_value_size > D::MAX_VALUE_SIZE as u64 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set key/value sizing mismatch"));
|
||||
}
|
||||
|
||||
let (anti_loopback_response, challenge_response) = {
|
||||
let mut info = connection.info.lock().await;
|
||||
info.node_name = msg.node_name.clone();
|
||||
info.node_contact = msg.node_contact.clone();
|
||||
let _ = msg.explicit_ipv4.map(|pv4| {
|
||||
info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port)));
|
||||
});
|
||||
let _ = msg.explicit_ipv6.map(|pv6| {
|
||||
info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0)));
|
||||
});
|
||||
|
||||
let challenge_response = self.host.authenticate(&info, msg.challenge);
|
||||
if challenge_response.is_none() {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped"));
|
||||
}
|
||||
(H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), challenge_response.unwrap())
|
||||
};
|
||||
|
||||
connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse {
|
||||
anti_loopback_response: &anti_loopback_response,
|
||||
challenge_response: &challenge_response
|
||||
}, now).await?;
|
||||
|
||||
init_received = true;
|
||||
},
|
||||
|
||||
MESSAGE_TYPE_INIT_RESPONSE => {
|
||||
let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
|
||||
|
||||
let mut info = connection.info.lock().await;
|
||||
if info.initialized {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init response"));
|
||||
}
|
||||
info.initialized = true;
|
||||
let info = info.clone();
|
||||
|
||||
if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self"));
|
||||
}
|
||||
if !self.host.authenticate(&info, &challenge_sent).map_or(false, |cr| msg.challenge_response.eq(&cr)) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed"));
|
||||
}
|
||||
|
||||
self.host.on_connect(&info);
|
||||
},
|
||||
|
||||
_ => {
|
||||
// Skip messages that aren't recognized or don't need to be parsed.
|
||||
let mut remaining = message_size as usize;
|
||||
while remaining > 0 {
|
||||
let s = remaining.min(buf.len());
|
||||
reader.read_exact(&mut buf.as_mut_slice()[0..s]).await?;
|
||||
remaining -= s;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
connection.bytes_received.fetch_add((header_size as u64) + message_size, Ordering::Relaxed);
|
||||
/// Initiate an outgoing connection with a deadline based timeout.
|
||||
async fn connect(self: Arc<Self>, address: &SocketAddr, deadline: Instant) -> std::io::Result<bool> {
|
||||
self.host.on_connect_attempt(address);
|
||||
let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
|
||||
configure_tcp_socket(&stream)?;
|
||||
stream.bind(self.bind_address.clone())?;
|
||||
let stream = tokio::time::timeout_at(deadline, stream.connect(address.clone())).await;
|
||||
if stream.is_ok() {
|
||||
Ok(self.connection_start(address.clone(), stream.unwrap()?, false).await)
|
||||
} else {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "connect timed out"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets up and spawns the task for a new TCP connection whether inbound or outbound.
|
||||
async fn connection_start(self: &Arc<Self>, address: SocketAddr, stream: TcpStream, inbound: bool) -> bool {
|
||||
let mut ok = false;
|
||||
let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| {
|
||||
|
@ -415,7 +330,8 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
node_contact: None,
|
||||
remote_address: address.clone(),
|
||||
explicit_addresses: Vec::new(),
|
||||
connect_time: SystemTime::now(),
|
||||
connect_time: ms_since_epoch(),
|
||||
connect_instant: ms_monotonic(),
|
||||
inbound,
|
||||
initialized: false
|
||||
}),
|
||||
|
@ -431,25 +347,322 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
ok
|
||||
}
|
||||
|
||||
async fn connect(self: Arc<Self>, address: &SocketAddr, deadline: Instant) -> std::io::Result<bool> {
|
||||
self.host.on_connect_attempt(address);
|
||||
let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
|
||||
configure_tcp_socket(&stream)?;
|
||||
stream.bind(self.bind_address.clone())?;
|
||||
let stream = tokio::time::timeout_at(deadline, stream.connect(address.clone())).await;
|
||||
if stream.is_ok() {
|
||||
Ok(self.connection_start(address.clone(), stream.unwrap()?, false).await)
|
||||
} else {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "connect timed out"))
|
||||
/// Main I/O task launched for each connection.
|
||||
///
|
||||
/// This handles reading from the connection and reacting to what it sends. Killing this
|
||||
/// task is done when the connection is closed.
|
||||
async fn connection_io_task_main(self: Arc<Self>, connection: &Arc<Connection>, mut reader: BufReader<OwnedReadHalf>) -> std::io::Result<()> {
|
||||
let mut anti_loopback_challenge_sent = [0_u8; 64];
|
||||
let mut domain_challenge_sent = [0_u8; 64];
|
||||
let mut auth_challenge_sent = [0_u8; 64];
|
||||
self.host.get_secure_random(&mut anti_loopback_challenge_sent);
|
||||
self.host.get_secure_random(&mut domain_challenge_sent);
|
||||
self.host.get_secure_random(&mut auth_challenge_sent);
|
||||
connection.send_obj(MESSAGE_TYPE_INIT, &msg::Init {
|
||||
anti_loopback_challenge: &anti_loopback_challenge_sent,
|
||||
domain_challenge: &domain_challenge_sent,
|
||||
auth_challenge: &auth_challenge_sent,
|
||||
node_name: self.host.name().map(|n| n.to_string()),
|
||||
node_contact: self.host.contact().map(|c| c.to_string()),
|
||||
locally_bound_port: self.bind_address.port(),
|
||||
explicit_ipv4: None,
|
||||
explicit_ipv6: None
|
||||
}, ms_monotonic()).await?;
|
||||
|
||||
let mut initialized = false;
|
||||
let background_tasks = AsyncTaskReaper::new();
|
||||
let mut init_received = false;
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.resize(4096, 0);
|
||||
loop {
|
||||
let message_type = reader.read_u8().await?;
|
||||
let message_size = varint::read_async(&mut reader).await?;
|
||||
let header_size = 1 + message_size.1;
|
||||
let message_size = message_size.0;
|
||||
if message_size > (D::MAX_VALUE_SIZE + ((D::KEY_SIZE + 10) * 256) + 65536) as u64 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large"));
|
||||
}
|
||||
|
||||
let now = ms_monotonic();
|
||||
connection.last_receive_time.store(now, Ordering::Relaxed);
|
||||
|
||||
match message_type {
|
||||
|
||||
MESSAGE_TYPE_INIT => {
|
||||
if init_received {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init"));
|
||||
}
|
||||
|
||||
let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
|
||||
let (anti_loopback_response, domain_challenge_response, auth_challenge_response) = {
|
||||
let mut info = connection.info.lock().await;
|
||||
info.node_name = msg.node_name.clone();
|
||||
info.node_contact = msg.node_contact.clone();
|
||||
let _ = msg.explicit_ipv4.map(|pv4| {
|
||||
info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port)));
|
||||
});
|
||||
let _ = msg.explicit_ipv6.map(|pv6| {
|
||||
info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0)));
|
||||
});
|
||||
|
||||
let auth_challenge_response = self.host.authenticate(&info, msg.auth_challenge);
|
||||
if auth_challenge_response.is_none() {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped"));
|
||||
}
|
||||
(
|
||||
H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge),
|
||||
H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), msg.domain_challenge),
|
||||
auth_challenge_response.unwrap()
|
||||
)
|
||||
};
|
||||
|
||||
connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse {
|
||||
anti_loopback_response: &anti_loopback_response,
|
||||
domain_response: &domain_challenge_response,
|
||||
auth_response: &auth_challenge_response
|
||||
}, now).await?;
|
||||
|
||||
init_received = true;
|
||||
},
|
||||
|
||||
MESSAGE_TYPE_INIT_RESPONSE => {
|
||||
let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
|
||||
|
||||
let mut info = connection.info.lock().await;
|
||||
if info.initialized {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init response"));
|
||||
}
|
||||
info.initialized = true;
|
||||
let info = info.clone();
|
||||
|
||||
if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self"));
|
||||
}
|
||||
if msg.domain_response.eq(&H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), &domain_challenge_sent)) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "domain mismatch"));
|
||||
}
|
||||
if !self.host.authenticate(&info, &auth_challenge_sent).map_or(false, |cr| msg.auth_response.eq(&cr)) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed"));
|
||||
}
|
||||
|
||||
self.host.on_connect(&info);
|
||||
initialized = true;
|
||||
},
|
||||
|
||||
_ => {
|
||||
if !initialized {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "init exchange must be completed before other messages are sent"));
|
||||
}
|
||||
|
||||
match message_type {
|
||||
|
||||
MESSAGE_TYPE_STATUS => {
|
||||
let msg: msg::Status = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
|
||||
self.connection_request_summary(connection, msg.total_record_count, now, msg.reference_time).await?;
|
||||
},
|
||||
|
||||
MESSAGE_TYPE_GET_SUMMARY => {
|
||||
//let msg: msg::GetSummary = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
|
||||
},
|
||||
|
||||
MESSAGE_TYPE_SUMMARY => {
|
||||
let mut remaining = message_size as isize;
|
||||
|
||||
// Read summary header.
|
||||
let summary_header_size = varint::read_async(&mut reader).await?;
|
||||
remaining -= summary_header_size.1 as isize;
|
||||
let summary_header_size = summary_header_size.0;
|
||||
if (summary_header_size as i64) > (remaining as i64) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid summary header"));
|
||||
}
|
||||
let summary_header: msg::SummaryHeader = connection.read_obj(&mut reader, &mut buf, summary_header_size as usize).await?;
|
||||
remaining -= summary_header_size as isize;
|
||||
|
||||
// Read and evaluate summary that we were sent.
|
||||
match summary_header.summary_type {
|
||||
SUMMARY_TYPE_KEYS => {
|
||||
self.connection_receive_and_process_remote_hash_list(
|
||||
connection,
|
||||
remaining,
|
||||
&mut reader,
|
||||
now,
|
||||
summary_header.reference_time,
|
||||
&summary_header.prefix[0..summary_header.prefix.len().min((summary_header.prefix_bits / 8) as usize)]).await?;
|
||||
},
|
||||
SUMMARY_TYPE_IBLT => {
|
||||
//let summary = IBLT::new_from_reader(&mut reader, remaining as usize).await?;
|
||||
},
|
||||
_ => {} // ignore unknown summary types
|
||||
}
|
||||
|
||||
// Request another summary if needed, keeping a ping-pong game going in a tight loop until synced.
|
||||
self.connection_request_summary(connection, summary_header.total_record_count, now, summary_header.reference_time).await?;
|
||||
},
|
||||
|
||||
MESSAGE_TYPE_HAVE_RECORDS => {
|
||||
let mut remaining = message_size as isize;
|
||||
let reference_time = varint::read_async(&mut reader).await?;
|
||||
remaining -= reference_time.1 as isize;
|
||||
if remaining <= 0 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid message"));
|
||||
}
|
||||
self.connection_receive_and_process_remote_hash_list(connection, remaining, &mut reader, now, reference_time.0 as i64, &[]).await?
|
||||
},
|
||||
|
||||
MESSAGE_TYPE_GET_RECORDS => {
|
||||
},
|
||||
|
||||
MESSAGE_TYPE_RECORD => {
|
||||
let value = connection.read_msg(&mut reader, &mut buf, message_size as usize).await?;
|
||||
if value.len() > D::MAX_VALUE_SIZE {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "value larger than MAX_VALUE_SIZE"));
|
||||
}
|
||||
let key = H::sha512(&[value]);
|
||||
match self.datastore.store(&key, value) {
|
||||
StoreResult::Ok(reference_time) => {
|
||||
let mut have_records_msg = [0_u8; 2 + 10 + ANNOUNCE_HASH_BYTES];
|
||||
let mut msg_len = varint::encode(&mut have_records_msg, reference_time as u64);
|
||||
have_records_msg[msg_len] = ANNOUNCE_HASH_BYTES as u8;
|
||||
msg_len += 1;
|
||||
have_records_msg[msg_len..(msg_len + ANNOUNCE_HASH_BYTES)].copy_from_slice(&key[..ANNOUNCE_HASH_BYTES]);
|
||||
msg_len += ANNOUNCE_HASH_BYTES;
|
||||
|
||||
let self2 = self.clone();
|
||||
let connection2 = connection.clone();
|
||||
background_tasks.spawn(async move {
|
||||
let connections = self2.connections.lock().await;
|
||||
let mut recipients = Vec::with_capacity(connections.len());
|
||||
for (_, c) in connections.iter() {
|
||||
if !Arc::ptr_eq(&(c.0), &connection2) {
|
||||
recipients.push(c.0.clone());
|
||||
}
|
||||
}
|
||||
drop(connections); // release lock
|
||||
|
||||
for c in recipients.iter() {
|
||||
// This typically completes instantly due to buffering, as this message is small.
|
||||
// Add a small timeout in the case that some connections are stalled. Misses will
|
||||
// not impact the overall network much.
|
||||
let _ = tokio::time::timeout(Duration::from_millis(250), c.send_msg(MESSAGE_TYPE_HAVE_RECORDS, &have_records_msg[..msg_len], now)).await;
|
||||
}
|
||||
});
|
||||
},
|
||||
StoreResult::Rejected => {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid datum received"));
|
||||
},
|
||||
_ => {} // duplicate or ignored values are just... ignored
|
||||
}
|
||||
},
|
||||
|
||||
_ => {
|
||||
// Skip messages that aren't recognized or don't need to be parsed.
|
||||
let mut remaining = message_size as usize;
|
||||
while remaining > 0 {
|
||||
let s = remaining.min(buf.len());
|
||||
reader.read_exact(&mut buf.as_mut_slice()[0..s]).await?;
|
||||
remaining -= s;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
connection.bytes_received.fetch_add((header_size as u64) + message_size, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Request a summary if needed, or do nothing if not.
|
||||
///
|
||||
/// This is where all the logic lives that determines whether to request summaries, choosing a
|
||||
/// prefix, etc. It's called when the remote node tells us its total record count with an
|
||||
/// associated reference time, which happens in status announcements and in summaries.
|
||||
async fn connection_request_summary(&self, connection: &Arc<Connection>, total_record_count: u64, now: i64, reference_time: i64) -> std::io::Result<()> {
|
||||
let my_total_record_count = self.datastore.total_count();
|
||||
if my_total_record_count < total_record_count {
|
||||
// Figure out how many bits need to be in a randomly chosen prefix to choose a slice of
|
||||
// the data set such that the set difference should be around 4096 records. This assumes
|
||||
// random distribution, which should be mostly maintained by probing prefixes at random.
|
||||
let prefix_bits = ((total_record_count - my_total_record_count) as f64) / 4096.0;
|
||||
let prefix_bits = if prefix_bits > 1.0 {
|
||||
(prefix_bits.log2().ceil() as usize).min(64)
|
||||
} else {
|
||||
0 as usize
|
||||
};
|
||||
let prefix_bytes = (prefix_bits / 8) + (((prefix_bits % 8) != 0) as usize);
|
||||
|
||||
// Generate a random prefix of this many bits (to the nearest byte).
|
||||
let mut prefix = [0_u8; 64];
|
||||
self.host.get_secure_random(&mut prefix[..prefix_bytes]);
|
||||
|
||||
// Request a set summary for this prefix, providing our own count for this prefix so
|
||||
// the remote can decide whether to send something like an IBLT or just hashes.
|
||||
let (local_range_start, local_range_end) = range_from_prefix(&prefix, prefix_bits);
|
||||
connection.send_obj(MESSAGE_TYPE_GET_SUMMARY, &msg::GetSummary {
|
||||
reference_time,
|
||||
prefix: &prefix[..prefix_bytes],
|
||||
prefix_bits: prefix_bits as u8,
|
||||
record_count: self.datastore.count(reference_time, &local_range_start, &local_range_end)
|
||||
}, now).await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a stream of record hashes (or hash prefixes) from a connection and request records we don't have.
|
||||
async fn connection_receive_and_process_remote_hash_list(&self, connection: &Arc<Connection>, mut remaining: isize, reader: &mut BufReader<OwnedReadHalf>, now: i64, reference_time: i64, common_prefix: &[u8]) -> std::io::Result<()> {
|
||||
if remaining > 0 {
|
||||
// Hash list is prefaced by the number of bytes in each hash, since whole 64 byte hashes do not have to be sent.
|
||||
let prefix_entry_size = reader.read_u8().await? as usize;
|
||||
let total_prefix_size = common_prefix.len() + prefix_entry_size;
|
||||
|
||||
if prefix_entry_size > 0 && total_prefix_size <= 64 {
|
||||
remaining -= 1;
|
||||
if remaining >= (prefix_entry_size as isize) {
|
||||
let mut get_records_msg: Vec<u8> = Vec::with_capacity(((remaining as usize) / prefix_entry_size) * total_prefix_size);
|
||||
varint::write(&mut get_records_msg, reference_time as u64)?;
|
||||
get_records_msg.push(total_prefix_size as u8);
|
||||
|
||||
let mut key_prefix_buf = [0_u8; 64];
|
||||
key_prefix_buf[..common_prefix.len()].copy_from_slice(common_prefix);
|
||||
|
||||
while remaining >= (prefix_entry_size as isize) {
|
||||
remaining -= prefix_entry_size as isize;
|
||||
reader.read_exact(&mut key_prefix_buf[common_prefix.len()..total_prefix_size]).await?;
|
||||
|
||||
if if total_prefix_size < 64 {
|
||||
let (s, e) = range_from_prefix(&key_prefix_buf[..total_prefix_size], total_prefix_size * 8);
|
||||
self.datastore.count(reference_time, &s, &e) == 0
|
||||
} else {
|
||||
!self.datastore.contains(reference_time, &key_prefix_buf)
|
||||
} {
|
||||
let _ = get_records_msg.write_all(&key_prefix_buf[..total_prefix_size]);
|
||||
}
|
||||
}
|
||||
|
||||
if remaining == 0 {
|
||||
return connection.send_msg(MESSAGE_TYPE_GET_RECORDS, get_records_msg.as_slice(), now).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hash list"));
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataStore + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
|
||||
fn drop(&mut self) {
|
||||
for (_, c) in self.connections.blocking_lock().drain() {
|
||||
c.1.map(|c| c.abort());
|
||||
}
|
||||
let _ = tokio::runtime::Handle::try_current().map_or_else(|_| {
|
||||
for (_, c) in self.connections.blocking_lock().drain() {
|
||||
c.1.map(|c| c.abort());
|
||||
}
|
||||
}, |h| {
|
||||
let _ = h.block_on(async {
|
||||
for (_, c) in self.connections.lock().await.drain() {
|
||||
c.1.map(|c| c.abort());
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -464,11 +677,18 @@ struct Connection {
|
|||
}
|
||||
|
||||
impl Connection {
|
||||
async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> {
|
||||
self.writer.lock().await.write_all(data).await?;
|
||||
self.last_send_time.store(now, Ordering::Relaxed);
|
||||
self.bytes_sent.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||
Ok(())
|
||||
async fn send_msg(&self, message_type: u8, data: &[u8], now: i64) -> std::io::Result<()> {
|
||||
let mut type_and_size = [0_u8; 16];
|
||||
type_and_size[0] = message_type;
|
||||
let tslen = 1 + varint::encode(&mut type_and_size[1..], data.len() as u64) as usize;
|
||||
let total_size = tslen + data.len();
|
||||
if self.writer.lock().await.write_vectored(&[IoSlice::new(&type_and_size[..tslen]), IoSlice::new(data)]).await? == total_size {
|
||||
self.last_send_time.store(now, Ordering::Relaxed);
|
||||
self.bytes_sent.fetch_add(total_size as u64, Ordering::Relaxed);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_obj<O: Serialize>(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> {
|
||||
|
|
|
@ -18,63 +18,63 @@ pub const MESSAGE_TYPE_INIT_RESPONSE: u8 = 2;
|
|||
/// Sent every few seconds to notify peers of number of records, clock, etc.
|
||||
pub const MESSAGE_TYPE_STATUS: u8 = 3;
|
||||
|
||||
/// Get a set summary of a prefix in the data set.
|
||||
pub const MESSAGE_TYPE_GET_SUMMARY: u8 = 4;
|
||||
|
||||
/// Set summary of a prefix.
|
||||
pub const MESSAGE_TYPE_SUMMARY: u8 = 5;
|
||||
|
||||
/// Payload is a list of keys of records. Usually sent to advertise recently received new records.
|
||||
pub const MESSAGE_TYPE_HAVE_RECORDS: u8 = 4;
|
||||
pub const MESSAGE_TYPE_HAVE_RECORDS: u8 = 6;
|
||||
|
||||
/// Payload is a list of keys of records the sending node wants.
|
||||
pub const MESSAGE_TYPE_GET_RECORDS: u8 = 5;
|
||||
pub const MESSAGE_TYPE_GET_RECORDS: u8 = 7;
|
||||
|
||||
/// Payload is a record, with key being omitted if the data store's KEY_IS_COMPUTED constant is true.
|
||||
pub const MESSAGE_TYPE_RECORD: u8 = 6;
|
||||
pub const MESSAGE_TYPE_RECORD: u8 = 8;
|
||||
|
||||
/// Summary type: simple array of keys under the given prefix.
|
||||
pub const SUMMARY_TYPE_KEYS: u8 = 0;
|
||||
|
||||
/// An IBLT set summary.
|
||||
pub const SUMMARY_TYPE_IBLT: u8 = 1;
|
||||
|
||||
/// Number of bytes of each SHA512 hash to announce, request, etc. This is okay to change but 16 is plenty.
|
||||
pub const ANNOUNCE_HASH_BYTES: usize = 16;
|
||||
|
||||
pub mod msg {
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct IPv4 {
|
||||
#[serde(rename = "i")]
|
||||
pub ip: [u8; 4],
|
||||
#[serde(rename = "p")]
|
||||
pub port: u16
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct IPv6 {
|
||||
#[serde(rename = "i")]
|
||||
pub ip: [u8; 16],
|
||||
#[serde(rename = "p")]
|
||||
pub port: u16
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Init<'a> {
|
||||
/// A random challenge to be hashed with a secret to detect and drop connections to self.
|
||||
#[serde(rename = "alc")]
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub anti_loopback_challenge: &'a [u8],
|
||||
|
||||
/// A random challenge for checking the data set domain.
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub domain_challenge: &'a [u8],
|
||||
|
||||
/// A random challenge for login/authentication.
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub challenge: &'a [u8],
|
||||
|
||||
/// An arbitrary name for this data set to avoid connecting to peers not replicating it.
|
||||
#[serde(rename = "d")]
|
||||
pub domain: String,
|
||||
|
||||
/// Size of keys in this data set in bytes.
|
||||
#[serde(rename = "ks")]
|
||||
pub key_size: u16,
|
||||
|
||||
/// Maximum allowed size of values in this data set in bytes.
|
||||
#[serde(rename = "mvs")]
|
||||
pub max_value_size: u64,
|
||||
pub auth_challenge: &'a [u8],
|
||||
|
||||
/// Optional name to advertise for this node.
|
||||
#[serde(rename = "nn")]
|
||||
pub node_name: Option<String>,
|
||||
|
||||
/// Optional contact information for this node, such as a URL or an e-mail address.
|
||||
#[serde(rename = "nc")]
|
||||
pub node_contact: Option<String>,
|
||||
|
||||
/// Port to which this node has locally bound.
|
||||
|
@ -83,35 +83,85 @@ pub mod msg {
|
|||
|
||||
/// An IPv4 address where this node can be reached.
|
||||
/// If both explicit_ipv4 and explicit_ipv6 are omitted the physical source IP:port may be used.
|
||||
#[serde(rename = "ei4")]
|
||||
pub explicit_ipv4: Option<IPv4>,
|
||||
|
||||
/// An IPv6 address where this node can be reached.
|
||||
/// If both explicit_ipv4 and explicit_ipv6 are omitted the physical source IP:port may be used.
|
||||
#[serde(rename = "ei6")]
|
||||
pub explicit_ipv6: Option<IPv6>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct InitResponse<'a> {
|
||||
/// HMAC-SHA512(local secret, anti_loopback_challenge) to detect and drop loops.
|
||||
#[serde(rename = "alr")]
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub anti_loopback_response: &'a [u8],
|
||||
|
||||
/// HMAC-SHA512(SHA512(domain), domain_challenge) to check that the data set domain matches.
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub domain_response: &'a [u8],
|
||||
|
||||
/// HMAC-SHA512(secret, challenge) for authentication. (If auth is not enabled, an all-zero secret is used.)
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub challenge_response: &'a [u8],
|
||||
pub auth_response: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct Status {
|
||||
/// Total number of records in data set.
|
||||
#[serde(rename = "rc")]
|
||||
pub record_count: u64,
|
||||
|
||||
/// Local wall clock time in milliseconds since Unix epoch.
|
||||
#[serde(rename = "c")]
|
||||
pub clock: u64,
|
||||
pub total_record_count: u64,
|
||||
|
||||
/// Reference wall clock time in milliseconds since Unix epoch.
|
||||
#[serde(rename = "t")]
|
||||
pub reference_time: i64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct GetSummary<'a> {
|
||||
/// Reference wall clock time in milliseconds since Unix epoch.
|
||||
#[serde(rename = "t")]
|
||||
pub reference_time: i64,
|
||||
|
||||
/// Prefix within key space.
|
||||
#[serde(rename = "p")]
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub prefix: &'a [u8],
|
||||
|
||||
/// Length of prefix in bits (trailing bits in byte array are ignored).
|
||||
#[serde(rename = "b")]
|
||||
pub prefix_bits: u8,
|
||||
|
||||
/// Number of records in this range the requesting node already has, used to choose summary type.
|
||||
#[serde(rename = "r")]
|
||||
pub record_count: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct SummaryHeader<'a> {
|
||||
/// Total number of records in data set, for easy rapid generation of next query.
|
||||
#[serde(rename = "c")]
|
||||
pub total_record_count: u64,
|
||||
|
||||
/// Reference wall clock time in milliseconds since Unix epoch.
|
||||
#[serde(rename = "t")]
|
||||
pub reference_time: i64,
|
||||
|
||||
/// Random salt value used by some summary types.
|
||||
#[serde(rename = "x")]
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub salt: &'a [u8],
|
||||
|
||||
/// Prefix within key space.
|
||||
#[serde(rename = "p")]
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub prefix: &'a [u8],
|
||||
|
||||
/// Length of prefix in bits (trailing bits in byte array are ignored).
|
||||
#[serde(rename = "b")]
|
||||
pub prefix_bits: u8,
|
||||
|
||||
/// Type of summary that follows this header.
|
||||
#[serde(rename = "s")]
|
||||
pub summary_type: u8,
|
||||
}
|
||||
}
|
||||
|
|
102
syncwhole/src/utils.rs
Normal file
102
syncwhole/src/utils.rs
Normal file
|
@ -0,0 +1,102 @@
|
|||
/* This Source Code Form is subject to the terms of the Mozilla Public
|
||||
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
*
|
||||
* (c)2022 ZeroTier, Inc.
|
||||
* https://www.zerotier.com/
|
||||
*/
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Get the real time clock in milliseconds since Unix epoch.
|
||||
pub fn ms_since_epoch() -> i64 {
|
||||
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as i64
|
||||
}
|
||||
|
||||
/// Get the current monotonic clock in milliseconds.
|
||||
pub fn ms_monotonic() -> i64 {
|
||||
std::time::Instant::now().elapsed().as_millis() as i64
|
||||
}
|
||||
|
||||
/// Encode a byte slice to a hexadecimal string.
|
||||
pub fn to_hex_string(b: &[u8]) -> String {
|
||||
const HEX_CHARS: [u8; 16] = [b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', b'b', b'c', b'd', b'e', b'f'];
|
||||
let mut s = String::new();
|
||||
s.reserve(b.len() * 2);
|
||||
for c in b {
|
||||
let x = *c as usize;
|
||||
s.push(HEX_CHARS[x >> 4] as char);
|
||||
s.push(HEX_CHARS[x & 0xf] as char);
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn xorshift64(mut x: u64) -> u64 {
|
||||
x = u64::from_le(x);
|
||||
x ^= x.wrapping_shl(13);
|
||||
x ^= x.wrapping_shr(7);
|
||||
x ^= x.wrapping_shl(17);
|
||||
x.to_le()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn splitmix64(mut x: u64) -> u64 {
|
||||
x = u64::from_le(x);
|
||||
x ^= x.wrapping_shr(30);
|
||||
x = x.wrapping_mul(0xbf58476d1ce4e5b9);
|
||||
x ^= x.wrapping_shr(27);
|
||||
x = x.wrapping_mul(0x94d049bb133111eb);
|
||||
x ^= x.wrapping_shr(31);
|
||||
x.to_le()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn splitmix64_inverse(mut x: u64) -> u64 {
|
||||
x = u64::from_le(x);
|
||||
x ^= x.wrapping_shr(31) ^ x.wrapping_shr(62);
|
||||
x = x.wrapping_mul(0x319642b2d24d8ec3);
|
||||
x ^= x.wrapping_shr(27) ^ x.wrapping_shr(54);
|
||||
x = x.wrapping_mul(0x96de1b173f119089);
|
||||
x ^= x.wrapping_shr(30) ^ x.wrapping_shr(60);
|
||||
x.to_le()
|
||||
}
|
||||
|
||||
/// Wrapper for tokio::spawn() that aborts tasks not yet completed when it is dropped.
|
||||
pub struct AsyncTaskReaper {
|
||||
ctr: AtomicUsize,
|
||||
handles: Arc<std::sync::Mutex<HashMap<usize, JoinHandle<()>>>>,
|
||||
}
|
||||
|
||||
impl AsyncTaskReaper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
ctr: AtomicUsize::new(0),
|
||||
handles: Arc::new(std::sync::Mutex::new(HashMap::new()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn a new task.
|
||||
/// Note that currently any output is ignored. This is primarily for background tasks
|
||||
/// that are used similarly to goroutines in Go.
|
||||
pub fn spawn<F: Future + Send + 'static>(&self, future: F) {
|
||||
let id = self.ctr.fetch_add(1, Ordering::Relaxed);
|
||||
let handles = self.handles.clone();
|
||||
self.handles.lock().unwrap().insert(id, tokio::spawn(async move {
|
||||
let _ = future.await;
|
||||
let _ = handles.lock().unwrap().remove(&id);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AsyncTaskReaper {
|
||||
fn drop(&mut self) {
|
||||
for (_, h) in self.handles.lock().unwrap().iter() {
|
||||
h.abort();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -36,17 +36,11 @@ pub async fn write_async<W: AsyncWrite + Unpin>(w: &mut W, v: u64) -> std::io::R
|
|||
|
||||
pub async fn read_async<R: AsyncRead + Unpin>(r: &mut R) -> std::io::Result<(u64, usize)> {
|
||||
let mut v = 0_u64;
|
||||
let mut buf = [0_u8; 1];
|
||||
let mut pos = 0;
|
||||
let mut count = 0;
|
||||
loop {
|
||||
loop {
|
||||
if r.read(&mut buf).await? == 1 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
count += 1;
|
||||
let b = buf[0];
|
||||
let b = r.read_u8().await?;
|
||||
if b <= 0x7f {
|
||||
v |= (b as u64).wrapping_shl(pos);
|
||||
pos += 7;
|
||||
|
|
Loading…
Add table
Reference in a new issue