Loads of syncwhole work.

This commit is contained in:
Adam Ierymenko 2022-03-30 15:46:17 -04:00
parent 1913d956b3
commit aba212fd87
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
12 changed files with 647 additions and 714 deletions

3
.nova/Configuration.json Normal file
View file

@ -0,0 +1,3 @@
{
"workspace.name" : "ZeroTier"
}

8
rustfmt.toml Normal file
View file

@ -0,0 +1,8 @@
max_width = 300
use_small_heuristics = "Max"
tab_spaces = 4
newline_style = "Unix"
control_brace_style = "AlwaysSameLine"
edition = "2021"
imports_granularity = "Crate"
group_imports = "StdExternalCrate"

1
syncwhole/rustfmt.toml Symbolic link
View file

@ -0,0 +1 @@
../rustfmt.toml

View file

@ -6,21 +6,35 @@
* https://www.zerotier.com/
*/
/// Size of keys, which is the size of a 512-bit hash. This is a protocol constant.
pub const KEY_SIZE: usize = 64;
/// Minimum possible key (all zero).
pub const MIN_KEY: [u8; KEY_SIZE] = [0; KEY_SIZE];
/// Maximum possible key (all 0xff).
pub const MAX_KEY: [u8; KEY_SIZE] = [0xff; KEY_SIZE];
/// Generate a range of SHA512 hashes from a prefix and a number of bits.
/// The range will be inclusive and cover all keys under the prefix.
pub fn range_from_prefix(prefix: &[u8], prefix_bits: usize) -> ([u8; 64], [u8; 64]) {
let prefix_bits = prefix_bits.min(prefix.len() * 8).min(64);
let mut start = [0_u8; 64];
let mut end = [0xff_u8; 64];
pub fn range_from_prefix(prefix: &[u8], prefix_bits: usize) -> Option<([u8; KEY_SIZE], [u8; KEY_SIZE])> {
let mut start = [0_u8; KEY_SIZE];
let mut end = [0xff_u8; KEY_SIZE];
if prefix_bits > (KEY_SIZE * 8) {
return None;
}
let whole_bytes = prefix_bits / 8;
let remaining_bits = prefix_bits % 8;
if prefix.len() < (whole_bytes + ((remaining_bits != 0) as usize)) {
return None;
}
start[0..whole_bytes].copy_from_slice(&prefix[0..whole_bytes]);
end[0..whole_bytes].copy_from_slice(&prefix[0..whole_bytes]);
if remaining_bits != 0 && whole_bytes < prefix.len() {
if remaining_bits != 0 {
start[whole_bytes] |= prefix[whole_bytes];
end[whole_bytes] &= prefix[whole_bytes] | ((0xff_u8).wrapping_shr(remaining_bits as u32));
}
(start, end)
return Some((start, end));
}
/// Result returned by DataStore::load().
@ -32,15 +46,13 @@ pub enum LoadResult<V: AsRef<[u8]> + Send> {
NotFound,
/// Supplied reference_time is outside what is available (usually too old).
TimeNotAvailable
TimeNotAvailable,
}
/// Result returned by DataStore::store().
pub enum StoreResult {
/// Entry was accepted.
/// The integer included with Ok is the reference time that should be advertised.
/// If this is not a temporally subjective data set then zero can be used.
Ok(i64),
Ok,
/// Entry was a duplicate of one we already have but was otherwise valid.
Duplicate,
@ -49,7 +61,7 @@ pub enum StoreResult {
Ignored,
/// Entry was rejected as malformed or otherwise invalid (e.g. failed signature check).
Rejected
Rejected,
}
/// API to be implemented by the data set we want to replicate.
@ -73,7 +85,7 @@ pub trait DataStore: Sync + Send {
type LoadResultValueType: AsRef<[u8]> + Send;
/// Key hash size, always 64 for SHA512.
const KEY_SIZE: usize = 64;
const KEY_SIZE: usize = KEY_SIZE;
/// Maximum size of a value in bytes.
const MAX_VALUE_SIZE: usize;
@ -98,7 +110,7 @@ pub trait DataStore: Sync + Send {
fn contains(&self, reference_time: i64, key: &[u8]) -> bool {
match self.load(reference_time, key) {
LoadResult::Ok(_) => true,
_ => false
_ => false,
}
}

View file

@ -6,70 +6,50 @@
* https://www.zerotier.com/
*/
use std::collections::HashSet;
use std::net::SocketAddr;
#[cfg(feature = "include_sha2_lib")]
use sha2::digest::{Digest, FixedOutput};
use serde::{Deserialize, Serialize};
use crate::node::RemoteNodeInfo;
/// Configuration setttings for a syncwhole node.
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq)]
pub struct Config {
/// A list of peer addresses to which we always want to stay connected.
/// The library will try to maintain connectivity to these regardless of connection limits.
pub anchors: Vec<SocketAddr>,
/// A list of peer addresses that we can try in order to achieve desired_connection_count.
pub seeds: Vec<SocketAddr>,
/// The maximum number of TCP connections we should allow.
pub max_connection_count: usize,
/// The desired number of peering links.
pub desired_connection_count: usize,
/// An optional name for this node to advertise to other nodes.
pub name: String,
/// An optional contact string for this node to advertise to other nodes.
/// Example: bighead@stanford.edu or https://www.piedpiper.com/
pub contact: String,
}
/// A trait that users of syncwhole implement to provide configuration information and listen for events.
pub trait Host: Sync + Send {
/// Get a list of peer addresses to which we always want to try to stay connected.
///
/// These are always contacted until a link is established regardless of anything else.
fn fixed_peers(&self) -> &[SocketAddr];
/// Get a random peer address not in the supplied set.
///
/// The default implementation just returns None.
fn another_peer(&self, exclude: &HashSet<SocketAddr>) -> Option<SocketAddr> {
None
}
/// Get the maximum number of endpoints allowed.
///
/// This is checked on incoming connect and incoming links are refused if the total is
/// over this count. Fixed endpoints will be contacted even if the total is over this limit.
///
/// The default implementation returns 1024.
fn max_connection_count(&self) -> usize {
1024
}
/// Get the number of connections we ideally want.
///
/// Attempts will be made to lazily contact remote endpoints if the total number of links
/// is under this amount. Note that fixed endpoints will still be contacted even if the
/// total is over the desired count.
///
/// This should always be less than max_connection_count().
///
/// The default implementation returns 128.
fn desired_connection_count(&self) -> usize {
128
}
/// Get an optional name that this node should advertise.
///
/// The default implementation returns None.
fn name(&self) -> Option<&str> {
None
}
/// Get an optional contact info string that this node should advertise.
///
/// The default implementation returns None.
fn contact(&self) -> Option<&str> {
None
}
/// Get a copy of the current configuration for this syncwhole node.
fn node_config(&self) -> Config;
/// Test whether an inbound connection should be allowed from an address.
///
/// This is called on first incoming connection before any init is received. The authenticate()
/// method is called once init has been received and is another decision point. The default
/// implementation of this always returns true.
#[allow(unused_variables)]
fn allow(&self, remote_address: &SocketAddr) -> bool {
true
}
@ -83,6 +63,10 @@ pub trait Host: Sync + Send {
///
/// This actually gets called twice per link: once when Init is received to compute the
/// response, and once when InitResponse is received to verify the response to our challenge.
///
/// The default implementation authenticates with an all-zero key. Leave it this way if
/// you don't want authentication.
#[allow(unused_variables)]
fn authenticate(&self, info: &RemoteNodeInfo, challenge: &[u8]) -> Option<[u8; 64]> {
Some(Self::hmac_sha512(&[0_u8; 64], challenge))
}
@ -125,8 +109,8 @@ pub trait Host: Sync + Send {
///
/// Supplied key will always be 64 bytes in length.
///
/// The default implementation is a basic HMAC implemented in terms of sha512() above. This
/// can be specialized if the user wishes to provide their own implementation.
/// The default implementation is HMAC implemented in terms of sha512() above. Specialize
/// to provide your own implementation.
fn hmac_sha512(key: &[u8], msg: &[u8]) -> [u8; 64] {
let mut opad = [0x5c_u8; 128];
let mut ipad = [0x36_u8; 128];
@ -137,7 +121,6 @@ pub trait Host: Sync + Send {
for i in 0..64 {
ipad[i] ^= key[i];
}
let s1 = Self::sha512(&[&ipad, msg]);
Self::sha512(&[&opad, &s1])
Self::sha512(&[&opad, &Self::sha512(&[&ipad, msg])])
}
}

View file

@ -6,35 +6,19 @@
* https://www.zerotier.com/
*/
use std::mem::size_of;
use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut, write_bytes};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::utils::*;
/// Called to get the next iteration index for each KEY_MAPPING_ITERATIONS table lookup.
/// (See IBLT papers, etc.)
#[inline(always)]
fn next_iteration_index(prev_iteration_index: u64) -> u64 {
splitmix64(prev_iteration_index.wrapping_add(1))
}
#[derive(Clone, PartialEq, Eq)]
#[repr(C, packed)]
struct IBLTEntry {
key_sum: u64,
check_hash_sum: u64,
count: i32
}
impl IBLTEntry {
#[inline(always)]
fn is_singular(&self) -> bool {
if i32::from_le(self.count) == 1 || i32::from_le(self.count) == -1 {
xorshift64(self.key_sum) == self.check_hash_sum
} else {
false
}
}
fn next_iteration_index(mut x: u64) -> u64 {
x = x.wrapping_add(1);
x ^= x.wrapping_shr(30);
x = x.wrapping_mul(0xbf58476d1ce4e5b9);
x ^= x.wrapping_shr(27);
x = x.wrapping_mul(0x94d049bb133111eb);
x ^= x.wrapping_shr(31);
x
}
/// An Invertible Bloom Lookup Table for set reconciliation with 64-bit hashes.
@ -42,170 +26,151 @@ impl IBLTEntry {
/// Usage inspired by this paper:
///
/// https://dash.harvard.edu/bitstream/handle/1/14398536/GENTILI-SENIORTHESIS-2015.pdf
#[derive(Clone, PartialEq, Eq)]
pub struct IBLT {
map: Vec<IBLTEntry>
#[repr(C, packed)]
pub struct IBLT<const BUCKETS: usize> {
key: [u64; BUCKETS],
check_hash: [u64; BUCKETS],
count: [i8; BUCKETS],
}
impl IBLT {
impl<const BUCKETS: usize> IBLT<BUCKETS> {
/// This was determined to be effective via empirical testing with random keys. This
/// is a protocol constant that can't be changed without upgrading all nodes in a domain.
const KEY_MAPPING_ITERATIONS: usize = 2;
pub fn new(buckets: usize) -> Self {
assert!(buckets < i32::MAX as usize);
assert_eq!(size_of::<IBLTEntry>(), 20);
let mut iblt = Self { map: Vec::with_capacity(buckets) };
unsafe {
iblt.map.set_len(buckets);
iblt.reset();
/// Number of buckets in this IBLT.
pub const BUCKETS: usize = BUCKETS;
/// Size of this IBLT in bytes.
pub const SIZE_BYTES: usize = BUCKETS * (8 + 8 + 1);
#[inline(always)]
fn is_singular(&self, i: usize) -> bool {
let c = self.count[i];
if c == 1 || c == -1 {
xorshift64(self.key[i]) == self.check_hash[i]
} else {
false
}
iblt
}
/// Compute the size in bytes of an IBLT with the given number of buckets.
#[inline(always)]
pub fn size_bytes_with_buckets(buckets: usize) -> usize { buckets * size_of::<IBLTEntry>() }
/// Create a new zeroed IBLT.
pub fn new() -> Self {
assert_eq!(Self::SIZE_BYTES, std::mem::size_of::<Self>());
assert!(BUCKETS < (i32::MAX as usize));
unsafe { std::mem::zeroed() }
}
/// Cast a byte array to an IBLT if it is of the correct size.
pub fn ref_from_bytes(b: &[u8]) -> Option<&Self> {
if b.len() == Self::SIZE_BYTES {
Some(unsafe { &*b.as_ptr().cast() })
} else {
None
}
}
/// Compute the IBLT size in buckets to reconcile a given set difference, or return 0 if no advantage.
/// This returns zero if an IBLT would take up as much or more space than just sending local_set_size
/// hashes of hash_size_bytes.
#[inline(always)]
pub fn calc_iblt_parameters(hash_size_bytes: usize, local_set_size: u64, difference_size: u64) -> usize {
let hashes_would_be = (hash_size_bytes as f64) * (local_set_size as f64);
let buckets_should_be = (difference_size as f64) * 1.8; // factor determined experimentally for best bytes/item, can be tuned
let iblt_would_be = buckets_should_be * (size_of::<IBLTEntry>() as f64);
if iblt_would_be < hashes_would_be {
buckets_should_be.ceil() as usize
let b = (difference_size as f64) * 1.8; // factor determined experimentally for best bytes/item, can be tuned
if b > 64.0 && (b * (8.0 + 8.0 + 1.0)) < ((hash_size_bytes as f64) * (local_set_size as f64)) {
b.round() as usize
} else {
0
}
}
/// Get the size of this IBLT in buckets.
#[inline(always)]
pub fn buckets(&self) -> usize { self.map.len() }
/// Get the size of this IBLT in bytes.
pub fn size_bytes(&self) -> usize { self.map.len() * size_of::<IBLTEntry>() }
/// Zero this IBLT.
#[inline(always)]
pub fn reset(&mut self) {
unsafe { write_bytes(self.map.as_mut_ptr().cast::<u8>(), 0, self.map.len() * size_of::<IBLTEntry>()); }
unsafe {
std::ptr::write_bytes((self as *mut Self).cast::<u8>(), 0, std::mem::size_of::<Self>());
}
}
/// Get this IBLT as a byte slice in place.
#[inline(always)]
pub fn as_bytes(&self) -> &[u8] {
unsafe { &*slice_from_raw_parts(self.map.as_ptr().cast::<u8>(), self.map.len() * size_of::<IBLTEntry>()) }
unsafe { &*std::ptr::slice_from_raw_parts((self as *const Self).cast::<u8>(), std::mem::size_of::<Self>()) }
}
/// Construct an IBLT from an input reader and a size in bytes.
pub async fn new_from_reader<R: AsyncReadExt + Unpin>(r: &mut R, bytes: usize) -> std::io::Result<Self> {
assert_eq!(size_of::<IBLTEntry>(), 20);
if (bytes % size_of::<IBLTEntry>()) != 0 {
Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "incomplete or invalid IBLT"))
} else {
let buckets = bytes / size_of::<IBLTEntry>();
let mut iblt = Self { map: Vec::with_capacity(buckets) };
unsafe {
iblt.map.set_len(buckets);
r.read_exact(&mut *slice_from_raw_parts_mut(iblt.map.as_mut_ptr().cast::<u8>(), bytes)).await?;
}
Ok(iblt)
}
}
/// Write this IBLT to a stream.
/// Note that the size of the IBLT in bytes must be stored separately. Use size_bytes() to get that.
#[inline(always)]
pub async fn write<W: AsyncWriteExt + Unpin>(&self, w: &mut W) -> std::io::Result<()> {
w.write_all(self.as_bytes()).await
}
fn ins_rem(&mut self, mut key: u64, delta: i32) {
key = splitmix64(key);
fn ins_rem(&mut self, key: u64, delta: i8) {
let check_hash = xorshift64(key);
let mut iteration_index = u64::from_le(key);
let buckets = self.map.len();
for _ in 0..Self::KEY_MAPPING_ITERATIONS {
iteration_index = next_iteration_index(iteration_index);
let b = unsafe { self.map.get_unchecked_mut((iteration_index as usize) % buckets) };
b.key_sum ^= key;
b.check_hash_sum ^= check_hash;
b.count = (i32::from_le(b.count) + delta).to_le();
let i = (iteration_index as usize) % BUCKETS;
self.key[i] ^= key;
self.check_hash[i] ^= check_hash;
self.count[i] = self.count[i].wrapping_add(delta);
}
}
/// Insert a 64-bit key.
/// Panics if the key is shorter than 64 bits. If longer, bits beyond 64 are ignored.
#[inline(always)]
pub fn insert(&mut self, key: &[u8]) {
assert!(key.len() >= 8);
self.ins_rem(unsafe { u64::from_ne_bytes(*(key.as_ptr().cast::<[u8; 8]>())) }, 1);
pub fn insert(&mut self, key: u64) {
self.ins_rem(key, 1);
}
/// Remove a 64-bit key.
/// Panics if the key is shorter than 64 bits. If longer, bits beyond 64 are ignored.
#[inline(always)]
pub fn remove(&mut self, key: &[u8]) {
assert!(key.len() >= 8);
self.ins_rem(unsafe { u64::from_ne_bytes(*(key.as_ptr().cast::<[u8; 8]>())) }, -1);
pub fn remove(&mut self, key: u64) {
self.ins_rem(key, -1);
}
/// Subtract another IBLT from this one to get a set difference.
///
/// This returns true on success or false on error, which right now can only happen if the
/// other IBLT has a different number of buckets or if it contains so many entries
pub fn subtract(&mut self, other: &Self) -> bool {
if other.map.len() == self.map.len() {
for (s, o) in self.map.iter_mut().zip(other.map.iter()) {
s.key_sum ^= o.key_sum;
s.check_hash_sum ^= o.check_hash_sum;
s.count = (i32::from_le(s.count) - i32::from_le(o.count)).to_le();
}
return true;
pub fn subtract(&mut self, other: &Self) {
for i in 0..BUCKETS {
self.key[i] ^= other.key[i];
}
for i in 0..BUCKETS {
self.check_hash[i] ^= other.check_hash[i];
}
for i in 0..BUCKETS {
self.count[i] = self.count[i].wrapping_sub(other.count[i]);
}
return false;
}
/// List as many entries in this IBLT as can be extracted.
pub fn list<F: FnMut(&[u8; 8])>(mut self, mut f: F) {
let mut queue: Vec<u32> = Vec::with_capacity(self.map.len());
pub fn list<F: FnMut(u64)>(mut self, mut f: F) {
let mut queue: Vec<u32> = Vec::with_capacity(BUCKETS);
for bi in 0..self.map.len() {
if unsafe { self.map.get_unchecked(bi).is_singular() } {
queue.push(bi as u32);
for i in 0..BUCKETS {
if self.is_singular(i) {
queue.push(i as u32);
}
}
loop {
let b = queue.pop();
let b = if b.is_some() {
unsafe { self.map.get_unchecked_mut(b.unwrap() as usize) }
let i = queue.pop();
let i = if i.is_some() {
i.unwrap() as usize
} else {
break;
};
if b.is_singular() {
let key = b.key_sum;
if self.is_singular(i) {
let key = self.key[i];
f(key);
let check_hash = xorshift64(key);
let mut iteration_index = u64::from_le(key);
f(&(splitmix64_inverse(key)).to_ne_bytes());
for _ in 0..Self::KEY_MAPPING_ITERATIONS {
iteration_index = next_iteration_index(iteration_index);
let b_idx = iteration_index % (self.map.len() as u64);
let b = unsafe { self.map.get_unchecked_mut(b_idx as usize) };
b.key_sum ^= key;
b.check_hash_sum ^= check_hash;
b.count = (i32::from_le(b.count) - 1).to_le();
if b.is_singular() {
if queue.len() > self.map.len() { // sanity check for invalid IBLT
let i = (iteration_index as usize) % BUCKETS;
self.key[i] ^= key;
self.check_hash[i] ^= check_hash;
self.count[i] = self.count[i].wrapping_sub(1);
if self.is_singular(i) {
if queue.len() > BUCKETS {
// sanity check, should be impossible
return;
}
queue.push(b_idx as u32);
queue.push(i as u32);
}
}
}
@ -234,19 +199,18 @@ mod tests {
let mut count = 64;
const CAPACITY: usize = 4096;
while count <= CAPACITY {
let mut test: IBLT = IBLT::new(CAPACITY);
let mut test = IBLT::<CAPACITY>::new();
expected.clear();
for _ in 0..count {
let x = rn;
rn = splitmix64(rn);
expected.insert(x);
test.insert(&x.to_ne_bytes());
test.insert(x);
}
let mut list_count = 0;
test.list(|x| {
let x = u64::from_ne_bytes(*x);
list_count += 1;
assert!(expected.contains(&x));
});
@ -266,16 +230,16 @@ mod tests {
let mut missing: HashSet<u64> = HashSet::with_capacity(CAPACITY);
while missing_count <= CAPACITY {
missing.clear();
let mut local: IBLT = IBLT::new(CAPACITY);
let mut remote: IBLT = IBLT::new(CAPACITY);
let mut local = IBLT::<CAPACITY>::new();
let mut remote = IBLT::<CAPACITY>::new();
for k in 0..REMOTE_SIZE {
if k >= missing_count {
local.insert(&rn.to_ne_bytes());
local.insert(rn);
} else {
missing.insert(rn);
}
remote.insert(&rn.to_ne_bytes());
remote.insert(rn);
rn = splitmix64(rn);
}
@ -283,7 +247,6 @@ mod tests {
let bytes = local.as_bytes().len();
let mut cnt = 0;
local.list(|k| {
let k = u64::from_ne_bytes(*k);
assert!(missing.contains(&k));
cnt += 1;
});

View file

@ -1,22 +1,40 @@
extern crate core;
use std::collections::BTreeMap;
use std::io::{stdout, Write};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::ops::Bound::Included;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use std::time::{Duration, Instant, SystemTime};
use sha2::digest::Digest;
use sha2::Sha512;
use syncwhole::datastore::{DataStore, LoadResult, StoreResult};
use syncwhole::host::Host;
use syncwhole::ms_since_epoch;
use syncwhole::node::{Node, RemoteNodeInfo};
use syncwhole::utils::*;
const TEST_NODE_COUNT: usize = 16;
const TEST_NODE_COUNT: usize = 8;
const TEST_PORT_RANGE_START: u16 = 21384;
const TEST_STARTING_RECORDS_PER_NODE: usize = 16;
static mut RANDOM_CTR: u128 = 0;
fn get_random_bytes(mut buf: &mut [u8]) {
// This is only for testing and is not really secure.
let mut ctr = unsafe { RANDOM_CTR };
if ctr == 0 {
ctr = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos() * (1 + Instant::now().elapsed().as_nanos());
}
while !buf.is_empty() {
let l = buf.len().min(64);
ctr = ctr.wrapping_add(1);
buf[..l].copy_from_slice(&Sha512::digest(&ctr.to_ne_bytes()).as_slice()[..l]);
buf = &mut buf[l..];
}
unsafe { RANDOM_CTR = ctr };
}
struct TestNodeHost {
name: String,
@ -34,22 +52,16 @@ impl Host for TestNodeHost {
}
fn on_connect(&self, info: &RemoteNodeInfo) {
println!("{:5}: connected to {} ({}, {})", self.name, info.remote_address.to_string(), info.node_name.as_ref().map_or("null", |s| s.as_str()), if info.inbound { "inbound" } else { "outbound" });
//println!("{:5}: connected to {} ({}, {})", self.name, info.remote_address.to_string(), info.node_name.as_ref().map_or("null", |s| s.as_str()), if info.inbound { "inbound" } else { "outbound" });
}
fn on_connection_closed(&self, info: &RemoteNodeInfo, reason: String) {
println!("{:5}: closed connection to {}: {} ({}, {})", self.name, info.remote_address.to_string(), reason, if info.inbound { "inbound" } else { "outbound" }, if info.initialized { "initialized" } else { "not initialized" });
}
fn get_secure_random(&self, mut buf: &mut [u8]) {
fn get_secure_random(&self, buf: &mut [u8]) {
// This is only for testing and is not really secure.
let mut ctr = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos();
while !buf.is_empty() {
let l = buf.len().min(64);
ctr = ctr.wrapping_add(1);
buf[0..l].copy_from_slice(&Self::sha512(&[&ctr.to_ne_bytes()])[0..l]);
buf = &mut buf[l..];
}
get_random_bytes(buf);
}
}
@ -82,13 +94,17 @@ impl DataStore for TestNodeHost {
}
fn count(&self, _: i64, key_range_start: &[u8], key_range_end: &[u8]) -> u64 {
self.db.lock().unwrap().range((Included(key_range_start.try_into().unwrap()), Included(key_range_end.try_into().unwrap()))).count() as u64
let s: [u8; 64] = key_range_start.try_into().unwrap();
let e: [u8; 64] = key_range_end.try_into().unwrap();
self.db.lock().unwrap().range((Included(s), Included(e))).count() as u64
}
fn total_count(&self) -> u64 { self.db.lock().unwrap().len() as u64 }
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, _: i64, key_range_start: &[u8], key_range_end: &[u8], mut f: F) {
for (k, v) in self.db.lock().unwrap().range((Included(key_range_start.try_into().unwrap()), Included(key_range_end.try_into().unwrap()))) {
let s: [u8; 64] = key_range_start.try_into().unwrap();
let e: [u8; 64] = key_range_end.try_into().unwrap();
for (k, v) in self.db.lock().unwrap().range((Included(s), Included(e))) {
if !f(k, v.as_ref()) {
break;
}
@ -101,6 +117,7 @@ fn main() {
println!("Running syncwhole local self-test network with {} nodes starting at 127.0.0.1:{}", TEST_NODE_COUNT, TEST_PORT_RANGE_START);
println!();
println!("Starting nodes on 127.0.0.1...");
let mut nodes: Vec<Node<TestNodeHost, TestNodeHost>> = Vec::with_capacity(TEST_NODE_COUNT);
for port in TEST_PORT_RANGE_START..(TEST_PORT_RANGE_START + (TEST_NODE_COUNT as u16)) {
let mut peers: Vec<SocketAddr> = Vec::with_capacity(TEST_NODE_COUNT);
@ -114,20 +131,43 @@ fn main() {
peers,
db: Mutex::new(BTreeMap::new())
});
println!("Starting node on 127.0.0.1:{} with {} records in data store...", port, nh.db.lock().unwrap().len());
//println!("Starting node on 127.0.0.1:{}...", port, nh.db.lock().unwrap().len());
nodes.push(Node::new(nh.clone(), nh.clone(), SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))).await.unwrap());
}
println!();
print!("Waiting for all connections to be established...");
let _ = stdout().flush();
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
/*
let mut count = 0;
for n in nodes.iter() {
count += n.connection_count().await;
}
println!("{}", count);
*/
if count == (TEST_NODE_COUNT * (TEST_NODE_COUNT - 1)) {
println!(" {} connections up.", count);
break;
} else {
print!(".");
let _ = stdout().flush();
}
}
println!("Populating maps with data to be synchronized between nodes...");
let mut all_records = BTreeMap::new();
for n in nodes.iter_mut() {
for _ in 0..TEST_STARTING_RECORDS_PER_NODE {
let mut k = [0_u8; 64];
let mut v = [0_u8; 32];
get_random_bytes(&mut k);
get_random_bytes(&mut v);
let v: Arc<[u8]> = Arc::from(v);
all_records.insert(k.clone(), v.clone());
n.datastore().db.lock().unwrap().insert(k, v);
}
}
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}

View file

@ -11,13 +11,13 @@ use std::io::IoSlice;
use std::mem::MaybeUninit;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::ops::Add;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio::time::{Duration, Instant};
@ -30,25 +30,19 @@ use crate::utils::*;
use crate::varint;
/// Inactivity timeout for connections in milliseconds.
const CONNECTION_TIMEOUT: i64 = 120000;
/// How often to send STATUS messages in milliseconds.
const STATUS_INTERVAL: i64 = 10000;
const CONNECTION_TIMEOUT: i64 = SYNC_STATUS_PERIOD * 4;
/// How often to run the housekeeping task's loop in milliseconds.
const HOUSEKEEPING_INTERVAL: i64 = STATUS_INTERVAL / 2;
/// Size of read buffer, which is used to reduce the number of syscalls.
const READ_BUFFER_SIZE: usize = 16384;
const HOUSEKEEPING_INTERVAL: i64 = SYNC_STATUS_PERIOD;
/// Information about a remote node to which we are connected.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RemoteNodeInfo {
/// Optional name advertised by remote node (arbitrary).
pub node_name: Option<String>,
pub name: String,
/// Optional contact information advertised by remote node (arbitrary).
pub node_contact: Option<String>,
pub contact: String,
/// Actual remote endpoint address.
pub remote_address: SocketAddr,
@ -78,6 +72,10 @@ fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> {
}
}
fn decode_msgpack<'a, T: Deserialize<'a>>(b: &'a [u8]) -> std::io::Result<T> {
rmp_serde::from_slice(b).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid msgpack object: {}", e.to_string())))
}
/// An instance of the syncwhole data set synchronization engine.
///
/// This holds a number of async tasks that are terminated or aborted if this object
@ -85,7 +83,7 @@ fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> {
pub struct Node<D: DataStore + 'static, H: Host + 'static> {
internal: Arc<NodeInternal<D, H>>,
housekeeping_task: JoinHandle<()>,
listener_task: JoinHandle<()>
listener_task: JoinHandle<()>,
}
impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
@ -105,18 +103,21 @@ impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
host: host.clone(),
connections: Mutex::new(HashMap::with_capacity(64)),
bind_address,
starting_instant: Instant::now(),
});
Ok(Self {
internal: internal.clone(),
housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()),
listener_task: tokio::spawn(internal.listener_task_main(listener)),
})
Ok(Self { internal: internal.clone(), housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()), listener_task: tokio::spawn(internal.listener_task_main(listener)) })
}
pub fn datastore(&self) -> &Arc<D> { &self.internal.datastore }
#[inline(always)]
pub fn datastore(&self) -> &Arc<D> {
&self.internal.datastore
}
pub fn host(&self) -> &Arc<H> { &self.internal.host }
#[inline(always)]
pub fn host(&self) -> &Arc<H> {
&self.internal.host
}
pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> {
self.internal.clone().connect(endpoint, Instant::now().add(Duration::from_millis(CONNECTION_TIMEOUT as u64))).await
@ -126,7 +127,7 @@ impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
let connections = self.internal.connections.lock().await;
let mut cl: Vec<RemoteNodeInfo> = Vec::with_capacity(connections.len());
for (_, c) in connections.iter() {
cl.push(c.0.info.blocking_lock().clone());
cl.push(c.info.lock().await.clone());
}
cl
}
@ -152,16 +153,22 @@ pub struct NodeInternal<D: DataStore + 'static, H: Host + 'static> {
host: Arc<H>,
// Connections and their task join handles, by remote endpoint address.
connections: Mutex<HashMap<SocketAddr, (Arc<Connection>, Option<JoinHandle<std::io::Result<()>>>)>>,
connections: Mutex<HashMap<SocketAddr, Arc<Connection>>>,
// Local address to which this node is bound
bind_address: SocketAddr,
// Instant this node started.
starting_instant: Instant,
}
impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
fn ms_monotonic(&self) -> i64 {
Instant::now().duration_since(self.starting_instant).as_millis() as i64
}
/// Loop that constantly runs in the background to do cleanup and service things.
async fn housekeeping_task_main(self: Arc<Self>) {
let mut last_status_sent = ms_monotonic();
let mut tasks: Vec<JoinHandle<()>> = Vec::new();
let mut connected_to_addresses: HashSet<SocketAddr> = HashSet::new();
let mut sleep_until = Instant::now().add(Duration::from_millis(500));
@ -171,50 +178,17 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
tasks.clear();
connected_to_addresses.clear();
let now = ms_monotonic();
let now = self.ms_monotonic();
// Check connection timeouts, send status updates, and garbage collect from the connections map.
// Status message outputs are backgrounded since these can block of TCP buffers are nearly full.
// A timeout based on the service loop interval is used. Usually these sends will finish instantly
// but if they take too long this typically means the link is dead. We wait for all tasks at the
// end of the service loop. The on_connection_closed() method in 'host' is called in sub-tasks to
// prevent the possibility of deadlocks on self.connections.lock() if the Host implementation calls
// something that tries to lock it.
let status = if (now - last_status_sent) >= STATUS_INTERVAL {
last_status_sent = now;
Some(msg::Status {
total_record_count: self.datastore.total_count(),
reference_time: self.datastore.clock()
})
} else {
None
};
self.connections.lock().await.retain(|sa, c| {
let cc = &(c.0);
if !cc.closed.load(Ordering::Relaxed) {
if (now - cc.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT {
if !c.closed.load(Ordering::Relaxed) {
if (now - c.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT {
connected_to_addresses.insert(sa.clone());
let _ = status.as_ref().map(|status| {
let status = status.clone();
let self2 = self.clone();
let cc = cc.clone();
let sa = sa.clone();
tasks.push(tokio::spawn(async move {
if cc.info.lock().await.initialized {
// This almost always completes instantly due to queues, but add a timeout in case connection
// is stalled. In this case the result is a closed connection.
if !tokio::time::timeout_at(sleep_until, cc.send_obj(MESSAGE_TYPE_STATUS, &status, now)).await.map_or(false, |r| r.is_ok()) {
let _ = self2.connections.lock().await.remove(&sa).map(|c| c.1.map(|j| j.abort()));
self2.host.on_connection_closed(&*cc.info.lock().await, "write overflow (timeout)".to_string());
}
}
}));
});
true // keep connection
} else {
let _ = c.1.take().map(|j| j.abort());
let _ = c.read_task.lock().unwrap().take().map(|j| j.abort());
let host = self.host.clone();
let cc = cc.clone();
let cc = c.clone();
tasks.push(tokio::spawn(async move {
host.on_connection_closed(&*cc.info.lock().await, "timeout".to_string());
}));
@ -222,8 +196,8 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
}
} else {
let host = self.host.clone();
let cc = cc.clone();
let j = c.1.take();
let cc = c.clone();
let j = c.read_task.lock().unwrap().take();
tasks.push(tokio::spawn(async move {
if j.is_some() {
let e = j.unwrap().await;
@ -241,9 +215,10 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
}
});
// Always try to connect to fixed peers.
let fixed_peers = self.host.fixed_peers();
for sa in fixed_peers.iter() {
let config = self.host.node_config();
// Always try to connect to anchor peers.
for sa in config.anchors.iter() {
if !connected_to_addresses.contains(sa) {
let sa = sa.clone();
let self2 = self.clone();
@ -255,18 +230,18 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
}
// Try to connect to more peers until desired connection count is reached.
let desired_connection_count = self.host.desired_connection_count().min(self.host.max_connection_count());
while connected_to_addresses.len() < desired_connection_count {
let sa = self.host.another_peer(&connected_to_addresses);
if sa.is_some() {
let sa = sa.unwrap();
let desired_connection_count = config.desired_connection_count.min(config.max_connection_count);
for sa in config.seeds.iter() {
if connected_to_addresses.len() >= desired_connection_count {
break;
}
if !connected_to_addresses.contains(sa) {
connected_to_addresses.insert(sa.clone());
let self2 = self.clone();
let sa = sa.clone();
tasks.push(tokio::spawn(async move {
let _ = self2.connect(&sa, sleep_until).await;
}));
connected_to_addresses.insert(sa.clone());
} else {
break;
}
}
@ -289,7 +264,8 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
if socket.is_ok() {
let (stream, address) = socket.unwrap();
if self.host.allow(&address) {
if self.connections.lock().await.len() < self.host.max_connection_count() || self.host.fixed_peers().contains(&address) {
let config = self.host.node_config();
if self.connections.lock().await.len() < config.max_connection_count || config.anchors.contains(&address) {
Self::connection_start(&self, address, stream, true).await;
}
}
@ -316,33 +292,27 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
let mut ok = false;
let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| {
ok = true;
let _ = stream.set_nodelay(true);
//let _ = stream.set_nodelay(true);
let (reader, writer) = stream.into_split();
let now = ms_monotonic();
let now = self.ms_monotonic();
let connection = Arc::new(Connection {
writer: Mutex::new(writer),
last_send_time: AtomicI64::new(now),
last_receive_time: AtomicI64::new(now),
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
info: Mutex::new(RemoteNodeInfo {
node_name: None,
node_contact: None,
remote_address: address.clone(),
explicit_addresses: Vec::new(),
connect_time: ms_since_epoch(),
connect_instant: ms_monotonic(),
inbound,
initialized: false
}),
closed: AtomicBool::new(false)
info: Mutex::new(RemoteNodeInfo { name: String::new(), contact: String::new(), remote_address: address.clone(), explicit_addresses: Vec::new(), connect_time: ms_since_epoch(), connect_instant: now, inbound, initialized: false }),
read_task: std::sync::Mutex::new(None),
closed: AtomicBool::new(false),
});
let self2 = self.clone();
(connection.clone(), Some(tokio::spawn(async move {
let result = self2.connection_io_task_main(&connection, BufReader::with_capacity(READ_BUFFER_SIZE, reader)).await;
connection.closed.store(true, Ordering::Relaxed);
let c2 = connection.clone();
connection.read_task.lock().unwrap().replace(tokio::spawn(async move {
let result = self2.connection_io_task_main(&c2, reader).await;
c2.closed.store(true, Ordering::Relaxed);
result
})))
}));
connection
});
ok
}
@ -351,53 +321,96 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
///
/// This handles reading from the connection and reacting to what it sends. Killing this
/// task is done when the connection is closed.
async fn connection_io_task_main(self: Arc<Self>, connection: &Arc<Connection>, mut reader: BufReader<OwnedReadHalf>) -> std::io::Result<()> {
async fn connection_io_task_main(self: Arc<Self>, connection: &Arc<Connection>, mut reader: OwnedReadHalf) -> std::io::Result<()> {
const BUF_CHUNK_SIZE: usize = 4096;
const READ_BUF_INITIAL_SIZE: usize = 65536; // should be a multiple of BUF_CHUNK_SIZE
let background_tasks = AsyncTaskReaper::new();
let mut write_buffer: Vec<u8> = Vec::with_capacity(BUF_CHUNK_SIZE);
let mut read_buffer: Vec<u8> = Vec::new();
read_buffer.resize(READ_BUF_INITIAL_SIZE, 0);
let config = self.host.node_config();
let mut anti_loopback_challenge_sent = [0_u8; 64];
let mut domain_challenge_sent = [0_u8; 64];
let mut auth_challenge_sent = [0_u8; 64];
self.host.get_secure_random(&mut anti_loopback_challenge_sent);
self.host.get_secure_random(&mut domain_challenge_sent);
self.host.get_secure_random(&mut auth_challenge_sent);
connection.send_obj(MESSAGE_TYPE_INIT, &msg::Init {
anti_loopback_challenge: &anti_loopback_challenge_sent,
domain_challenge: &domain_challenge_sent,
auth_challenge: &auth_challenge_sent,
node_name: self.host.name().map(|n| n.to_string()),
node_contact: self.host.contact().map(|c| c.to_string()),
locally_bound_port: self.bind_address.port(),
explicit_ipv4: None,
explicit_ipv6: None
}, ms_monotonic()).await?;
connection
.send_obj(
&mut write_buffer,
MessageType::Init,
&msg::Init {
anti_loopback_challenge: &anti_loopback_challenge_sent,
domain_challenge: &domain_challenge_sent,
auth_challenge: &auth_challenge_sent,
node_name: config.name.as_str(),
node_contact: config.contact.as_str(),
locally_bound_port: self.bind_address.port(),
explicit_ipv4: None,
explicit_ipv6: None,
},
self.ms_monotonic(),
)
.await?;
drop(config);
let max_message_size = ((D::MAX_VALUE_SIZE * 8) + (D::KEY_SIZE * 1024) + 65536) as u64; // sanity limit
let mut initialized = false;
let background_tasks = AsyncTaskReaper::new();
let mut init_received = false;
let mut buf: Vec<u8> = Vec::new();
buf.resize(4096, 0);
let mut buffer_fill = 0_usize;
loop {
let message_type = reader.read_u8().await?;
let message_size = varint::read_async(&mut reader).await?;
let header_size = 1 + message_size.1;
let message_size = message_size.0;
if message_size > (D::MAX_VALUE_SIZE + ((D::KEY_SIZE + 10) * 256) + 65536) as u64 {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large"));
}
let message_type: MessageType;
let message_size: usize;
let header_size: usize;
let total_size: usize;
loop {
buffer_fill += reader.read(&mut read_buffer.as_mut_slice()[buffer_fill..]).await?;
if buffer_fill >= 2 {
// type and at least one byte of varint
let ms = varint::decode(&read_buffer.as_slice()[1..]);
if ms.1 > 0 {
// varint is all there and parsed correctly
if ms.0 > max_message_size {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large"));
}
let now = ms_monotonic();
message_type = MessageType::from(*read_buffer.get(0).unwrap());
message_size = ms.0 as usize;
header_size = 1 + ms.1;
total_size = header_size + message_size;
if read_buffer.len() < total_size {
read_buffer.resize(((total_size / BUF_CHUNK_SIZE) + 1) * BUF_CHUNK_SIZE, 0);
}
while buffer_fill < total_size {
buffer_fill += reader.read(&mut read_buffer.as_mut_slice()[buffer_fill..]).await?;
}
break;
}
}
}
let message = &read_buffer.as_slice()[header_size..total_size];
let now = self.ms_monotonic();
connection.last_receive_time.store(now, Ordering::Relaxed);
match message_type {
MessageType::Nop => {}
MESSAGE_TYPE_INIT => {
MessageType::Init => {
if init_received {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init"));
}
init_received = true;
let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
let msg: msg::Init = decode_msgpack(message)?;
let (anti_loopback_response, domain_challenge_response, auth_challenge_response) = {
let mut info = connection.info.lock().await;
info.node_name = msg.node_name.clone();
info.node_contact = msg.node_contact.clone();
info.name = msg.node_name.to_string();
info.contact = msg.node_contact.to_string();
let _ = msg.explicit_ipv4.map(|pv4| {
info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port)));
});
@ -409,260 +422,121 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
if auth_challenge_response.is_none() {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped"));
}
(
H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge),
H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), msg.domain_challenge),
auth_challenge_response.unwrap()
)
(H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), msg.domain_challenge), auth_challenge_response.unwrap())
};
connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse {
anti_loopback_response: &anti_loopback_response,
domain_response: &domain_challenge_response,
auth_response: &auth_challenge_response
}, now).await?;
connection.send_obj(&mut write_buffer, MessageType::InitResponse, &msg::InitResponse { anti_loopback_response: &anti_loopback_response, domain_response: &domain_challenge_response, auth_response: &auth_challenge_response }, now).await?;
}
init_received = true;
},
MESSAGE_TYPE_INIT_RESPONSE => {
let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
MessageType::InitResponse => {
let msg: msg::InitResponse = decode_msgpack(message)?;
let mut info = connection.info.lock().await;
if info.initialized {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init response"));
}
info.initialized = true;
let info = info.clone();
if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self"));
}
if msg.domain_response.eq(&H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), &domain_challenge_sent)) {
if !msg.domain_response.eq(&H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), &domain_challenge_sent)) {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "domain mismatch"));
}
if !self.host.authenticate(&info, &auth_challenge_sent).map_or(false, |cr| msg.auth_response.eq(&cr)) {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed"));
}
self.host.on_connect(&info);
info.initialized = true;
initialized = true;
},
let info = info.clone();
self.host.on_connect(&info);
}
// Handle messages other than INIT and INIT_RESPONSE after checking 'initialized' flag.
_ => {
if !initialized {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "init exchange must be completed before other messages are sent"));
}
match message_type {
_ => {}
MESSAGE_TYPE_STATUS => {
let msg: msg::Status = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
self.connection_request_summary(connection, msg.total_record_count, now, msg.reference_time).await?;
},
MessageType::HaveRecords => {
let msg: msg::HaveRecords = decode_msgpack(message)?;
}
MESSAGE_TYPE_GET_SUMMARY => {
//let msg: msg::GetSummary = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
},
MessageType::GetRecords => {
let msg: msg::GetRecords = decode_msgpack(message)?;
}
MESSAGE_TYPE_SUMMARY => {
let mut remaining = message_size as isize;
// Read summary header.
let summary_header_size = varint::read_async(&mut reader).await?;
remaining -= summary_header_size.1 as isize;
let summary_header_size = summary_header_size.0;
if (summary_header_size as i64) > (remaining as i64) {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid summary header"));
}
let summary_header: msg::SummaryHeader = connection.read_obj(&mut reader, &mut buf, summary_header_size as usize).await?;
remaining -= summary_header_size as isize;
// Read and evaluate summary that we were sent.
match summary_header.summary_type {
SUMMARY_TYPE_KEYS => {
self.connection_receive_and_process_remote_hash_list(
connection,
remaining,
&mut reader,
now,
summary_header.reference_time,
&summary_header.prefix[0..summary_header.prefix.len().min((summary_header.prefix_bits / 8) as usize)]).await?;
},
SUMMARY_TYPE_IBLT => {
//let summary = IBLT::new_from_reader(&mut reader, remaining as usize).await?;
},
_ => {} // ignore unknown summary types
}
// Request another summary if needed, keeping a ping-pong game going in a tight loop until synced.
self.connection_request_summary(connection, summary_header.total_record_count, now, summary_header.reference_time).await?;
},
MESSAGE_TYPE_HAVE_RECORDS => {
let mut remaining = message_size as isize;
let reference_time = varint::read_async(&mut reader).await?;
remaining -= reference_time.1 as isize;
if remaining <= 0 {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid message"));
}
self.connection_receive_and_process_remote_hash_list(connection, remaining, &mut reader, now, reference_time.0 as i64, &[]).await?
},
MESSAGE_TYPE_GET_RECORDS => {
},
MESSAGE_TYPE_RECORD => {
let value = connection.read_msg(&mut reader, &mut buf, message_size as usize).await?;
if value.len() > D::MAX_VALUE_SIZE {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "value larger than MAX_VALUE_SIZE"));
}
let key = H::sha512(&[value]);
match self.datastore.store(&key, value) {
StoreResult::Ok(reference_time) => {
let mut have_records_msg = [0_u8; 2 + 10 + ANNOUNCE_HASH_BYTES];
let mut msg_len = varint::encode(&mut have_records_msg, reference_time as u64);
have_records_msg[msg_len] = ANNOUNCE_HASH_BYTES as u8;
msg_len += 1;
have_records_msg[msg_len..(msg_len + ANNOUNCE_HASH_BYTES)].copy_from_slice(&key[..ANNOUNCE_HASH_BYTES]);
msg_len += ANNOUNCE_HASH_BYTES;
let self2 = self.clone();
let connection2 = connection.clone();
background_tasks.spawn(async move {
let connections = self2.connections.lock().await;
let mut recipients = Vec::with_capacity(connections.len());
for (_, c) in connections.iter() {
if !Arc::ptr_eq(&(c.0), &connection2) {
recipients.push(c.0.clone());
}
MessageType::Record => {
let key = H::sha512(&[message]);
match self.datastore.store(&key, message) {
StoreResult::Ok => {
// TODO: probably should not announce if way out of sync
let connections = self.connections.lock().await;
let mut announce_to: Vec<Arc<Connection>> = Vec::with_capacity(connections.len());
for (_, c) in connections.iter() {
if !Arc::ptr_eq(&connection, c) {
announce_to.push(c.clone());
}
drop(connections); // release lock
}
drop(connections); // release lock
for c in recipients.iter() {
// This typically completes instantly due to buffering, as this message is small.
// Add a small timeout in the case that some connections are stalled. Misses will
// not impact the overall network much.
let _ = tokio::time::timeout(Duration::from_millis(250), c.send_msg(MESSAGE_TYPE_HAVE_RECORDS, &have_records_msg[..msg_len], now)).await;
background_tasks.spawn(async move {
for c in announce_to.iter() {
let _ = c.send_msg(MessageType::HaveRecord, &key[0..ANNOUNCE_KEY_LEN], now).await;
}
});
},
}
StoreResult::Rejected => {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid datum received"));
},
_ => {} // duplicate or ignored values are just... ignored
}
},
_ => {
// Skip messages that aren't recognized or don't need to be parsed.
let mut remaining = message_size as usize;
while remaining > 0 {
let s = remaining.min(buf.len());
reader.read_exact(&mut buf.as_mut_slice()[0..s]).await?;
remaining -= s;
return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("record rejected by data store: {}", to_hex_string(&key))));
}
_ => {}
}
}
MessageType::SyncStatus => {
let msg: msg::SyncStatus = decode_msgpack(message)?;
}
MessageType::SyncRequest => {
let msg: msg::SyncRequest = decode_msgpack(message)?;
}
MessageType::SyncResponse => {
let msg: msg::SyncResponse = decode_msgpack(message)?;
}
}
}
}
connection.bytes_received.fetch_add((header_size as u64) + message_size, Ordering::Relaxed);
read_buffer.copy_within(total_size..buffer_fill, 0);
buffer_fill -= total_size;
connection.bytes_received.fetch_add(total_size as u64, Ordering::Relaxed);
}
}
/// Request a summary if needed, or do nothing if not.
///
/// This is where all the logic lives that determines whether to request summaries, choosing a
/// prefix, etc. It's called when the remote node tells us its total record count with an
/// associated reference time, which happens in status announcements and in summaries.
async fn connection_request_summary(&self, connection: &Arc<Connection>, total_record_count: u64, now: i64, reference_time: i64) -> std::io::Result<()> {
let my_total_record_count = self.datastore.total_count();
if my_total_record_count < total_record_count {
// Figure out how many bits need to be in a randomly chosen prefix to choose a slice of
// the data set such that the set difference should be around 4096 records. This assumes
// random distribution, which should be mostly maintained by probing prefixes at random.
let prefix_bits = ((total_record_count - my_total_record_count) as f64) / 4096.0;
let prefix_bits = if prefix_bits > 1.0 {
(prefix_bits.log2().ceil() as usize).min(64)
} else {
0 as usize
};
let prefix_bytes = (prefix_bits / 8) + (((prefix_bits % 8) != 0) as usize);
// Generate a random prefix of this many bits (to the nearest byte).
let mut prefix = [0_u8; 64];
self.host.get_secure_random(&mut prefix[..prefix_bytes]);
// Request a set summary for this prefix, providing our own count for this prefix so
// the remote can decide whether to send something like an IBLT or just hashes.
let (local_range_start, local_range_end) = range_from_prefix(&prefix, prefix_bits);
connection.send_obj(MESSAGE_TYPE_GET_SUMMARY, &msg::GetSummary {
reference_time,
prefix: &prefix[..prefix_bytes],
prefix_bits: prefix_bits as u8,
record_count: self.datastore.count(reference_time, &local_range_start, &local_range_end)
}, now).await
} else {
Ok(())
}
}
/// Read a stream of record hashes (or hash prefixes) from a connection and request records we don't have.
async fn connection_receive_and_process_remote_hash_list(&self, connection: &Arc<Connection>, mut remaining: isize, reader: &mut BufReader<OwnedReadHalf>, now: i64, reference_time: i64, common_prefix: &[u8]) -> std::io::Result<()> {
if remaining > 0 {
// Hash list is prefaced by the number of bytes in each hash, since whole 64 byte hashes do not have to be sent.
let prefix_entry_size = reader.read_u8().await? as usize;
let total_prefix_size = common_prefix.len() + prefix_entry_size;
if prefix_entry_size > 0 && total_prefix_size <= 64 {
remaining -= 1;
if remaining >= (prefix_entry_size as isize) {
let mut get_records_msg: Vec<u8> = Vec::with_capacity(((remaining as usize) / prefix_entry_size) * total_prefix_size);
varint::write(&mut get_records_msg, reference_time as u64)?;
get_records_msg.push(total_prefix_size as u8);
let mut key_prefix_buf = [0_u8; 64];
key_prefix_buf[..common_prefix.len()].copy_from_slice(common_prefix);
while remaining >= (prefix_entry_size as isize) {
remaining -= prefix_entry_size as isize;
reader.read_exact(&mut key_prefix_buf[common_prefix.len()..total_prefix_size]).await?;
if if total_prefix_size < 64 {
let (s, e) = range_from_prefix(&key_prefix_buf[..total_prefix_size], total_prefix_size * 8);
self.datastore.count(reference_time, &s, &e) == 0
} else {
!self.datastore.contains(reference_time, &key_prefix_buf)
} {
let _ = get_records_msg.write_all(&key_prefix_buf[..total_prefix_size]);
}
}
if remaining == 0 {
return connection.send_msg(MESSAGE_TYPE_GET_RECORDS, get_records_msg.as_slice(), now).await;
}
}
}
}
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hash list"));
}
}
impl<D: DataStore + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
fn drop(&mut self) {
let _ = tokio::runtime::Handle::try_current().map_or_else(|_| {
for (_, c) in self.connections.blocking_lock().drain() {
c.1.map(|c| c.abort());
}
}, |h| {
let _ = h.block_on(async {
for (_, c) in self.connections.lock().await.drain() {
c.1.map(|c| c.abort());
let _ = tokio::runtime::Handle::try_current().map_or_else(
|_| {
for (_, c) in self.connections.blocking_lock().drain() {
c.read_task.lock().unwrap().as_mut().map(|c| c.abort());
}
});
});
},
|h| {
let _ = h.block_on(async {
for (_, c) in self.connections.lock().await.drain() {
c.read_task.lock().unwrap().as_mut().map(|c| c.abort());
}
});
},
);
}
}
@ -673,53 +547,30 @@ struct Connection {
bytes_sent: AtomicU64,
bytes_received: AtomicU64,
info: Mutex<RemoteNodeInfo>,
read_task: std::sync::Mutex<Option<JoinHandle<std::io::Result<()>>>>,
closed: AtomicBool,
}
impl Connection {
async fn send_msg(&self, message_type: u8, data: &[u8], now: i64) -> std::io::Result<()> {
let mut type_and_size = [0_u8; 16];
type_and_size[0] = message_type;
let tslen = 1 + varint::encode(&mut type_and_size[1..], data.len() as u64) as usize;
let total_size = tslen + data.len();
if self.writer.lock().await.write_vectored(&[IoSlice::new(&type_and_size[..tslen]), IoSlice::new(data)]).await? == total_size {
async fn send_msg(&self, message_type: MessageType, data: &[u8], now: i64) -> std::io::Result<()> {
let mut header: [u8; 16] = unsafe { MaybeUninit::uninit().assume_init() };
header[0] = message_type as u8;
let header_size = 1 + varint::encode(&mut header[1..], data.len() as u64);
if self.writer.lock().await.write_vectored(&[IoSlice::new(&header[0..header_size]), IoSlice::new(data)]).await? == (data.len() + header_size) {
self.last_send_time.store(now, Ordering::Relaxed);
self.bytes_sent.fetch_add(total_size as u64, Ordering::Relaxed);
self.bytes_sent.fetch_add((header_size + data.len()) as u64, Ordering::Relaxed);
Ok(())
} else {
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error"))
}
}
async fn send_obj<O: Serialize>(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> {
let data = rmp_serde::encode::to_vec_named(&obj);
if data.is_ok() {
let data = data.unwrap();
let mut header: [u8; 16] = unsafe { MaybeUninit::uninit().assume_init() };
header[0] = message_type;
let header_size = 1 + varint::encode(&mut header[1..], data.len() as u64);
if self.writer.lock().await.write_vectored(&[IoSlice::new(&header[0..header_size]), IoSlice::new(data.as_slice())]).await? == (data.len() + header_size) {
self.last_send_time.store(now, Ordering::Relaxed);
self.bytes_sent.fetch_add((header_size + data.len()) as u64, Ordering::Relaxed);
Ok(())
} else {
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error"))
}
async fn send_obj<O: Serialize>(&self, write_buf: &mut Vec<u8>, message_type: MessageType, obj: &O, now: i64) -> std::io::Result<()> {
write_buf.clear();
if rmp_serde::encode::write_named(write_buf, obj).is_ok() {
self.send_msg(message_type, write_buf.as_slice(), now).await
} else {
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure (internal error)"))
}
}
async fn read_msg<'a, R: AsyncReadExt + Unpin>(&self, reader: &mut R, buf: &'a mut Vec<u8>, message_size: usize) -> std::io::Result<&'a [u8]> {
if message_size > buf.len() {
buf.resize(((message_size / 4096) + 1) * 4096, 0);
}
let b = &mut buf.as_mut_slice()[0..message_size];
reader.read_exact(b).await?;
Ok(b)
}
async fn read_obj<'a, R: AsyncReadExt + Unpin, O: Deserialize<'a>>(&self, reader: &mut R, buf: &'a mut Vec<u8>, message_size: usize) -> std::io::Result<O> {
rmp_serde::from_read_ref(self.read_msg(reader, buf, message_size).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid msgpack: {}", e.to_string())))
}
}

View file

@ -6,55 +6,108 @@
* https://www.zerotier.com/
*/
/// No operation, payload ignored.
pub const MESSAGE_TYPE_NOP: u8 = 0;
/// Number of bytes of SHA512 to announce, should be high enough to make collisions virtually impossible.
pub const ANNOUNCE_KEY_LEN: usize = 24;
/// Sent by both sides of a TCP link when it's established.
pub const MESSAGE_TYPE_INIT: u8 = 1;
/// Send SyncStatus this frequently, in milliseconds.
pub const SYNC_STATUS_PERIOD: i64 = 5000;
/// Reply sent to INIT.
pub const MESSAGE_TYPE_INIT_RESPONSE: u8 = 2;
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MessageType {
Nop = 0_u8,
Init = 1_u8,
InitResponse = 2_u8,
HaveRecord = 3_u8,
HaveRecords = 4_u8,
GetRecords = 5_u8,
Record = 6_u8,
SyncStatus = 7_u8,
SyncRequest = 8_u8,
SyncResponse = 9_u8,
}
/// Sent every few seconds to notify peers of number of records, clock, etc.
pub const MESSAGE_TYPE_STATUS: u8 = 3;
impl From<u8> for MessageType {
/// Get a type from a byte, returning the Nop type if the byte is out of range.
#[inline(always)]
fn from(b: u8) -> Self {
if b <= 7 {
unsafe { std::mem::transmute(b) }
} else {
Self::Nop
}
}
}
/// Get a set summary of a prefix in the data set.
pub const MESSAGE_TYPE_GET_SUMMARY: u8 = 4;
impl MessageType {
pub fn name(&self) -> &'static str {
match *self {
Self::Nop => "NOP",
Self::Init => "INIT",
Self::InitResponse => "INIT_RESPONSE",
Self::HaveRecord => "HAVE_RECORD",
Self::HaveRecords => "HAVE_RECORDS",
Self::GetRecords => "GET_RECORDS",
Self::Record => "RECORD",
Self::SyncStatus => "SYNC_STATUS",
Self::SyncRequest => "SYNC_REQUEST",
Self::SyncResponse => "SYNC_RESPONSE",
}
}
}
/// Set summary of a prefix.
pub const MESSAGE_TYPE_SUMMARY: u8 = 5;
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum SyncResponseType {
/// No response, do nothing.
None = 0_u8,
/// Payload is a list of keys of records. Usually sent to advertise recently received new records.
pub const MESSAGE_TYPE_HAVE_RECORDS: u8 = 6;
/// Response is a msgpack-encoded HaveRecords message.
HaveRecords = 1_u8,
/// Payload is a list of keys of records the sending node wants.
pub const MESSAGE_TYPE_GET_RECORDS: u8 = 7;
/// Response is a series of records prefixed by varint record sizes.
Records = 2_u8,
/// Payload is a record, with key being omitted if the data store's KEY_IS_COMPUTED constant is true.
pub const MESSAGE_TYPE_RECORD: u8 = 8;
/// Response is an IBLT set summary.
IBLT = 3_u8,
}
/// Summary type: simple array of keys under the given prefix.
pub const SUMMARY_TYPE_KEYS: u8 = 0;
impl From<u8> for SyncResponseType {
/// Get response type from a byte, returning None if the byte is out of range.
#[inline(always)]
fn from(b: u8) -> Self {
if b <= 3 {
unsafe { std::mem::transmute(b) }
} else {
Self::None
}
}
}
/// An IBLT set summary.
pub const SUMMARY_TYPE_IBLT: u8 = 1;
/// Number of bytes of each SHA512 hash to announce, request, etc. This is okay to change but 16 is plenty.
pub const ANNOUNCE_HASH_BYTES: usize = 16;
impl SyncResponseType {
pub fn as_str(&self) -> &'static str {
match *self {
SyncResponseType::None => "NONE",
SyncResponseType::HaveRecords => "HAVE_RECORDS",
SyncResponseType::Records => "RECORDS",
SyncResponseType::IBLT => "IBLT",
}
}
}
pub mod msg {
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct IPv4 {
pub ip: [u8; 4],
pub port: u16
pub port: u16,
}
#[derive(Serialize, Deserialize)]
pub struct IPv6 {
pub ip: [u8; 16],
pub port: u16
pub port: u16,
}
#[derive(Serialize, Deserialize)]
@ -72,10 +125,10 @@ pub mod msg {
pub auth_challenge: &'a [u8],
/// Optional name to advertise for this node.
pub node_name: Option<String>,
pub node_name: &'a str,
/// Optional contact information for this node, such as a URL or an e-mail address.
pub node_contact: Option<String>,
pub node_contact: &'a str,
/// Port to which this node has locally bound.
/// This is used to try to auto-detect whether a NAT is in the way.
@ -105,63 +158,95 @@ pub mod msg {
pub auth_response: &'a [u8],
}
#[derive(Serialize, Deserialize, Clone)]
pub struct Status {
/// Total number of records in data set.
#[serde(rename = "c")]
pub total_record_count: u64,
#[derive(Serialize, Deserialize)]
pub struct HaveRecords<'a> {
/// Length of each key, chosen to ensure uniqueness.
#[serde(rename = "l")]
pub key_length: usize,
/// Reference wall clock time in milliseconds since Unix epoch.
#[serde(rename = "t")]
pub reference_time: i64,
/// Keys whose existence is being announced, of 'key_length' length.
#[serde(with = "serde_bytes")]
#[serde(rename = "k")]
pub keys: &'a [u8],
}
#[derive(Serialize, Deserialize, Clone)]
pub struct GetSummary<'a> {
/// Reference wall clock time in milliseconds since Unix epoch.
#[serde(rename = "t")]
pub reference_time: i64,
#[derive(Serialize, Deserialize)]
pub struct GetRecords<'a> {
/// Length of each key, chosen to ensure uniqueness.
#[serde(rename = "l")]
pub key_length: usize,
/// Prefix within key space.
#[serde(rename = "p")]
/// Keys to retrieve, of 'key_length' bytes in length.
#[serde(with = "serde_bytes")]
pub prefix: &'a [u8],
#[serde(rename = "k")]
pub keys: &'a [u8],
}
/// Length of prefix in bits (trailing bits in byte array are ignored).
#[serde(rename = "b")]
pub prefix_bits: u8,
/// Number of records in this range the requesting node already has, used to choose summary type.
#[serde(rename = "r")]
#[derive(Serialize, Deserialize)]
pub struct SyncStatus {
/// Total number of records this node has in its data store.
#[serde(rename = "c")]
pub record_count: u64,
/// Sending node's system clock.
#[serde(rename = "t")]
pub clock: u64,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct SummaryHeader<'a> {
/// Total number of records in data set, for easy rapid generation of next query.
#[serde(rename = "c")]
pub total_record_count: u64,
/// Reference wall clock time in milliseconds since Unix epoch.
#[serde(rename = "t")]
pub reference_time: i64,
/// Random salt value used by some summary types.
#[serde(rename = "x")]
#[derive(Serialize, Deserialize)]
pub struct SyncRequest<'a> {
/// Query mask, a random string of KEY_SIZE bytes.
#[serde(with = "serde_bytes")]
pub salt: &'a [u8],
#[serde(rename = "q")]
pub query_mask: &'a [u8],
/// Prefix within key space.
#[serde(rename = "p")]
#[serde(with = "serde_bytes")]
pub prefix: &'a [u8],
/// Length of prefix in bits (trailing bits in byte array are ignored).
/// Number of bits to match as a prefix in query_mask (0 for entire data set).
#[serde(rename = "b")]
pub prefix_bits: u8,
pub query_mask_bits: u8,
/// Type of summary that follows this header.
/// Number of records requesting node already holds under query mask prefix.
#[serde(rename = "c")]
pub record_count: u64,
/// Sender's reference time.
#[serde(rename = "t")]
pub reference_time: u64,
/// Random salt
#[serde(rename = "s")]
pub summary_type: u8,
pub salt: u64,
}
#[derive(Serialize, Deserialize)]
pub struct SyncResponse<'a> {
/// Query mask, a random string of KEY_SIZE bytes.
#[serde(with = "serde_bytes")]
#[serde(rename = "q")]
pub query_mask: &'a [u8],
/// Number of bits to match as a prefix in query_mask (0 for entire data set).
#[serde(rename = "b")]
pub query_mask_bits: u8,
/// Number of records sender has under this prefix.
#[serde(rename = "c")]
pub record_count: u64,
/// Sender's reference time.
#[serde(rename = "t")]
pub reference_time: u64,
/// Random salt
#[serde(rename = "s")]
pub salt: u64,
/// SyncResponseType determining content of 'data'.
#[serde(rename = "r")]
pub response_type: u8,
/// Data whose meaning depends on the response type.
#[serde(with = "serde_bytes")]
#[serde(rename = "d")]
pub data: &'a [u8],
}
}

View file

@ -8,8 +8,10 @@
use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::SystemTime;
use tokio::task::JoinHandle;
/// Get the real time clock in milliseconds since Unix epoch.
@ -17,11 +19,6 @@ pub fn ms_since_epoch() -> i64 {
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as i64
}
/// Get the current monotonic clock in milliseconds.
pub fn ms_monotonic() -> i64 {
std::time::Instant::now().elapsed().as_millis() as i64
}
/// Encode a byte slice to a hexadecimal string.
pub fn to_hex_string(b: &[u8]) -> String {
const HEX_CHARS: [u8; 16] = [b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', b'b', b'c', b'd', b'e', b'f'];
@ -66,6 +63,28 @@ pub fn splitmix64_inverse(mut x: u64) -> u64 {
x.to_le()
}
static mut RANDOM_STATE_0: u64 = 0;
static mut RANDOM_STATE_1: u64 = 0;
/// Get a non-cryptographic pseudorandom number.
pub fn random() -> u64 {
let (mut s0, mut s1) = unsafe { (RANDOM_STATE_0, RANDOM_STATE_1) };
if s0 == 0 {
s0 = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos() as u64;
}
if s1 == 0 {
s1 = splitmix64(std::process::id() as u64);
}
let s1_new = xorshift64(s1);
s0 = splitmix64(s0.wrapping_add(s1));
s1 = s1_new;
unsafe {
RANDOM_STATE_0 = s0;
RANDOM_STATE_1 = s1;
};
s0
}
/// Wrapper for tokio::spawn() that aborts tasks not yet completed when it is dropped.
pub struct AsyncTaskReaper {
ctr: AtomicUsize,
@ -74,22 +93,23 @@ pub struct AsyncTaskReaper {
impl AsyncTaskReaper {
pub fn new() -> Self {
Self {
ctr: AtomicUsize::new(0),
handles: Arc::new(std::sync::Mutex::new(HashMap::new()))
}
Self { ctr: AtomicUsize::new(0), handles: Arc::new(std::sync::Mutex::new(HashMap::new())) }
}
/// Spawn a new task.
/// Note that currently any output is ignored. This is primarily for background tasks
/// that are used similarly to goroutines in Go.
///
/// Note that currently any task output is ignored. This is for fire and forget
/// background tasks that you want to be collected on loss of scope.
pub fn spawn<F: Future + Send + 'static>(&self, future: F) {
let id = self.ctr.fetch_add(1, Ordering::Relaxed);
let handles = self.handles.clone();
self.handles.lock().unwrap().insert(id, tokio::spawn(async move {
let _ = future.await;
let _ = handles.lock().unwrap().remove(&id);
}));
self.handles.lock().unwrap().insert(
id,
tokio::spawn(async move {
let _ = future.await;
let _ = handles.lock().unwrap().remove(&id);
}),
);
}
}

View file

@ -6,10 +6,7 @@
* https://www.zerotier.com/
*/
use std::io::{Read, Write};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const VARINT_MAX_SIZE_BYTES: usize = 10;
pub const VARINT_MAX_SIZE_BYTES: usize = 10;
pub fn encode(b: &mut [u8], mut v: u64) -> usize {
let mut i = 0;
@ -27,50 +24,21 @@ pub fn encode(b: &mut [u8], mut v: u64) -> usize {
i
}
#[inline(always)]
pub async fn write_async<W: AsyncWrite + Unpin>(w: &mut W, v: u64) -> std::io::Result<()> {
let mut b = [0_u8; VARINT_MAX_SIZE_BYTES];
let i = encode(&mut b, v);
w.write_all(&b[0..i]).await
}
pub async fn read_async<R: AsyncRead + Unpin>(r: &mut R) -> std::io::Result<(u64, usize)> {
pub fn decode(b: &[u8]) -> (u64, usize) {
let mut v = 0_u64;
let mut pos = 0;
let mut count = 0;
loop {
count += 1;
let b = r.read_u8().await?;
if b <= 0x7f {
v |= (b as u64).wrapping_shl(pos);
let mut l = 0;
let bl = b.len();
while l < bl {
let x = b[l];
l += 1;
if x <= 0x7f {
v |= (x as u64).wrapping_shl(pos);
pos += 7;
} else {
v |= ((b & 0x7f) as u64).wrapping_shl(pos);
return Ok((v, count));
}
}
}
#[inline(always)]
pub fn write<W: Write>(w: &mut W, v: u64) -> std::io::Result<()> {
let mut b = [0_u8; VARINT_MAX_SIZE_BYTES];
let i = encode(&mut b, v);
w.write_all(&b[0..i])
}
pub fn read<R: Read>(r: &mut R) -> std::io::Result<u64> {
let mut v = 0_u64;
let mut buf = [0_u8; 1];
let mut pos = 0;
loop {
let _ = r.read_exact(&mut buf)?;
let b = buf[0];
if b <= 0x7f {
v |= (b as u64).wrapping_shl(pos);
pos += 7;
} else {
v |= ((b & 0x7f) as u64).wrapping_shl(pos);
return Ok(v);
v |= ((x & 0x7f) as u64).wrapping_shl(pos);
return (v, l);
}
}
return (0, 0);
}

View file

@ -39,7 +39,6 @@ pub fn ms_monotonic() -> i64 {
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
pub fn ms_monotonic() -> i64 {
std::time::Instant::now().elapsed().as_millis() as i64
}
pub fn parse_bool(v: &str) -> Result<bool, String> {