mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-04-26 17:03:43 +02:00
A bunch more syncwhole work and self test code.
This commit is contained in:
parent
44f42ef608
commit
2158675fd2
10 changed files with 639 additions and 210 deletions
63
syncwhole/Cargo.lock
generated
63
syncwhole/Cargo.lock
generated
|
@ -16,9 +16,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
|||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.2"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324"
|
||||
checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
@ -51,23 +51,12 @@ dependencies = [
|
|||
]
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.3"
|
||||
name = "digest"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8"
|
||||
checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -80,6 +69,15 @@ dependencies = [
|
|||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.119"
|
||||
|
@ -150,6 +148,22 @@ dependencies = [
|
|||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.0"
|
||||
|
@ -242,6 +256,15 @@ dependencies = [
|
|||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_bytes"
|
||||
version = "0.11.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "16ae07dd2f88a366f15bd0632ba725227018c69a1c8550a927324f8eb8368bb9"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.136"
|
||||
|
@ -255,13 +278,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.2"
|
||||
version = "0.9.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55deaec60f81eefe3cce0dc50bda92d6d8e88f2a27df7c5033b42afeb1ed2676"
|
||||
checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
"opaque-debug",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -298,6 +323,7 @@ dependencies = [
|
|||
"rmp",
|
||||
"rmp-serde",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
"sha2",
|
||||
"tokio",
|
||||
]
|
||||
|
@ -312,6 +338,7 @@ dependencies = [
|
|||
"libc",
|
||||
"memchr",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"socket2",
|
||||
|
|
|
@ -5,11 +5,30 @@ edition = "2021"
|
|||
license = "MPL-2.0"
|
||||
authors = ["Adam Ierymenko <adam.ierymenko@zerotier.com>"]
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
panic = 'abort'
|
||||
|
||||
[lib]
|
||||
name = "syncwhole"
|
||||
path = "src/lib.rs"
|
||||
doc = true
|
||||
|
||||
[[bin]]
|
||||
name = "syncwhole_local_testnet"
|
||||
path = "src/main.rs"
|
||||
doc = false
|
||||
required-features = ["include_sha2_lib"]
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "^1", features = ["net", "rt", "parking_lot", "time", "io-std", "io-util", "sync"], default-features = false }
|
||||
tokio = { version = "^1", features = ["net", "rt", "parking_lot", "time", "io-std", "io-util", "sync", "rt-multi-thread"], default-features = false }
|
||||
serde = { version = "^1", features = ["derive"], default-features = false }
|
||||
serde_bytes = "^0"
|
||||
rmp = "^0"
|
||||
rmp-serde = "^1"
|
||||
sha2 = { version = "^0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
sha2 = "^0"
|
||||
[features]
|
||||
include_sha2_lib = ["sha2"]
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
use std::ops::Bound::Included;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use crate::ms_since_epoch;
|
||||
|
||||
/// Result returned by DB::load().
|
||||
pub enum LoadResult<V: AsRef<[u8]> + Send> {
|
||||
|
@ -50,7 +51,7 @@ pub enum StoreResult {
|
|||
/// are upgraded at once.
|
||||
///
|
||||
/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a function of
|
||||
/// values. If this is true, get_key() must be implemented.
|
||||
/// values. If this is true, key_from_value() must be implemented.
|
||||
///
|
||||
/// The implementation must be thread safe and may be called concurrently.
|
||||
pub trait DataStore: Sync + Send {
|
||||
|
@ -75,15 +76,19 @@ pub trait DataStore: Sync + Send {
|
|||
|
||||
/// Compute the key corresponding to a value.
|
||||
///
|
||||
/// The 'key' slice must be KEY_SIZE bytes in length.
|
||||
///
|
||||
/// If KEY_IS_COMPUTED is true this must be implemented. The default implementation
|
||||
/// panics to indicate this. If KEY_IS_COMPUTED is false this is never called.
|
||||
#[allow(unused_variables)]
|
||||
fn get_key(&self, value: &[u8], key: &mut [u8]) {
|
||||
panic!("get_key() must be implemented if KEY_IS_COMPUTED is true");
|
||||
fn key_from_value(&self, value: &[u8], key_buffer: &mut [u8]) {
|
||||
panic!("key_from_value() must be implemented if KEY_IS_COMPUTED is true");
|
||||
}
|
||||
|
||||
/// Get the current wall clock in milliseconds since Unix epoch.
|
||||
///
|
||||
/// This is delegated to the data store to support scenarios where you want to fix
|
||||
/// the clock or snapshot at a given time.
|
||||
fn clock(&self) -> i64;
|
||||
|
||||
/// Get the domain of this data store.
|
||||
///
|
||||
/// This is an arbitrary unique identifier that must be the same for all nodes that
|
||||
|
@ -125,6 +130,9 @@ pub trait DataStore: Sync + Send {
|
|||
cnt
|
||||
}
|
||||
|
||||
/// Get the total number of records in this data store.
|
||||
fn total_count(&self) -> u64;
|
||||
|
||||
/// Iterate through keys beneath a key prefix, stopping if the function returns false.
|
||||
///
|
||||
/// The default implementation uses for_each(). This can be specialized if it's faster to
|
||||
|
@ -138,13 +146,13 @@ pub trait DataStore: Sync + Send {
|
|||
}
|
||||
|
||||
/// A simple in-memory data store backed by a BTreeMap.
|
||||
pub struct MemoryDatabase<const KEY_SIZE: usize> {
|
||||
pub struct MemoryDataStore<const KEY_SIZE: usize> {
|
||||
max_age: i64,
|
||||
domain: String,
|
||||
db: Mutex<BTreeMap<[u8; KEY_SIZE], (i64, Arc<[u8]>)>>
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> MemoryDatabase<KEY_SIZE> {
|
||||
impl<const KEY_SIZE: usize> MemoryDataStore<KEY_SIZE> {
|
||||
pub fn new(max_age: i64, domain: String) -> Self {
|
||||
Self {
|
||||
max_age: if max_age > 0 { max_age } else { i64::MAX },
|
||||
|
@ -154,12 +162,14 @@ impl<const KEY_SIZE: usize> MemoryDatabase<KEY_SIZE> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> DataStore for MemoryDatabase<KEY_SIZE> {
|
||||
impl<const KEY_SIZE: usize> DataStore for MemoryDataStore<KEY_SIZE> {
|
||||
type LoadResultValueType = Arc<[u8]>;
|
||||
const KEY_SIZE: usize = KEY_SIZE;
|
||||
const MAX_VALUE_SIZE: usize = 65536;
|
||||
const KEY_IS_COMPUTED: bool = false;
|
||||
|
||||
fn clock(&self) -> i64 { ms_since_epoch() }
|
||||
|
||||
fn domain(&self) -> &str { self.domain.as_str() }
|
||||
|
||||
fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult<Self::LoadResultValueType> {
|
||||
|
@ -196,6 +206,8 @@ impl<const KEY_SIZE: usize> DataStore for MemoryDatabase<KEY_SIZE> {
|
|||
}
|
||||
}
|
||||
|
||||
fn total_count(&self) -> u64 { self.db.lock().unwrap().len() as u64 }
|
||||
|
||||
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, reference_time: i64, key_prefix: &[u8], mut f: F) {
|
||||
let mut r_start = [0_u8; KEY_SIZE];
|
||||
let mut r_end = [0xff_u8; KEY_SIZE];
|
||||
|
@ -211,10 +223,20 @@ impl<const KEY_SIZE: usize> DataStore for MemoryDatabase<KEY_SIZE> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> PartialEq for MemoryDatabase<KEY_SIZE> {
|
||||
impl<const KEY_SIZE: usize> PartialEq for MemoryDataStore<KEY_SIZE> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.max_age == other.max_age && self.db.lock().unwrap().eq(&*other.db.lock().unwrap())
|
||||
self.max_age == other.max_age && self.domain == other.domain && self.db.lock().unwrap().eq(&*other.db.lock().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> Eq for MemoryDatabase<KEY_SIZE> {}
|
||||
impl<const KEY_SIZE: usize> Eq for MemoryDataStore<KEY_SIZE> {}
|
||||
|
||||
impl<const KEY_SIZE: usize> Clone for MemoryDataStore<KEY_SIZE> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
max_age: self.max_age,
|
||||
domain: self.domain.clone(),
|
||||
db: Mutex::new(self.db.lock().unwrap().clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,9 +7,11 @@
|
|||
*/
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::error::Error;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
#[cfg(feature = "include_sha2_lib")]
|
||||
use sha2::digest::{Digest, FixedOutput};
|
||||
|
||||
use crate::node::RemoteNodeInfo;
|
||||
|
||||
/// A trait that users of syncwhole implement to provide configuration information and listen for events.
|
||||
|
@ -20,13 +22,21 @@ pub trait Host: Sync + Send {
|
|||
fn fixed_peers(&self) -> &[SocketAddr];
|
||||
|
||||
/// Get a random peer address not in the supplied set.
|
||||
fn another_peer(&self, exclude: &HashSet<SocketAddr>) -> Option<SocketAddr>;
|
||||
///
|
||||
/// The default implementation just returns None.
|
||||
fn another_peer(&self, exclude: &HashSet<SocketAddr>) -> Option<SocketAddr> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the maximum number of endpoints allowed.
|
||||
///
|
||||
/// This is checked on incoming connect and incoming links are refused if the total is
|
||||
/// over this count. Fixed endpoints will be contacted even if the total is over this limit.
|
||||
fn max_connection_count(&self) -> usize;
|
||||
///
|
||||
/// The default implementation returns 1024.
|
||||
fn max_connection_count(&self) -> usize {
|
||||
1024
|
||||
}
|
||||
|
||||
/// Get the number of connections we ideally want.
|
||||
///
|
||||
|
@ -35,10 +45,47 @@ pub trait Host: Sync + Send {
|
|||
/// total is over the desired count.
|
||||
///
|
||||
/// This should always be less than max_connection_count().
|
||||
fn desired_connection_count(&self) -> usize;
|
||||
///
|
||||
/// The default implementation returns 128.
|
||||
fn desired_connection_count(&self) -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
/// Get an optional name that this node should advertise.
|
||||
///
|
||||
/// The default implementation returns None.
|
||||
fn name(&self) -> Option<&str> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Get an optional contact info string that this node should advertise.
|
||||
///
|
||||
/// The default implementation returns None.
|
||||
fn contact(&self) -> Option<&str> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Test whether an inbound connection should be allowed from an address.
|
||||
fn allow(&self, remote_address: &SocketAddr) -> bool;
|
||||
///
|
||||
/// This is called on first incoming connection before any init is received. The authenticate()
|
||||
/// method is called once init has been received and is another decision point. The default
|
||||
/// implementation of this always returns true.
|
||||
fn allow(&self, remote_address: &SocketAddr) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA512(secret, challenge).
|
||||
///
|
||||
/// A return of None indicates that the connection should be dropped. If authentication is
|
||||
/// not enabled, the response should be computed using an all-zero secret key. This is
|
||||
/// what the default implementation does, so if you don't want authentication there is no
|
||||
/// need to override and implement this.
|
||||
///
|
||||
/// This actually gets called twice per link: once when Init is received to compute the
|
||||
/// response, and once when InitResponse is received to verify the response to our challenge.
|
||||
fn authenticate(&self, info: &RemoteNodeInfo, challenge: &[u8]) -> Option<[u8; 64]> {
|
||||
Some(Self::hmac_sha512(&[0_u8; 64], challenge))
|
||||
}
|
||||
|
||||
/// Called when an attempt is made to connect to a remote address.
|
||||
fn on_connect_attempt(&self, address: &SocketAddr);
|
||||
|
@ -50,7 +97,7 @@ pub trait Host: Sync + Send {
|
|||
fn on_connect(&self, info: &RemoteNodeInfo);
|
||||
|
||||
/// Called when an open connection is closed for any reason.
|
||||
fn on_connection_closed(&self, address: &SocketAddr, reason: Option<Box<dyn Error>>);
|
||||
fn on_connection_closed(&self, info: &RemoteNodeInfo, reason: String);
|
||||
|
||||
/// Fill a buffer with secure random bytes.
|
||||
///
|
||||
|
@ -59,6 +106,38 @@ pub trait Host: Sync + Send {
|
|||
|
||||
/// Compute a SHA512 digest of the input.
|
||||
///
|
||||
/// This is supplied to reduce inherent dependencies and allow the user to choose the implementation.
|
||||
fn sha512(msg: &[u8]) -> [u8; 64];
|
||||
/// Input can consist of one or more slices that will be processed in order.
|
||||
///
|
||||
/// If the feature "include_sha2_lib" is enabled a default implementation in terms of the
|
||||
/// Rust sha2 crate is generated. Otherwise the user must supply their own implementation.
|
||||
#[cfg(not(feature = "include_sha2_lib"))]
|
||||
fn sha512(msg: &[&[u8]]) -> [u8; 64];
|
||||
#[cfg(feature = "include_sha2_lib")]
|
||||
fn sha512(msg: &[&[u8]]) -> [u8; 64] {
|
||||
let mut h = sha2::Sha512::new();
|
||||
for b in msg.iter() {
|
||||
h.update(*b);
|
||||
}
|
||||
h.finalize_fixed().as_ref().try_into().unwrap()
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA512 using key and input.
|
||||
///
|
||||
/// Supplied key will always be 64 bytes in length.
|
||||
///
|
||||
/// The default implementation is a basic HMAC implemented in terms of sha512() above. This
|
||||
/// can be specialized if the user wishes to provide their own implementation.
|
||||
fn hmac_sha512(key: &[u8], msg: &[u8]) -> [u8; 64] {
|
||||
let mut opad = [0x5c_u8; 128];
|
||||
let mut ipad = [0x36_u8; 128];
|
||||
assert!(key.len() >= 64);
|
||||
for i in 0..64 {
|
||||
opad[i] ^= key[i];
|
||||
}
|
||||
for i in 0..64 {
|
||||
ipad[i] ^= key[i];
|
||||
}
|
||||
let s1 = Self::sha512(&[&ipad, msg]);
|
||||
Self::sha512(&[&opad, &s1])
|
||||
}
|
||||
}
|
||||
|
|
|
@ -97,7 +97,7 @@ impl<const B: usize> IBLT<B> {
|
|||
for b in unsafe { (*self.map).iter_mut() } {
|
||||
let _ = r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() }).await?;
|
||||
let _ = r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() }).await?;
|
||||
let mut c = varint::read_async(r).await? as i64;
|
||||
let mut c = varint::read_async(r).await?.0 as i64;
|
||||
if (c & 1) == 0 {
|
||||
c = c.wrapping_shr(1);
|
||||
} else {
|
||||
|
@ -174,7 +174,7 @@ impl<const B: usize> IBLT<B> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn list<F: FnMut(&[u8; 8])>(mut self, mut f: F) -> bool {
|
||||
pub fn list<F: FnMut(&[u8; 8])>(self, mut f: F) -> bool {
|
||||
let mut queue: Vec<u32> = Vec::with_capacity(B);
|
||||
|
||||
for b in 0..B {
|
||||
|
|
|
@ -14,10 +14,10 @@ pub mod datastore;
|
|||
pub mod node;
|
||||
pub mod host;
|
||||
|
||||
pub(crate) fn ms_since_epoch() -> i64 {
|
||||
pub fn ms_since_epoch() -> i64 {
|
||||
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as i64
|
||||
}
|
||||
|
||||
pub(crate) fn ms_monotonic() -> i64 {
|
||||
pub fn ms_monotonic() -> i64 {
|
||||
std::time::Instant::now().elapsed().as_millis() as i64
|
||||
}
|
||||
|
|
136
syncwhole/src/main.rs
Normal file
136
syncwhole/src/main.rs
Normal file
|
@ -0,0 +1,136 @@
|
|||
extern crate core;
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::ops::Bound::Included;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use sha2::digest::Digest;
|
||||
use sha2::Sha512;
|
||||
|
||||
use syncwhole::datastore::{DataStore, LoadResult, StoreResult};
|
||||
use syncwhole::host::Host;
|
||||
use syncwhole::ms_since_epoch;
|
||||
use syncwhole::node::{Node, RemoteNodeInfo};
|
||||
|
||||
const TEST_NODE_COUNT: usize = 16;
|
||||
const TEST_PORT_RANGE_START: u16 = 21384;
|
||||
|
||||
struct TestNodeHost {
|
||||
name: String,
|
||||
peers: Vec<SocketAddr>,
|
||||
db: Mutex<BTreeMap<[u8; 64], Arc<[u8]>>>,
|
||||
}
|
||||
|
||||
impl Host for TestNodeHost {
|
||||
fn fixed_peers(&self) -> &[SocketAddr] { self.peers.as_slice() }
|
||||
|
||||
fn name(&self) -> Option<&str> { Some(self.name.as_str()) }
|
||||
|
||||
fn on_connect_attempt(&self, _address: &SocketAddr) {
|
||||
//println!("{:5}: connecting to {}", self.name, _address.to_string());
|
||||
}
|
||||
|
||||
fn on_connect(&self, info: &RemoteNodeInfo) {
|
||||
println!("{:5}: connected to {} ({}, {})", self.name, info.remote_address.to_string(), info.node_name.as_ref().map_or("null", |s| s.as_str()), if info.inbound { "inbound" } else { "outbound" });
|
||||
}
|
||||
|
||||
fn on_connection_closed(&self, info: &RemoteNodeInfo, reason: String) {
|
||||
println!("{:5}: closed connection to {}: {} ({}, {})", self.name, info.remote_address.to_string(), reason, if info.inbound { "inbound" } else { "outbound" }, if info.initialized { "initialized" } else { "not initialized" });
|
||||
}
|
||||
|
||||
fn get_secure_random(&self, mut buf: &mut [u8]) {
|
||||
// This is only for testing and is not really secure.
|
||||
let mut ctr = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos();
|
||||
while !buf.is_empty() {
|
||||
let l = buf.len().min(64);
|
||||
ctr = ctr.wrapping_add(1);
|
||||
buf[0..l].copy_from_slice(&Self::sha512(&[&ctr.to_ne_bytes()])[0..l]);
|
||||
buf = &mut buf[l..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DataStore for TestNodeHost {
|
||||
type LoadResultValueType = Arc<[u8]>;
|
||||
|
||||
const KEY_SIZE: usize = 64;
|
||||
const MAX_VALUE_SIZE: usize = 1024;
|
||||
const KEY_IS_COMPUTED: bool = true;
|
||||
|
||||
fn key_from_value(&self, value: &[u8], key_buffer: &mut [u8]) {
|
||||
key_buffer.copy_from_slice(Sha512::digest(value).as_slice());
|
||||
}
|
||||
|
||||
fn clock(&self) -> i64 { ms_since_epoch() }
|
||||
|
||||
fn domain(&self) -> &str { "test" }
|
||||
|
||||
fn load(&self, _: i64, key: &[u8]) -> LoadResult<Self::LoadResultValueType> {
|
||||
self.db.lock().unwrap().get(key).map_or(LoadResult::NotFound, |r| LoadResult::Ok(r.clone()))
|
||||
}
|
||||
|
||||
fn store(&self, key: &[u8], value: &[u8]) -> StoreResult {
|
||||
assert_eq!(key.len(), 64);
|
||||
let mut res = StoreResult::Ok;
|
||||
self.db.lock().unwrap().entry(key.try_into().unwrap()).and_modify(|e| {
|
||||
if e.as_ref().eq(value) {
|
||||
res = StoreResult::Duplicate;
|
||||
} else {
|
||||
*e = Arc::from(value)
|
||||
}
|
||||
}).or_insert_with(|| {
|
||||
Arc::from(value)
|
||||
});
|
||||
res
|
||||
}
|
||||
|
||||
fn total_count(&self) -> u64 { self.db.lock().unwrap().len() as u64 }
|
||||
|
||||
fn for_each<F: FnMut(&[u8], &[u8]) -> bool>(&self, _: i64, key_prefix: &[u8], mut f: F) {
|
||||
let mut r_start = [0_u8; Self::KEY_SIZE];
|
||||
let mut r_end = [0xff_u8; Self::KEY_SIZE];
|
||||
(&mut r_start[0..key_prefix.len()]).copy_from_slice(key_prefix);
|
||||
(&mut r_end[0..key_prefix.len()]).copy_from_slice(key_prefix);
|
||||
for (k, v) in self.db.lock().unwrap().range((Included(r_start), Included(r_end))) {
|
||||
if !f(k, v.as_ref()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async {
|
||||
println!("Running syncwhole local self-test network with {} nodes starting at 127.0.0.1:{}", TEST_NODE_COUNT, TEST_PORT_RANGE_START);
|
||||
println!();
|
||||
|
||||
let mut nodes: Vec<Node<TestNodeHost, TestNodeHost>> = Vec::with_capacity(TEST_NODE_COUNT);
|
||||
for port in TEST_PORT_RANGE_START..(TEST_PORT_RANGE_START + (TEST_NODE_COUNT as u16)) {
|
||||
let mut peers: Vec<SocketAddr> = Vec::with_capacity(TEST_NODE_COUNT);
|
||||
for port2 in TEST_PORT_RANGE_START..(TEST_PORT_RANGE_START + (TEST_NODE_COUNT as u16)) {
|
||||
if port != port2 {
|
||||
peers.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port2)));
|
||||
}
|
||||
}
|
||||
let nh = Arc::new(TestNodeHost {
|
||||
name: format!("{}", port),
|
||||
peers,
|
||||
db: Mutex::new(BTreeMap::new())
|
||||
});
|
||||
println!("Starting node on 127.0.0.1:{} with {} records in data store...", port, nh.db.lock().unwrap().len());
|
||||
nodes.push(Node::new(nh.clone(), nh.clone(), SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))).await.unwrap());
|
||||
}
|
||||
println!();
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let mut count = 0;
|
||||
for n in nodes.iter() {
|
||||
count += n.connection_count().await;
|
||||
}
|
||||
println!("{}", count);
|
||||
}
|
||||
});
|
||||
}
|
|
@ -7,17 +7,21 @@
|
|||
*/
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io::IoSlice;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering};
|
||||
use std::time::SystemTime;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::{Instant, Duration};
|
||||
|
||||
use crate::datastore::DataStore;
|
||||
use crate::host::Host;
|
||||
|
@ -25,10 +29,17 @@ use crate::ms_monotonic;
|
|||
use crate::protocol::*;
|
||||
use crate::varint;
|
||||
|
||||
const CONNECTION_TIMEOUT: i64 = 60000;
|
||||
const CONNECTION_KEEPALIVE_AFTER: i64 = CONNECTION_TIMEOUT / 3;
|
||||
const HOUSEKEEPING_INTERVAL: i64 = CONNECTION_KEEPALIVE_AFTER / 2;
|
||||
const IO_BUFFER_SIZE: usize = 65536;
|
||||
/// Inactivity timeout for connections in milliseconds.
|
||||
const CONNECTION_TIMEOUT: i64 = 120000;
|
||||
|
||||
/// How often to send STATUS messages in milliseconds.
|
||||
const STATUS_INTERVAL: i64 = 10000;
|
||||
|
||||
/// How often to run the housekeeping task's loop in milliseconds.
|
||||
const HOUSEKEEPING_INTERVAL: i64 = STATUS_INTERVAL / 2;
|
||||
|
||||
/// Size of read buffer, which is used to reduce the number of syscalls.
|
||||
const READ_BUFFER_SIZE: usize = 16384;
|
||||
|
||||
/// Information about a remote node to which we are connected.
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
@ -51,15 +62,17 @@ pub struct RemoteNodeInfo {
|
|||
/// True if this is an inbound TCP connection.
|
||||
pub inbound: bool,
|
||||
|
||||
/// True if this connection has exchanged init messages.
|
||||
/// True if this connection has exchanged init messages successfully.
|
||||
pub initialized: bool,
|
||||
}
|
||||
|
||||
fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> {
|
||||
if socket.set_reuseport(true).is_err() {
|
||||
socket.set_reuseaddr(true)?;
|
||||
}
|
||||
let _ = socket.set_linger(None);
|
||||
if socket.set_reuseport(true).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
socket.set_reuseaddr(true)
|
||||
}
|
||||
}
|
||||
|
||||
/// An instance of the syncwhole data set synchronization engine.
|
||||
|
@ -81,15 +94,14 @@ impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
|
|||
|
||||
let internal = Arc::new(NodeInternal::<D, H> {
|
||||
anti_loopback_secret: {
|
||||
let mut tmp = [0_u8; 16];
|
||||
let mut tmp = [0_u8; 64];
|
||||
host.get_secure_random(&mut tmp);
|
||||
tmp
|
||||
},
|
||||
db: db.clone(),
|
||||
datastore: db.clone(),
|
||||
host: host.clone(),
|
||||
bind_address,
|
||||
connections: Mutex::new(HashMap::with_capacity(64)),
|
||||
attempts: Mutex::new(HashMap::with_capacity(64)),
|
||||
bind_address
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
|
@ -99,22 +111,27 @@ impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn datastore(&self) -> &Arc<D> { &self.internal.datastore }
|
||||
|
||||
pub fn host(&self) -> &Arc<H> { &self.internal.host }
|
||||
|
||||
#[inline(always)]
|
||||
pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> {
|
||||
self.internal.clone().connect(endpoint).await
|
||||
self.internal.clone().connect(endpoint, Instant::now().add(Duration::from_millis(CONNECTION_TIMEOUT as u64))).await
|
||||
}
|
||||
|
||||
pub fn list_connections(&self) -> Vec<RemoteNodeInfo> {
|
||||
let mut connections = self.internal.connections.blocking_lock();
|
||||
pub async fn list_connections(&self) -> Vec<RemoteNodeInfo> {
|
||||
let connections = self.internal.connections.lock().await;
|
||||
let mut cl: Vec<RemoteNodeInfo> = Vec::with_capacity(connections.len());
|
||||
connections.retain(|_, c| {
|
||||
c.0.upgrade().map_or(false, |c| {
|
||||
cl.push(c.info.lock().unwrap().clone());
|
||||
true
|
||||
})
|
||||
});
|
||||
for (_, c) in connections.iter() {
|
||||
cl.push(c.0.info.blocking_lock().clone());
|
||||
}
|
||||
cl
|
||||
}
|
||||
|
||||
pub async fn connection_count(&self) -> usize {
|
||||
self.internal.connections.lock().await.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataStore + 'static, H: Host + 'static> Drop for Node<D, H> {
|
||||
|
@ -125,84 +142,127 @@ impl<D: DataStore + 'static, H: Host + 'static> Drop for Node<D, H> {
|
|||
}
|
||||
|
||||
pub struct NodeInternal<D: DataStore + 'static, H: Host + 'static> {
|
||||
anti_loopback_secret: [u8; 16],
|
||||
db: Arc<D>,
|
||||
anti_loopback_secret: [u8; 64],
|
||||
datastore: Arc<D>,
|
||||
host: Arc<H>,
|
||||
connections: Mutex<HashMap<SocketAddr, (Arc<Connection>, Option<JoinHandle<std::io::Result<()>>>)>>,
|
||||
bind_address: SocketAddr,
|
||||
connections: Mutex<HashMap<SocketAddr, (Weak<Connection>, Option<JoinHandle<std::io::Result<()>>>)>>,
|
||||
attempts: Mutex<HashMap<SocketAddr, JoinHandle<std::io::Result<bool>>>>,
|
||||
}
|
||||
|
||||
impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
||||
async fn housekeeping_task_main(self: Arc<Self>) {
|
||||
let mut last_status_sent = ms_monotonic();
|
||||
let mut tasks: Vec<JoinHandle<()>> = Vec::new();
|
||||
let mut connected_to_addresses: HashSet<SocketAddr> = HashSet::new();
|
||||
let mut sleep_until = Instant::now().add(Duration::from_millis(500));
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_millis(HOUSEKEEPING_INTERVAL as u64)).await;
|
||||
tokio::time::sleep_until(sleep_until).await;
|
||||
sleep_until = sleep_until.add(Duration::from_millis(HOUSEKEEPING_INTERVAL as u64));
|
||||
|
||||
let mut to_ping: Vec<Arc<Connection>> = Vec::new();
|
||||
let mut dead: Vec<(SocketAddr, Option<JoinHandle<std::io::Result<()>>>)> = Vec::new();
|
||||
let mut current_endpoints: HashSet<SocketAddr> = HashSet::new();
|
||||
|
||||
let mut connections = self.connections.lock().await;
|
||||
current_endpoints.reserve(connections.len() + 1);
|
||||
tasks.clear();
|
||||
connected_to_addresses.clear();
|
||||
let now = ms_monotonic();
|
||||
connections.retain(|sa, c| {
|
||||
let cc = c.0.upgrade();
|
||||
if cc.is_some() {
|
||||
let cc = cc.unwrap();
|
||||
|
||||
// Check connection timeouts, send status updates, and garbage collect from the connections map.
|
||||
// Status message outputs are backgrounded since these can block of TCP buffers are nearly full.
|
||||
// A timeout based on the service loop interval is used. Usually these sends will finish instantly
|
||||
// but if they take too long this typically means the link is dead. We wait for all tasks at the
|
||||
// end of the service loop. The on_connection_closed() method in 'host' is called in sub-tasks to
|
||||
// prevent the possibility of deadlocks on self.connections.lock() if the Host implementation calls
|
||||
// something that tries to lock it.
|
||||
let status = if (now - last_status_sent) >= STATUS_INTERVAL {
|
||||
last_status_sent = now;
|
||||
Some(msg::Status {
|
||||
record_count: self.datastore.total_count(),
|
||||
clock: self.datastore.clock() as u64
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.connections.lock().await.retain(|sa, c| {
|
||||
let cc = &(c.0);
|
||||
if !cc.closed.load(Ordering::Relaxed) {
|
||||
if (now - cc.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT {
|
||||
if (now - cc.last_send_time.load(Ordering::Relaxed)) >= CONNECTION_KEEPALIVE_AFTER {
|
||||
to_ping.push(cc);
|
||||
connected_to_addresses.insert(sa.clone());
|
||||
let _ = status.as_ref().map(|status| {
|
||||
let status = status.clone();
|
||||
let self2 = self.clone();
|
||||
let cc = cc.clone();
|
||||
let sa = sa.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
if cc.info.lock().await.initialized {
|
||||
if !tokio::time::timeout_at(sleep_until, cc.send_obj(MESSAGE_TYPE_STATUS, &status, now)).await.map_or(false, |r| r.is_ok()) {
|
||||
let _ = self2.connections.lock().await.remove(&sa).map(|c| c.1.map(|j| j.abort()));
|
||||
self2.host.on_connection_closed(&*cc.info.lock().await, "write overflow (timeout)".to_string());
|
||||
}
|
||||
current_endpoints.insert(sa.clone());
|
||||
}
|
||||
}));
|
||||
});
|
||||
true
|
||||
} else {
|
||||
c.1.take().map(|j| j.abort());
|
||||
let _ = c.1.take().map(|j| j.abort());
|
||||
let host = self.host.clone();
|
||||
let cc = cc.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
host.on_connection_closed(&*cc.info.lock().await, "timeout".to_string());
|
||||
}));
|
||||
false
|
||||
}
|
||||
} else {
|
||||
let _ = c.1.take().map(|j| dead.push((sa.clone(), Some(j))));
|
||||
let host = self.host.clone();
|
||||
let cc = cc.clone();
|
||||
let j = c.1.take();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
if j.is_some() {
|
||||
let e = j.unwrap().await;
|
||||
if e.is_ok() {
|
||||
let e = e.unwrap();
|
||||
host.on_connection_closed(&*cc.info.lock().await, e.map_or_else(|e| e.to_string(), |_| "unknown error".to_string()));
|
||||
} else {
|
||||
host.on_connection_closed(&*cc.info.lock().await, "remote host closed connection".to_string());
|
||||
}
|
||||
} else {
|
||||
host.on_connection_closed(&*cc.info.lock().await, "remote host closed connection".to_string());
|
||||
}
|
||||
}));
|
||||
false
|
||||
}
|
||||
});
|
||||
drop(connections); // release lock
|
||||
|
||||
for d in dead.iter_mut() {
|
||||
d.1.take().unwrap().await.map_or_else(|e| {
|
||||
self.host.on_connection_closed(&d.0, Some(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "timed out"))));
|
||||
}, |r| {
|
||||
if r.is_ok() {
|
||||
self.host.on_connection_closed(&d.0, None);
|
||||
// Always try to connect to fixed peers.
|
||||
let fixed_peers = self.host.fixed_peers();
|
||||
for sa in fixed_peers.iter() {
|
||||
if !connected_to_addresses.contains(sa) {
|
||||
let sa = sa.clone();
|
||||
let self2 = self.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let _ = self2.connect(&sa, sleep_until).await;
|
||||
}));
|
||||
connected_to_addresses.insert(sa.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Try to connect to more peers until desired connection count is reached.
|
||||
let desired_connection_count = self.host.desired_connection_count();
|
||||
while connected_to_addresses.len() < desired_connection_count {
|
||||
let sa = self.host.another_peer(&connected_to_addresses);
|
||||
if sa.is_some() {
|
||||
let sa = sa.unwrap();
|
||||
let self2 = self.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let _ = self2.connect(&sa, sleep_until).await;
|
||||
}));
|
||||
connected_to_addresses.insert(sa.clone());
|
||||
} else {
|
||||
self.host.on_connection_closed(&d.0, Some(Box::new(r.unwrap_err())));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for c in to_ping.iter() {
|
||||
let _ = c.send(&[MESSAGE_TYPE_NOP, 0], now).await;
|
||||
}
|
||||
|
||||
let desired = self.host.desired_connection_count();
|
||||
let fixed = self.host.fixed_peers();
|
||||
|
||||
let mut attempts = self.attempts.lock().await;
|
||||
|
||||
for ep in fixed.iter() {
|
||||
if !current_endpoints.contains(ep) {
|
||||
let self2 = self.clone();
|
||||
let ep2 = ep.clone();
|
||||
attempts.insert(ep.clone(), tokio::spawn(async move { self2.connect(&ep2).await }));
|
||||
current_endpoints.insert(ep.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
while current_endpoints.len() < desired {
|
||||
let ep = self.host.another_peer(¤t_endpoints);
|
||||
if ep.is_some() {
|
||||
let ep = ep.unwrap();
|
||||
current_endpoints.insert(ep.clone());
|
||||
let self2 = self.clone();
|
||||
attempts.insert(ep.clone(), tokio::spawn(async move { self2.connect(&ep).await }));
|
||||
// Wait for this iteration's batched background tasks to complete.
|
||||
loop {
|
||||
let s = tasks.pop();
|
||||
if s.is_some() {
|
||||
let _ = s.unwrap().await;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
@ -213,42 +273,51 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
async fn listener_task_main(self: Arc<Self>, listener: TcpListener) {
|
||||
loop {
|
||||
let socket = listener.accept().await;
|
||||
if self.connections.lock().await.len() < self.host.max_connection_count() && socket.is_ok() {
|
||||
if socket.is_ok() {
|
||||
let (stream, endpoint) = socket.unwrap();
|
||||
if self.host.allow(&endpoint) {
|
||||
let num_conn = self.connections.lock().await.len();
|
||||
if num_conn < self.host.max_connection_count() || self.host.fixed_peers().contains(&endpoint) {
|
||||
Self::connection_start(&self, endpoint, stream, true).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connection_io_task_main(self: Arc<Self>, connection: Arc<Connection>, reader: OwnedReadHalf) -> std::io::Result<()> {
|
||||
let mut challenge = [0_u8; 16];
|
||||
self.host.get_secure_random(&mut challenge);
|
||||
#[inline(always)]
|
||||
async fn connection_io_task_main(self: Arc<Self>, connection: &Arc<Connection>, mut reader: BufReader<OwnedReadHalf>) -> std::io::Result<()> {
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.resize(4096, 0);
|
||||
|
||||
let mut anti_loopback_challenge_sent = [0_u8; 64];
|
||||
let mut challenge_sent = [0_u8; 64];
|
||||
self.host.get_secure_random(&mut anti_loopback_challenge_sent);
|
||||
self.host.get_secure_random(&mut challenge_sent);
|
||||
connection.send_obj(MESSAGE_TYPE_INIT, &msg::Init {
|
||||
anti_loopback_challenge: &challenge,
|
||||
domain: String::new(), // TODO
|
||||
anti_loopback_challenge: &anti_loopback_challenge_sent,
|
||||
challenge: &challenge_sent,
|
||||
domain: self.datastore.domain().to_string(),
|
||||
key_size: D::KEY_SIZE as u16,
|
||||
max_value_size: D::MAX_VALUE_SIZE as u64,
|
||||
node_name: None,
|
||||
node_contact: None,
|
||||
node_name: self.host.name().map(|n| n.to_string()),
|
||||
node_contact: self.host.contact().map(|c| c.to_string()),
|
||||
locally_bound_port: self.bind_address.port(),
|
||||
explicit_ipv4: None,
|
||||
explicit_ipv6: None
|
||||
}, ms_monotonic()).await?;
|
||||
|
||||
let mut init_received = false;
|
||||
let mut initialized = false;
|
||||
let mut reader = BufReader::with_capacity(IO_BUFFER_SIZE, reader);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.resize(4096, 0);
|
||||
loop {
|
||||
reader.read_exact(&mut buf.as_mut_slice()[0..1]).await?;
|
||||
let message_type = unsafe { *buf.get_unchecked(0) };
|
||||
let message_size = varint::read_async(&mut reader).await?;
|
||||
let header_size = 1 + message_size.1;
|
||||
let message_size = message_size.0;
|
||||
if message_size > (D::MAX_VALUE_SIZE + ((D::KEY_SIZE + 10) * 256) + 65536) as u64 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large"));
|
||||
}
|
||||
|
||||
let now = ms_monotonic();
|
||||
connection.last_receive_time.store(now, Ordering::Relaxed);
|
||||
|
||||
match message_type {
|
||||
|
||||
|
@ -257,25 +326,17 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init"));
|
||||
}
|
||||
|
||||
let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?;
|
||||
let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
|
||||
|
||||
if !msg.domain.as_str().eq(self.db.domain()) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set domain mismatch"));
|
||||
if !msg.domain.as_str().eq(self.datastore.domain()) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("data set domain mismatch: '{}' != '{}'", msg.domain, self.datastore.domain())));
|
||||
}
|
||||
if msg.key_size != D::KEY_SIZE as u16 || msg.max_value_size > D::MAX_VALUE_SIZE as u64 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set key/value sizing mismatch"));
|
||||
}
|
||||
|
||||
let mut antiloop = msg.anti_loopback_challenge.to_vec();
|
||||
let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret);
|
||||
let antiloop = H::sha512(antiloop.as_slice());
|
||||
connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse {
|
||||
anti_loopback_response: &antiloop[0..16]
|
||||
}, now).await?;
|
||||
|
||||
init_received = true;
|
||||
|
||||
let mut info = connection.info.lock().unwrap();
|
||||
let (anti_loopback_response, challenge_response) = {
|
||||
let mut info = connection.info.lock().await;
|
||||
info.node_name = msg.node_name.clone();
|
||||
info.node_contact = msg.node_contact.clone();
|
||||
let _ = msg.explicit_ipv4.map(|pv4| {
|
||||
|
@ -284,55 +345,72 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
let _ = msg.explicit_ipv6.map(|pv6| {
|
||||
info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0)));
|
||||
});
|
||||
|
||||
let challenge_response = self.host.authenticate(&info, msg.challenge);
|
||||
if challenge_response.is_none() {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped"));
|
||||
}
|
||||
(H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), challenge_response.unwrap())
|
||||
};
|
||||
|
||||
connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse {
|
||||
anti_loopback_response: &anti_loopback_response,
|
||||
challenge_response: &challenge_response
|
||||
}, now).await?;
|
||||
|
||||
init_received = true;
|
||||
},
|
||||
|
||||
MESSAGE_TYPE_INIT_RESPONSE => {
|
||||
if initialized {
|
||||
let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize).await?;
|
||||
|
||||
let mut info = connection.info.lock().await;
|
||||
if info.initialized {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init response"));
|
||||
}
|
||||
|
||||
let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?;
|
||||
let mut antiloop = challenge.to_vec();
|
||||
let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret);
|
||||
let antiloop = H::sha512(antiloop.as_slice());
|
||||
if msg.anti_loopback_response.eq(&antiloop[0..16]) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "rejected connection to self"));
|
||||
}
|
||||
|
||||
initialized = true;
|
||||
let mut info = connection.info.lock().unwrap();
|
||||
info.initialized = true;
|
||||
let info = info.clone();
|
||||
|
||||
if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self"));
|
||||
}
|
||||
if !self.host.authenticate(&info, &challenge_sent).map_or(false, |cr| msg.challenge_response.eq(&cr)) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed"));
|
||||
}
|
||||
|
||||
self.host.on_connect(&info);
|
||||
},
|
||||
|
||||
_ => {
|
||||
// Skip messages that aren't recognized or don't need to be parsed like NOP.
|
||||
// Skip messages that aren't recognized or don't need to be parsed.
|
||||
let mut remaining = message_size as usize;
|
||||
while remaining > 0 {
|
||||
let s = remaining.min(buf.len());
|
||||
reader.read_exact(&mut buf.as_mut_slice()[0..s]).await?;
|
||||
remaining -= s;
|
||||
}
|
||||
connection.last_receive_time.store(ms_monotonic(), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
connection.bytes_received.fetch_add((header_size as u64) + message_size, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
async fn connection_start(self: &Arc<Self>, address: SocketAddr, stream: TcpStream, inbound: bool) -> bool {
|
||||
let (reader, writer) = stream.into_split();
|
||||
|
||||
let mut ok = false;
|
||||
let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| {
|
||||
ok = true;
|
||||
let _ = stream.set_nodelay(true);
|
||||
let (reader, writer) = stream.into_split();
|
||||
let now = ms_monotonic();
|
||||
let connection = Arc::new(Connection {
|
||||
writer: Mutex::new(BufWriter::with_capacity(IO_BUFFER_SIZE, writer)),
|
||||
writer: Mutex::new(writer),
|
||||
last_send_time: AtomicI64::new(now),
|
||||
last_receive_time: AtomicI64::new(now),
|
||||
info: std::sync::Mutex::new(RemoteNodeInfo {
|
||||
bytes_sent: AtomicU64::new(0),
|
||||
bytes_received: AtomicU64::new(0),
|
||||
info: Mutex::new(RemoteNodeInfo {
|
||||
node_name: None,
|
||||
node_contact: None,
|
||||
remote_address: address.clone(),
|
||||
|
@ -341,32 +419,34 @@ impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
inbound,
|
||||
initialized: false
|
||||
}),
|
||||
closed: AtomicBool::new(false)
|
||||
});
|
||||
(Arc::downgrade(&connection), Some(tokio::spawn(self.clone().connection_io_task_main(connection.clone(), reader))))
|
||||
let self2 = self.clone();
|
||||
(connection.clone(), Some(tokio::spawn(async move {
|
||||
let result = self2.connection_io_task_main(&connection, BufReader::with_capacity(READ_BUFFER_SIZE, reader)).await;
|
||||
connection.closed.store(true, Ordering::Relaxed);
|
||||
result
|
||||
})))
|
||||
});
|
||||
ok
|
||||
}
|
||||
|
||||
async fn connect(self: Arc<Self>, address: &SocketAddr) -> std::io::Result<bool> {
|
||||
let mut success = false;
|
||||
if !self.connections.lock().await.contains_key(address) {
|
||||
async fn connect(self: Arc<Self>, address: &SocketAddr, deadline: Instant) -> std::io::Result<bool> {
|
||||
self.host.on_connect_attempt(address);
|
||||
let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
|
||||
configure_tcp_socket(&stream)?;
|
||||
stream.bind(self.bind_address.clone())?;
|
||||
let stream = stream.connect(address.clone()).await?;
|
||||
success = self.connection_start(address.clone(), stream, false).await;
|
||||
let stream = tokio::time::timeout_at(deadline, stream.connect(address.clone())).await;
|
||||
if stream.is_ok() {
|
||||
Ok(self.connection_start(address.clone(), stream.unwrap()?, false).await)
|
||||
} else {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "connect timed out"))
|
||||
}
|
||||
self.attempts.lock().await.remove(address);
|
||||
Ok(success)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataStore + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
|
||||
fn drop(&mut self) {
|
||||
for a in self.attempts.blocking_lock().iter() {
|
||||
a.1.abort();
|
||||
}
|
||||
for (_, c) in self.connections.blocking_lock().drain() {
|
||||
c.1.map(|c| c.abort());
|
||||
}
|
||||
|
@ -374,18 +454,20 @@ impl<D: DataStore + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
|
|||
}
|
||||
|
||||
struct Connection {
|
||||
writer: Mutex<BufWriter<OwnedWriteHalf>>,
|
||||
writer: Mutex<OwnedWriteHalf>,
|
||||
last_send_time: AtomicI64,
|
||||
last_receive_time: AtomicI64,
|
||||
info: std::sync::Mutex<RemoteNodeInfo>,
|
||||
bytes_sent: AtomicU64,
|
||||
bytes_received: AtomicU64,
|
||||
info: Mutex<RemoteNodeInfo>,
|
||||
closed: AtomicBool,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> {
|
||||
let mut writer = self.writer.lock().await;
|
||||
writer.write_all(data).await?;
|
||||
writer.flush().await?;
|
||||
self.writer.lock().await.write_all(data).await?;
|
||||
self.last_send_time.store(now, Ordering::Relaxed);
|
||||
self.bytes_sent.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -393,31 +475,31 @@ impl Connection {
|
|||
let data = rmp_serde::encode::to_vec_named(&obj);
|
||||
if data.is_ok() {
|
||||
let data = data.unwrap();
|
||||
let mut tmp = [0_u8; 16];
|
||||
tmp[0] = message_type;
|
||||
let len = 1 + varint::encode(&mut tmp[1..], data.len() as u64);
|
||||
let mut writer = self.writer.lock().await;
|
||||
writer.write_all(&tmp[0..len]).await?;
|
||||
writer.write_all(data.as_slice()).await?;
|
||||
writer.flush().await?;
|
||||
let mut header: [u8; 16] = unsafe { MaybeUninit::uninit().assume_init() };
|
||||
header[0] = message_type;
|
||||
let header_size = 1 + varint::encode(&mut header[1..], data.len() as u64);
|
||||
if self.writer.lock().await.write_vectored(&[IoSlice::new(&header[0..header_size]), IoSlice::new(data.as_slice())]).await? == (data.len() + header_size) {
|
||||
self.last_send_time.store(now, Ordering::Relaxed);
|
||||
self.bytes_sent.fetch_add((header_size + data.len()) as u64, Ordering::Relaxed);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure"))
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error"))
|
||||
}
|
||||
} else {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure (internal error)"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_msg<'a>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<&'a [u8]> {
|
||||
async fn read_msg<'a, R: AsyncReadExt + Unpin>(&self, reader: &mut R, buf: &'a mut Vec<u8>, message_size: usize) -> std::io::Result<&'a [u8]> {
|
||||
if message_size > buf.len() {
|
||||
buf.resize(((message_size / 4096) + 1) * 4096, 0);
|
||||
}
|
||||
let b = &mut buf.as_mut_slice()[0..message_size];
|
||||
reader.read_exact(b).await?;
|
||||
self.last_receive_time.store(now, Ordering::Relaxed);
|
||||
Ok(b)
|
||||
}
|
||||
|
||||
async fn read_obj<'a, O: Deserialize<'a>>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<O> {
|
||||
rmp_serde::from_slice(self.read_msg(reader, buf, message_size, now).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
|
||||
async fn read_obj<'a, R: AsyncReadExt + Unpin, O: Deserialize<'a>>(&self, reader: &mut R, buf: &'a mut Vec<u8>, message_size: usize) -> std::io::Result<O> {
|
||||
rmp_serde::from_read_ref(self.read_msg(reader, buf, message_size).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid msgpack: {}", e.to_string())))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,11 +6,26 @@
|
|||
* https://www.zerotier.com/
|
||||
*/
|
||||
|
||||
/// No operation, payload ignored.
|
||||
pub const MESSAGE_TYPE_NOP: u8 = 0;
|
||||
|
||||
/// Sent by both sides of a TCP link when it's established.
|
||||
pub const MESSAGE_TYPE_INIT: u8 = 1;
|
||||
|
||||
/// Reply sent to INIT.
|
||||
pub const MESSAGE_TYPE_INIT_RESPONSE: u8 = 2;
|
||||
pub const MESSAGE_TYPE_HAVE_RECORDS: u8 = 3;
|
||||
pub const MESSAGE_TYPE_GET_RECORDS: u8 = 4;
|
||||
|
||||
/// Sent every few seconds to notify peers of number of records, clock, etc.
|
||||
pub const MESSAGE_TYPE_STATUS: u8 = 3;
|
||||
|
||||
/// Payload is a list of keys of records. Usually sent to advertise recently received new records.
|
||||
pub const MESSAGE_TYPE_HAVE_RECORDS: u8 = 4;
|
||||
|
||||
/// Payload is a list of keys of records the sending node wants.
|
||||
pub const MESSAGE_TYPE_GET_RECORDS: u8 = 5;
|
||||
|
||||
/// Payload is a record, with key being omitted if the data store's KEY_IS_COMPUTED constant is true.
|
||||
pub const MESSAGE_TYPE_RECORD: u8 = 6;
|
||||
|
||||
pub mod msg {
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
@ -33,27 +48,70 @@ pub mod msg {
|
|||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Init<'a> {
|
||||
/// A random challenge to be hashed with a secret to detect and drop connections to self.
|
||||
#[serde(rename = "alc")]
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub anti_loopback_challenge: &'a [u8],
|
||||
|
||||
/// A random challenge for login/authentication.
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub challenge: &'a [u8],
|
||||
|
||||
/// An arbitrary name for this data set to avoid connecting to peers not replicating it.
|
||||
#[serde(rename = "d")]
|
||||
pub domain: String,
|
||||
|
||||
/// Size of keys in this data set in bytes.
|
||||
#[serde(rename = "ks")]
|
||||
pub key_size: u16,
|
||||
|
||||
/// Maximum allowed size of values in this data set in bytes.
|
||||
#[serde(rename = "mvs")]
|
||||
pub max_value_size: u64,
|
||||
|
||||
/// Optional name to advertise for this node.
|
||||
#[serde(rename = "nn")]
|
||||
pub node_name: Option<String>,
|
||||
|
||||
/// Optional contact information for this node, such as a URL or an e-mail address.
|
||||
#[serde(rename = "nc")]
|
||||
pub node_contact: Option<String>,
|
||||
|
||||
/// Port to which this node has locally bound.
|
||||
/// This is used to try to auto-detect whether a NAT is in the way.
|
||||
pub locally_bound_port: u16,
|
||||
|
||||
/// An IPv4 address where this node can be reached.
|
||||
/// If both explicit_ipv4 and explicit_ipv6 are omitted the physical source IP:port may be used.
|
||||
#[serde(rename = "ei4")]
|
||||
pub explicit_ipv4: Option<IPv4>,
|
||||
|
||||
/// An IPv6 address where this node can be reached.
|
||||
/// If both explicit_ipv4 and explicit_ipv6 are omitted the physical source IP:port may be used.
|
||||
#[serde(rename = "ei6")]
|
||||
pub explicit_ipv6: Option<IPv6>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct InitResponse<'a> {
|
||||
/// HMAC-SHA512(local secret, anti_loopback_challenge) to detect and drop loops.
|
||||
#[serde(rename = "alr")]
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub anti_loopback_response: &'a [u8],
|
||||
|
||||
/// HMAC-SHA512(secret, challenge) for authentication. (If auth is not enabled, an all-zero secret is used.)
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub challenge_response: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct Status {
|
||||
/// Total number of records in data set.
|
||||
#[serde(rename = "rc")]
|
||||
pub record_count: u64,
|
||||
|
||||
/// Local wall clock time in milliseconds since Unix epoch.
|
||||
#[serde(rename = "c")]
|
||||
pub clock: u64,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,19 +34,25 @@ pub async fn write_async<W: AsyncWrite + Unpin>(w: &mut W, v: u64) -> std::io::R
|
|||
w.write_all(&b[0..i]).await
|
||||
}
|
||||
|
||||
pub async fn read_async<R: AsyncRead + Unpin>(r: &mut R) -> std::io::Result<u64> {
|
||||
pub async fn read_async<R: AsyncRead + Unpin>(r: &mut R) -> std::io::Result<(u64, usize)> {
|
||||
let mut v = 0_u64;
|
||||
let mut buf = [0_u8; 1];
|
||||
let mut pos = 0;
|
||||
let mut count = 0;
|
||||
loop {
|
||||
let _ = r.read_exact(&mut buf).await?;
|
||||
loop {
|
||||
if r.read(&mut buf).await? == 1 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
count += 1;
|
||||
let b = buf[0];
|
||||
if b <= 0x7f {
|
||||
v |= (b as u64).wrapping_shl(pos);
|
||||
pos += 7;
|
||||
} else {
|
||||
v |= ((b & 0x7f) as u64).wrapping_shl(pos);
|
||||
return Ok(v);
|
||||
return Ok((v, count));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue