More sync stuff.

This commit is contained in:
Adam Ierymenko 2022-03-02 16:15:34 -05:00
parent ee6fc671e4
commit 44f42ef608
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
7 changed files with 371 additions and 147 deletions

72
syncwhole/Cargo.lock generated
View file

@ -14,6 +14,15 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "block-buffer"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324"
dependencies = [
"generic-array",
]
[[package]]
name = "byteorder"
version = "1.4.3"
@ -32,6 +41,45 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cpufeatures"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469"
dependencies = [
"libc",
]
[[package]]
name = "crypto-common"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "digest"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506"
dependencies = [
"block-buffer",
"crypto-common",
]
[[package]]
name = "generic-array"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "libc"
version = "0.2.119"
@ -205,6 +253,17 @@ dependencies = [
"syn",
]
[[package]]
name = "sha2"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55deaec60f81eefe3cce0dc50bda92d6d8e88f2a27df7c5033b42afeb1ed2676"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "smallvec"
version = "1.8.0"
@ -239,6 +298,7 @@ dependencies = [
"rmp",
"rmp-serde",
"serde",
"sha2",
"tokio",
]
@ -258,12 +318,24 @@ dependencies = [
"winapi",
]
[[package]]
name = "typenum"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "unicode-xid"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "winapi"
version = "0.3.9"

View file

@ -10,3 +10,6 @@ tokio = { version = "^1", features = ["net", "rt", "parking_lot", "time", "io-st
serde = { version = "^1", features = ["derive"], default-features = false }
rmp = "^0"
rmp-serde = "^1"
[dev-dependencies]
sha2 = "^0"

View file

@ -27,13 +27,13 @@ pub enum StoreResult {
/// Entry was accepted (whether or not an old value was replaced).
Ok,
/// Entry was rejected as a duplicate but was otherwise valid.
/// Entry was a duplicate of one we already have but was otherwise valid.
Duplicate,
/// Entry was rejected as invalid.
///
/// An invalid entry is one that is malformed, fails a signature check, etc., and returning
/// this causes the synchronization service to drop the link to the node that sent it.
/// Entry was valid but was ignored for an unspecified reason.
Ignored,
/// Entry was rejected as malformed or otherwise invalid (e.g. failed signature check).
Rejected
}
@ -45,12 +45,19 @@ pub enum StoreResult {
/// what time I think it is" value to be considered locally so that data can be replicated
/// as of any given time.
///
/// The constants KEY_SIZE, MAX_VALUE_SIZE, and KEY_IS_COMPUTED are protocol constants
/// for your replication domain. They can't be changed once defined unless all nodes
/// are upgraded at once.
///
/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a function of
/// values. If this is true, get_key() must be implemented.
///
/// The implementation must be thread safe.
/// The implementation must be thread safe and may be called concurrently.
pub trait DataStore: Sync + Send {
/// Type to be enclosed in the Ok() enum value in LoadResult.
///
/// Allowing this type to be defined lets you use any type that makes sense with
/// your implementation. Examples include Box<[u8]>, Arc<[u8]>, Vec<u8>, etc.
type LoadResultValueType: AsRef<[u8]> + Send;
/// Size of keys, which must be fixed in length. These are typically hashes.
@ -61,11 +68,14 @@ pub trait DataStore: Sync + Send {
/// This should be true if the key is computed, such as by hashing the value.
///
/// If this is true only values are sent over the wire and get_key() is used to compute
/// keys from values. If this is false both keys and values are replicated.
/// If this is true then keys do not have to be sent over the wire. Instead they
/// are computed by calling get_key(). If this is false keys are assumed not to
/// be computable from values and are explicitly sent.
const KEY_IS_COMPUTED: bool;
/// Get the key corresponding to a value.
/// Compute the key corresponding to a value.
///
/// The 'key' slice must be KEY_SIZE bytes in length.
///
/// If KEY_IS_COMPUTED is true this must be implemented. The default implementation
/// panics to indicate this. If KEY_IS_COMPUTED is false this is never called.
@ -74,44 +84,56 @@ pub trait DataStore: Sync + Send {
panic!("get_key() must be implemented if KEY_IS_COMPUTED is true");
}
/// Get the domain of this data store, which is just an arbitrary unique identifier.
/// Get the domain of this data store.
///
/// This is an arbitrary unique identifier that must be the same for all nodes that
/// are replicating the same data. It's checked on connect to avoid trying to share
/// data across data sets if this is not desired.
fn domain(&self) -> &str;
/// Get an item if it exists as of a given reference time.
///
/// The supplied key must be of length KEY_SIZE or this may panic.
fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult<Self::LoadResultValueType>;
/// Store an item in the data store and return Ok, Duplicate, or Rejected.
///
/// The supplied key must be of length KEY_SIZE or this may panic.
/// Store an item in the data store and return its status.
///
/// Note that no time is supplied here. The data store must determine this in an implementation
/// dependent manner if this is a temporally subjective data store. It could be determined by
/// the wall clock, from the object itself, etc.
///
/// The implementation is responsible for validating inputs and returning 'Rejected' if they
/// are invalid. A return value of Rejected can be used to do things like drop connections
/// to peers that send invalid data, so it should only be returned if the data is malformed
/// or something like a signature check fails. The 'Ignored' enum value should be returned
/// for inputs that are valid but were not stored for some other reason, such as being
/// expired. It's important to return 'Ok' for accepted values to hint to the replicator
/// that they should be aggressively advertised to other peers.
///
/// If KEY_IS_COMPUTED is true, the key supplied here can be assumed to be correct. It will
/// have been computed via get_key().
fn store(&self, key: &[u8], value: &[u8]) -> StoreResult;
/// Get the number of items under a prefix as of a given reference time.
///
/// The default implementation uses for_each(). This can be specialized if it can be done
/// more efficiently than that.
/// The default implementation uses for_each_key(). This can be specialized if it can
/// be done more efficiently than that.
fn count(&self, reference_time: i64, key_prefix: &[u8]) -> u64 {
let mut cnt: u64 = 0;
self.for_each(reference_time, key_prefix, |_, _| {
self.for_each_key(reference_time, key_prefix, |_| {
cnt += 1;
true
});
cnt
}
/// Iterate through keys beneath a key prefix, stopping at the end or if the function returns false.
/// Iterate through keys beneath a key prefix, stopping if the function returns false.
///
/// The default implementation uses for_each().
/// The default implementation uses for_each(). This can be specialized if it's faster to
/// only load keys.
fn for_each_key<F: FnMut(&[u8]) -> bool>(&self, reference_time: i64, key_prefix: &[u8], mut f: F) {
self.for_each(reference_time, key_prefix, |k, _| f(k));
}
/// Iterate through keys and values beneath a key prefix, stopping at the end or if the function returns false.
/// Iterate through keys and values beneath a key prefix, stopping if the function returns false.
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, reference_time: i64, key_prefix: &[u8], f: F);
}
@ -196,4 +218,3 @@ impl<const KEY_SIZE: usize> PartialEq for MemoryDatabase<KEY_SIZE> {
}
impl<const KEY_SIZE: usize> Eq for MemoryDatabase<KEY_SIZE> {}

View file

@ -14,32 +14,51 @@ use crate::node::RemoteNodeInfo;
/// A trait that users of syncwhole implement to provide configuration information and listen for events.
pub trait Host: Sync + Send {
/// Compute SHA512.
fn sha512(msg: &[u8]) -> [u8; 64];
/// Get a list of endpoints to which we always want to try to stay connected.
/// Get a list of peer addresses to which we always want to try to stay connected.
///
/// The node will repeatedly try to connect to these until a link is established and
/// reconnect on link failure. They should be stable well known nodes for this domain.
fn get_static_endpoints(&self) -> &[SocketAddr];
/// These are always contacted until a link is established regardless of anything else.
fn fixed_peers(&self) -> &[SocketAddr];
/// Get additional endpoints to try.
///
/// This should return any endpoints not in the supplied endpoint set if the size
/// of the set is less than the minimum active link count the host wishes to maintain.
fn get_more_endpoints(&self, current_endpoints: &HashSet<SocketAddr>) -> Vec<SocketAddr>;
/// Get a random peer address not in the supplied set.
fn another_peer(&self, exclude: &HashSet<SocketAddr>) -> Option<SocketAddr>;
/// Get the maximum number of endpoints allowed.
///
/// This is checked on incoming connect and incoming links are refused if the total is over this count.
fn max_endpoints(&self) -> usize;
/// This is checked on incoming connect and incoming links are refused if the total is
/// over this count. Fixed endpoints will be contacted even if the total is over this limit.
fn max_connection_count(&self) -> usize;
/// Called whenever we have successfully connected to a remote node (after connection is initialized).
/// Get the number of connections we ideally want.
///
/// Attempts will be made to lazily contact remote endpoints if the total number of links
/// is under this amount. Note that fixed endpoints will still be contacted even if the
/// total is over the desired count.
///
/// This should always be less than max_connection_count().
fn desired_connection_count(&self) -> usize;
/// Test whether an inbound connection should be allowed from an address.
fn allow(&self, remote_address: &SocketAddr) -> bool;
/// Called when an attempt is made to connect to a remote address.
fn on_connect_attempt(&self, address: &SocketAddr);
/// Called when a connection has been successfully established.
///
/// Hosts are encouraged to learn endpoints when a successful outbound connection is made. Check the
/// inbound flag in the remote node info structure.
fn on_connect(&self, info: &RemoteNodeInfo);
/// Called when an open connection is closed.
fn on_connection_closed(&self, endpoint: &SocketAddr, reason: Option<Box<dyn Error>>);
/// Called when an open connection is closed for any reason.
fn on_connection_closed(&self, address: &SocketAddr, reason: Option<Box<dyn Error>>);
/// Fill a buffer with secure random bytes.
///
/// This is supplied to reduce inherent dependencies and allow the user to choose the implementation.
fn get_secure_random(&self, buf: &mut [u8]);
/// Compute a SHA512 digest of the input.
///
/// This is supplied to reduce inherent dependencies and allow the user to choose the implementation.
fn sha512(msg: &[u8]) -> [u8; 64];
}

View file

@ -6,21 +6,20 @@
* https://www.zerotier.com/
*/
use std::mem::{size_of, zeroed};
use std::alloc::{alloc_zeroed, dealloc, Layout};
use std::mem::size_of;
use std::ptr::write_bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::varint;
// max value: 6, 5 was determined to be good via empirical testing
const KEY_MAPPING_ITERATIONS: usize = 5;
#[inline(always)]
fn xorshift64(mut x: u64) -> u64 {
x = u64::from_le(x);
x ^= x.wrapping_shl(13);
x ^= x.wrapping_shr(7);
x ^= x.wrapping_shl(17);
x
x.to_le()
}
#[inline(always)]
@ -45,22 +44,9 @@ fn splitmix64_inverse(mut x: u64) -> u64 {
x.to_le()
}
// https://nullprogram.com/blog/2018/07/31/
#[inline(always)]
fn triple32(mut x: u32) -> u32 {
x ^= x.wrapping_shr(17);
x = x.wrapping_mul(0xed5ad4bb);
x ^= x.wrapping_shr(11);
x = x.wrapping_mul(0xac4c1b51);
x ^= x.wrapping_shr(15);
x = x.wrapping_mul(0x31848bab);
x ^= x.wrapping_shr(14);
x
}
#[inline(always)]
fn next_iteration_index(prev_iteration_index: u64, salt: u64) -> u64 {
prev_iteration_index.wrapping_add(triple32(prev_iteration_index.wrapping_shr(32) as u32) as u64).wrapping_add(salt)
fn next_iteration_index(prev_iteration_index: u64) -> u64 {
splitmix64(prev_iteration_index.wrapping_add(1))
}
#[derive(Clone, PartialEq, Eq)]
@ -84,29 +70,31 @@ impl IBLTEntry {
/// An Invertible Bloom Lookup Table for set reconciliation with 64-bit hashes.
#[derive(Clone, PartialEq, Eq)]
pub struct IBLT<const B: usize> {
salt: u64,
map: [IBLTEntry; B]
map: *mut [IBLTEntry; B]
}
impl<const B: usize> IBLT<B> {
/// Number of buckets (capacity) of this IBLT.
pub const BUCKETS: usize = B;
pub fn new(salt: u64) -> Self {
/// This was determined to be effective via empirical testing with random keys. This
/// is a protocol constant that can't be changed without upgrading all nodes in a domain.
const KEY_MAPPING_ITERATIONS: usize = 2;
pub fn new() -> Self {
assert!(B < u32::MAX as usize); // sanity check
Self {
salt,
map: unsafe { zeroed() }
map: unsafe { alloc_zeroed(Layout::new::<[IBLTEntry; B]>()).cast() }
}
}
pub fn reset(&mut self, salt: u64) {
self.salt = salt;
unsafe { write_bytes((&mut self.map as *mut IBLTEntry).cast::<u8>(), 0, size_of::<[IBLTEntry; B]>()) };
pub fn reset(&mut self) {
unsafe { write_bytes(self.map.cast::<u8>(), 0, size_of::<[IBLTEntry; B]>()) };
}
pub async fn read<R: AsyncReadExt + Unpin>(&mut self, r: &mut R) -> std::io::Result<()> {
r.read_exact(unsafe { &mut *(&mut self.salt as *mut u64).cast::<[u8; 8]>() }).await?;
let mut prev_c = 0_i64;
for b in self.map.iter_mut() {
for b in unsafe { (*self.map).iter_mut() } {
let _ = r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() }).await?;
let _ = r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() }).await?;
let mut c = varint::read_async(r).await? as i64;
@ -122,9 +110,8 @@ impl<const B: usize> IBLT<B> {
}
pub async fn write<W: AsyncWriteExt + Unpin>(&self, w: &mut W) -> std::io::Result<()> {
let _ = w.write_all(unsafe { &*(&self.salt as *const u64).cast::<[u8; 8]>() }).await?;
let mut prev_c = 0_i64;
for b in self.map.iter() {
for b in unsafe { (*self.map).iter() } {
let _ = w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() }).await?;
let _ = w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() }).await?;
let mut c = (b.count - prev_c).wrapping_shl(1);
@ -137,13 +124,25 @@ impl<const B: usize> IBLT<B> {
Ok(())
}
/// Get this IBLT as a byte array.
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = std::io::Cursor::new(Vec::<u8>::with_capacity(B * 20));
let current = tokio::runtime::Handle::try_current();
if current.is_ok() {
assert!(current.unwrap().block_on(self.write(&mut out)).is_ok());
} else {
assert!(tokio::runtime::Builder::new_current_thread().build().unwrap().block_on(self.write(&mut out)).is_ok());
}
out.into_inner()
}
fn ins_rem(&mut self, mut key: u64, delta: i64) {
key = splitmix64(key ^ self.salt);
key = splitmix64(key);
let check_hash = xorshift64(key);
let mut iteration_index = u64::from_le(key);
for _ in 0..KEY_MAPPING_ITERATIONS {
iteration_index = next_iteration_index(iteration_index, self.salt);
let b = unsafe { self.map.get_unchecked_mut((iteration_index as usize) % B) };
for _ in 0..Self::KEY_MAPPING_ITERATIONS {
iteration_index = next_iteration_index(iteration_index);
let b = unsafe { (*self.map).get_unchecked_mut((iteration_index as usize) % B) };
b.key_sum ^= key;
b.check_hash_sum ^= check_hash;
b.count += delta;
@ -166,53 +165,51 @@ impl<const B: usize> IBLT<B> {
self.ins_rem(unsafe { u64::from_ne_bytes(*(key.as_ptr().cast::<[u8; 8]>())) }, -1);
}
/// Subtract another IBLT from this one to get a set difference.
pub fn subtract(&mut self, other: &Self) {
for b in 0..B {
let s = &mut self.map[b];
let o = &other.map[b];
for (s, o) in unsafe { (*self.map).iter_mut().zip((*other.map).iter()) } {
s.key_sum ^= o.key_sum;
s.check_hash_sum ^= o.check_hash_sum;
s.count += o.count;
s.count -= o.count;
}
}
pub fn list<F: FnMut(&[u8; 8])>(mut self, mut f: F) -> bool {
let mut singular_buckets = [0_usize; B];
let mut singular_bucket_count = 0_usize;
let mut queue: Vec<u32> = Vec::with_capacity(B);
for b in 0..B {
if self.map[b].is_singular() {
singular_buckets[singular_bucket_count] = b;
singular_bucket_count += 1;
if unsafe { (*self.map).get_unchecked(b).is_singular() } {
queue.push(b as u32);
}
}
while singular_bucket_count > 0 {
singular_bucket_count -= 1;
let b = &self.map[singular_buckets[singular_bucket_count]];
loop {
let b = queue.pop();
let b = if b.is_some() {
unsafe { (*self.map).get_unchecked_mut(b.unwrap() as usize) }
} else {
break;
};
if b.is_singular() {
let key = b.key_sum;
let check_hash = xorshift64(key);
let mut iteration_index = u64::from_le(key);
f(&(splitmix64_inverse(key) ^ self.salt).to_ne_bytes());
f(&(splitmix64_inverse(key)).to_ne_bytes());
for _ in 0..KEY_MAPPING_ITERATIONS {
iteration_index = next_iteration_index(iteration_index, self.salt);
let b_idx = (iteration_index as usize) % B;
let b = unsafe { self.map.get_unchecked_mut(b_idx) };
for _ in 0..Self::KEY_MAPPING_ITERATIONS {
iteration_index = next_iteration_index(iteration_index);
let b_idx = iteration_index % (B as u64);
let b = unsafe { (*self.map).get_unchecked_mut(b_idx as usize) };
b.key_sum ^= key;
b.check_hash_sum ^= check_hash;
b.count -= 1;
if b.is_singular() {
if singular_bucket_count >= B {
// This would indicate an invalid IBLT.
if queue.len() >= (B * 2) { // sanity check for invalid IBLT
return false;
}
singular_buckets[singular_bucket_count] = b_idx;
singular_bucket_count += 1;
queue.push(b_idx as u32);
}
}
}
@ -222,6 +219,12 @@ impl<const B: usize> IBLT<B> {
}
}
impl<const B: usize> Drop for IBLT<B> {
fn drop(&mut self) {
unsafe { dealloc(self.map.cast(), Layout::new::<[IBLTEntry; B]>()) };
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
@ -230,32 +233,75 @@ mod tests {
#[test]
fn splitmix_is_invertiblex() {
for i in 1..1024_u64 {
for i in 1..2000_u64 {
assert_eq!(i, splitmix64_inverse(splitmix64(i)))
}
}
#[test]
fn insert_and_list() {
fn fill_list_performance() {
let mut rn = xorshift64(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos() as u64);
for _ in 0..256 {
let mut alice: IBLT<1024> = IBLT::new(rn);
rn = xorshift64(rn);
let mut expected: HashSet<u64> = HashSet::with_capacity(1024);
let count = 600;
let mut expected: HashSet<u64> = HashSet::with_capacity(4096);
let mut count = 64;
const CAPACITY: usize = 4096;
while count <= CAPACITY {
let mut test: IBLT<CAPACITY> = IBLT::new();
expected.clear();
for _ in 0..count {
let x = rn;
rn = xorshift64(rn);
rn = splitmix64(rn);
expected.insert(x);
alice.insert(&x.to_ne_bytes());
test.insert(&x.to_ne_bytes());
}
let mut cnt = 0;
alice.list(|x| {
let mut list_count = 0;
test.list(|x| {
let x = u64::from_ne_bytes(*x);
cnt += 1;
list_count += 1;
assert!(expected.contains(&x));
});
assert_eq!(cnt, count);
//println!("inserted: {}\tlisted: {}\tcapacity: {}\tscore: {:.4}\tfill: {:.4}", count, list_count, CAPACITY, (list_count as f64) / (count as f64), (count as f64) / (CAPACITY as f64));
count += 64;
}
}
#[test]
fn merge_sets() {
let mut rn = xorshift64(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos() as u64);
const CAPACITY: usize = 16384;
const REMOTE_SIZE: usize = 1024 * 1024;
const STEP: usize = 1024;
let mut missing_count = 1024;
let mut missing: HashSet<u64> = HashSet::with_capacity(CAPACITY);
while missing_count <= CAPACITY {
missing.clear();
let mut local: IBLT<CAPACITY> = IBLT::new();
let mut remote: IBLT<CAPACITY> = IBLT::new();
for k in 0..REMOTE_SIZE {
if k >= missing_count {
local.insert(&rn.to_ne_bytes());
} else {
missing.insert(rn);
}
remote.insert(&rn.to_ne_bytes());
rn = splitmix64(rn);
}
local.subtract(&mut remote);
let bytes = local.to_bytes().len();
let mut cnt = 0;
local.list(|k| {
let k = u64::from_ne_bytes(*k);
assert!(missing.contains(&k));
cnt += 1;
});
println!("total: {} missing: {:5} recovered: {:5} size: {:5} score: {:.4} bytes/item: {:.2}", REMOTE_SIZE, missing.len(), cnt, bytes, (cnt as f64) / (missing.len() as f64), (bytes as f64) / (cnt as f64));
missing_count += STEP;
}
}
}

View file

@ -6,7 +6,7 @@
* https://www.zerotier.com/
*/
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::{Arc, Weak};
use std::sync::atomic::{AtomicI64, Ordering};
@ -26,20 +26,46 @@ use crate::protocol::*;
use crate::varint;
const CONNECTION_TIMEOUT: i64 = 60000;
const CONNECTION_KEEPALIVE_AFTER: i64 = 20000;
const CONNECTION_KEEPALIVE_AFTER: i64 = CONNECTION_TIMEOUT / 3;
const HOUSEKEEPING_INTERVAL: i64 = CONNECTION_KEEPALIVE_AFTER / 2;
const IO_BUFFER_SIZE: usize = 65536;
#[derive(Clone, PartialEq, Eq)]
/// Information about a remote node to which we are connected.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RemoteNodeInfo {
/// Optional name advertised by remote node (arbitrary).
pub node_name: Option<String>,
/// Optional contact information advertised by remote node (arbitrary).
pub node_contact: Option<String>,
pub endpoint: SocketAddr,
pub preferred_endpoints: Vec<SocketAddr>,
/// Actual remote endpoint address.
pub remote_address: SocketAddr,
/// Explicitly advertised remote addresses supplied by remote node (not necessarily verified).
pub explicit_addresses: Vec<SocketAddr>,
/// Time TCP connection was established.
pub connect_time: SystemTime,
/// True if this is an inbound TCP connection.
pub inbound: bool,
/// True if this connection has exchanged init messages.
pub initialized: bool,
}
fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> {
if socket.set_reuseport(true).is_err() {
socket.set_reuseaddr(true)?;
}
Ok(())
}
/// An instance of the syncwhole data set synchronization engine.
///
/// This holds a number of async tasks that are terminated or aborted if this object
/// is dropped. In other words this implements structured concurrency.
pub struct Node<D: DataStore + 'static, H: Host + 'static> {
internal: Arc<NodeInternal<D, H>>,
housekeeping_task: JoinHandle<()>,
@ -49,9 +75,7 @@ pub struct Node<D: DataStore + 'static, H: Host + 'static> {
impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
pub async fn new(db: Arc<D>, host: Arc<H>, bind_address: SocketAddr) -> std::io::Result<Self> {
let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
if listener.set_reuseport(true).is_err() {
listener.set_reuseaddr(true)?;
}
configure_tcp_socket(&listener)?;
listener.bind(bind_address.clone())?;
let listener = listener.listen(1024)?;
@ -65,6 +89,7 @@ impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
host: host.clone(),
bind_address,
connections: Mutex::new(HashMap::with_capacity(64)),
attempts: Mutex::new(HashMap::with_capacity(64)),
});
Ok(Self {
@ -76,7 +101,7 @@ impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
#[inline(always)]
pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> {
self.internal.connect(endpoint).await
self.internal.clone().connect(endpoint).await
}
pub fn list_connections(&self) -> Vec<RemoteNodeInfo> {
@ -105,17 +130,20 @@ pub struct NodeInternal<D: DataStore + 'static, H: Host + 'static> {
host: Arc<H>,
bind_address: SocketAddr,
connections: Mutex<HashMap<SocketAddr, (Weak<Connection>, Option<JoinHandle<std::io::Result<()>>>)>>,
attempts: Mutex<HashMap<SocketAddr, JoinHandle<std::io::Result<bool>>>>,
}
impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
async fn housekeeping_task_main(self: Arc<Self>) {
loop {
tokio::time::sleep(Duration::from_millis((CONNECTION_KEEPALIVE_AFTER / 2) as u64)).await;
tokio::time::sleep(Duration::from_millis(HOUSEKEEPING_INTERVAL as u64)).await;
let mut to_ping: Vec<Arc<Connection>> = Vec::new();
let mut dead: Vec<(SocketAddr, Option<JoinHandle<std::io::Result<()>>>)> = Vec::new();
let mut current_endpoints: HashSet<SocketAddr> = HashSet::new();
let mut connections = self.connections.lock().await;
current_endpoints.reserve(connections.len() + 1);
let now = ms_monotonic();
connections.retain(|sa, c| {
let cc = c.0.upgrade();
@ -125,6 +153,7 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
if (now - cc.last_send_time.load(Ordering::Relaxed)) >= CONNECTION_KEEPALIVE_AFTER {
to_ping.push(cc);
}
current_endpoints.insert(sa.clone());
true
} else {
c.1.take().map(|j| j.abort());
@ -152,18 +181,46 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
for c in to_ping.iter() {
let _ = c.send(&[MESSAGE_TYPE_NOP, 0], now).await;
}
let desired = self.host.desired_connection_count();
let fixed = self.host.fixed_peers();
let mut attempts = self.attempts.lock().await;
for ep in fixed.iter() {
if !current_endpoints.contains(ep) {
let self2 = self.clone();
let ep2 = ep.clone();
attempts.insert(ep.clone(), tokio::spawn(async move { self2.connect(&ep2).await }));
current_endpoints.insert(ep.clone());
}
}
while current_endpoints.len() < desired {
let ep = self.host.another_peer(&current_endpoints);
if ep.is_some() {
let ep = ep.unwrap();
current_endpoints.insert(ep.clone());
let self2 = self.clone();
attempts.insert(ep.clone(), tokio::spawn(async move { self2.connect(&ep).await }));
} else {
break;
}
}
}
}
async fn listener_task_main(self: Arc<Self>, listener: TcpListener) {
loop {
let socket = listener.accept().await;
if self.connections.lock().await.len() < self.host.max_endpoints() && socket.is_ok() {
if self.connections.lock().await.len() < self.host.max_connection_count() && socket.is_ok() {
let (stream, endpoint) = socket.unwrap();
if self.host.allow(&endpoint) {
Self::connection_start(&self, endpoint, stream, true).await;
}
}
}
}
async fn connection_io_task_main(self: Arc<Self>, connection: Arc<Connection>, reader: OwnedReadHalf) -> std::io::Result<()> {
let mut challenge = [0_u8; 16];
@ -175,8 +232,8 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
max_value_size: D::MAX_VALUE_SIZE as u64,
node_name: None,
node_contact: None,
preferred_ipv4: None,
preferred_ipv6: None
explicit_ipv4: None,
explicit_ipv6: None
}, ms_monotonic()).await?;
let mut init_received = false;
@ -194,6 +251,7 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
let now = ms_monotonic();
match message_type {
MESSAGE_TYPE_INIT => {
if init_received {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init"));
@ -220,13 +278,14 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
let mut info = connection.info.lock().unwrap();
info.node_name = msg.node_name.clone();
info.node_contact = msg.node_contact.clone();
let _ = msg.preferred_ipv4.map(|pv4| {
info.preferred_endpoints.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(pv4.ip[0], pv4.ip[1], pv4.ip[2], pv4.ip[3]), pv4.port)));
let _ = msg.explicit_ipv4.map(|pv4| {
info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port)));
});
let _ = msg.preferred_ipv6.map(|pv6| {
info.preferred_endpoints.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0)));
let _ = msg.explicit_ipv6.map(|pv6| {
info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0)));
});
},
MESSAGE_TYPE_INIT_RESPONSE => {
if initialized {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init response"));
@ -246,6 +305,7 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
let info = info.clone();
self.host.on_connect(&info);
},
_ => {
// Skip messages that aren't recognized or don't need to be parsed like NOP.
let mut remaining = message_size as usize;
@ -256,16 +316,16 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
}
connection.last_receive_time.store(ms_monotonic(), Ordering::Relaxed);
}
}
}
}
async fn connection_start(self: &Arc<Self>, endpoint: SocketAddr, stream: TcpStream, inbound: bool) -> bool {
let _ = stream.set_nodelay(true);
async fn connection_start(self: &Arc<Self>, address: SocketAddr, stream: TcpStream, inbound: bool) -> bool {
let (reader, writer) = stream.into_split();
let mut ok = false;
let _ = self.connections.lock().await.entry(endpoint.clone()).or_insert_with(|| {
let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| {
ok = true;
let now = ms_monotonic();
let connection = Arc::new(Connection {
@ -275,8 +335,8 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
info: std::sync::Mutex::new(RemoteNodeInfo {
node_name: None,
node_contact: None,
endpoint: endpoint.clone(),
preferred_endpoints: Vec::new(),
remote_address: address.clone(),
explicit_addresses: Vec::new(),
connect_time: SystemTime::now(),
inbound,
initialized: false
@ -287,23 +347,26 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
ok
}
async fn connect(self: &Arc<Self>, endpoint: &SocketAddr) -> std::io::Result<bool> {
if !self.connections.lock().await.contains_key(endpoint) {
let stream = if endpoint.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
if stream.set_reuseport(true).is_err() {
stream.set_reuseaddr(true)?;
}
async fn connect(self: Arc<Self>, address: &SocketAddr) -> std::io::Result<bool> {
let mut success = false;
if !self.connections.lock().await.contains_key(address) {
self.host.on_connect_attempt(address);
let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
configure_tcp_socket(&stream)?;
stream.bind(self.bind_address.clone())?;
let stream = stream.connect(endpoint.clone()).await?;
Ok(self.connection_start(endpoint.clone(), stream, false).await)
} else {
Ok(false)
let stream = stream.connect(address.clone()).await?;
success = self.connection_start(address.clone(), stream, false).await;
}
self.attempts.lock().await.remove(address);
Ok(success)
}
}
impl<D: DataStore + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
fn drop(&mut self) {
for a in self.attempts.blocking_lock().iter() {
a.1.abort();
}
for (_, c) in self.connections.blocking_lock().drain() {
c.1.map(|c| c.abort());
}

View file

@ -45,10 +45,10 @@ pub mod msg {
pub node_name: Option<String>,
#[serde(rename = "nc")]
pub node_contact: Option<String>,
#[serde(rename = "pi4")]
pub preferred_ipv4: Option<IPv4>,
#[serde(rename = "pi6")]
pub preferred_ipv6: Option<IPv6>,
#[serde(rename = "ei4")]
pub explicit_ipv4: Option<IPv4>,
#[serde(rename = "ei6")]
pub explicit_ipv6: Option<IPv6>,
}
#[derive(Serialize, Deserialize)]