diff --git a/syncwhole/src/database.rs b/syncwhole/src/datastore.rs similarity index 92% rename from syncwhole/src/database.rs rename to syncwhole/src/datastore.rs index bf8c4a1c2..f282d3021 100644 --- a/syncwhole/src/database.rs +++ b/syncwhole/src/datastore.rs @@ -37,7 +37,7 @@ pub enum StoreResult { Rejected } -/// API to be implemented by the data store we want to replicate. +/// API to be implemented by the data set we want to replicate. /// /// The API specified here supports temporally subjective data sets. These are data sets /// where the existence or non-existence of a record may depend on the (real world) time. @@ -45,11 +45,11 @@ pub enum StoreResult { /// what time I think it is" value to be considered locally so that data can be replicated /// as of any given time. /// -/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a pure function of +/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a function of /// values. If this is true, get_key() must be implemented. /// /// The implementation must be thread safe. -pub trait Database: Sync + Send { +pub trait DataStore: Sync + Send { /// Type to be enclosed in the Ok() enum value in LoadResult. type LoadResultValueType: AsRef<[u8]> + Send; @@ -70,10 +70,13 @@ pub trait Database: Sync + Send { /// If KEY_IS_COMPUTED is true this must be implemented. The default implementation /// panics to indicate this. If KEY_IS_COMPUTED is false this is never called. #[allow(unused_variables)] - fn get_key(value: &[u8], key: &mut [u8]) { + fn get_key(&self, value: &[u8], key: &mut [u8]) { panic!("get_key() must be implemented if KEY_IS_COMPUTED is true"); } + /// Get the domain of this data store, which is just an arbitrary unique identifier. + fn domain(&self) -> &str; + /// Get an item if it exists as of a given reference time. /// /// The supplied key must be of length KEY_SIZE or this may panic. @@ -115,24 +118,28 @@ pub trait Database: Sync + Send { /// A simple in-memory data store backed by a BTreeMap. pub struct MemoryDatabase { max_age: i64, + domain: String, db: Mutex)>> } impl MemoryDatabase { - pub fn new(max_age: i64) -> Self { + pub fn new(max_age: i64, domain: String) -> Self { Self { max_age: if max_age > 0 { max_age } else { i64::MAX }, + domain, db: Mutex::new(BTreeMap::new()) } } } -impl Database for MemoryDatabase { +impl DataStore for MemoryDatabase { type LoadResultValueType = Arc<[u8]>; const KEY_SIZE: usize = KEY_SIZE; const MAX_VALUE_SIZE: usize = 65536; const KEY_IS_COMPUTED: bool = false; + fn domain(&self) -> &str { self.domain.as_str() } + fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult { let db = self.db.lock().unwrap(); let e = db.get(key); diff --git a/syncwhole/src/host.rs b/syncwhole/src/host.rs index b205fcafa..5085096b5 100644 --- a/syncwhole/src/host.rs +++ b/syncwhole/src/host.rs @@ -10,6 +10,8 @@ use std::collections::HashSet; use std::error::Error; use std::net::SocketAddr; +use crate::node::RemoteNodeInfo; + /// A trait that users of syncwhole implement to provide configuration information and listen for events. pub trait Host: Sync + Send { /// Compute SHA512. @@ -27,11 +29,16 @@ pub trait Host: Sync + Send { /// of the set is less than the minimum active link count the host wishes to maintain. fn get_more_endpoints(&self, current_endpoints: &HashSet) -> Vec; - /// Called whenever we have successfully connected to a remote endpoint to (possibly) remember it. - fn on_connect_success(&self, endpoint: &SocketAddr); + /// Get the maximum number of endpoints allowed. + /// + /// This is checked on incoming connect and incoming links are refused if the total is over this count. + fn max_endpoints(&self) -> usize; - /// Called whenever an outgoing connection fails. - fn on_connect_failure(&self, endpoint: &SocketAddr, reason: Box); + /// Called whenever we have successfully connected to a remote node (after connection is initialized). + fn on_connect(&self, info: &RemoteNodeInfo); + + /// Called when an open connection is closed. + fn on_connection_closed(&self, endpoint: &SocketAddr, reason: Option>); /// Fill a buffer with secure random bytes. fn get_secure_random(&self, buf: &mut [u8]); diff --git a/syncwhole/src/iblt.rs b/syncwhole/src/iblt.rs index f12cd7454..7db4bb24c 100644 --- a/syncwhole/src/iblt.rs +++ b/syncwhole/src/iblt.rs @@ -6,9 +6,9 @@ * https://www.zerotier.com/ */ -use std::io::{Read, Write}; use std::mem::{size_of, zeroed}; use std::ptr::write_bytes; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::varint; @@ -103,13 +103,13 @@ impl IBLT { unsafe { write_bytes((&mut self.map as *mut IBLTEntry).cast::(), 0, size_of::<[IBLTEntry; B]>()) }; } - pub fn read(&mut self, r: &mut R) -> std::io::Result<()> { - r.read_exact(unsafe { &mut *(&mut self.salt as *mut u64).cast::<[u8; 8]>() })?; + pub async fn read(&mut self, r: &mut R) -> std::io::Result<()> { + r.read_exact(unsafe { &mut *(&mut self.salt as *mut u64).cast::<[u8; 8]>() }).await?; let mut prev_c = 0_i64; for b in self.map.iter_mut() { - r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() })?; - r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() })?; - let mut c = varint::read(r)? as i64; + let _ = r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() }).await?; + let _ = r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() }).await?; + let mut c = varint::read_async(r).await? as i64; if (c & 1) == 0 { c = c.wrapping_shr(1); } else { @@ -121,18 +121,18 @@ impl IBLT { Ok(()) } - pub fn write(&self, w: &mut W) -> std::io::Result<()> { - w.write_all(unsafe { &*(&self.salt as *const u64).cast::<[u8; 8]>() })?; + pub async fn write(&self, w: &mut W) -> std::io::Result<()> { + let _ = w.write_all(unsafe { &*(&self.salt as *const u64).cast::<[u8; 8]>() }).await?; let mut prev_c = 0_i64; for b in self.map.iter() { - w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() })?; - w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() })?; + let _ = w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() }).await?; + let _ = w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() }).await?; let mut c = (b.count - prev_c).wrapping_shl(1); prev_c = b.count; if c < 0 { c = -c | 1; } - varint::write(w, c as u64)?; + let _ = varint::write_async(w, c as u64).await?; } Ok(()) } diff --git a/syncwhole/src/lib.rs b/syncwhole/src/lib.rs index 38ffc9caa..134b83f95 100644 --- a/syncwhole/src/lib.rs +++ b/syncwhole/src/lib.rs @@ -1,7 +1,16 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + * + * (c)2022 ZeroTier, Inc. + * https://www.zerotier.com/ + */ + pub(crate) mod varint; pub(crate) mod protocol; pub(crate) mod iblt; -pub mod database; + +pub mod datastore; pub mod node; pub mod host; diff --git a/syncwhole/src/node.rs b/syncwhole/src/node.rs index 608b8961c..905a42b1a 100644 --- a/syncwhole/src/node.rs +++ b/syncwhole/src/node.rs @@ -7,19 +7,19 @@ */ use std::collections::HashMap; -use std::net::SocketAddr; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::sync::{Arc, Weak}; use std::sync::atomic::{AtomicI64, Ordering}; -use std::time::Duration; -use serde::{Deserialize, Serialize}; +use std::time::{Duration, SystemTime}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::sync::Mutex; use tokio::task::JoinHandle; -use crate::database::Database; +use crate::datastore::DataStore; use crate::host::Host; use crate::ms_monotonic; use crate::protocol::*; @@ -27,92 +27,130 @@ use crate::varint; const CONNECTION_TIMEOUT: i64 = 60000; const CONNECTION_KEEPALIVE_AFTER: i64 = 20000; +const IO_BUFFER_SIZE: usize = 65536; -struct Connection { - writer: Mutex, - last_send_time: AtomicI64, - last_receive_time: AtomicI64, - io_task: std::sync::Mutex>>>, - incoming: bool +#[derive(Clone, PartialEq, Eq)] +pub struct RemoteNodeInfo { + pub node_name: Option, + pub node_contact: Option, + pub endpoint: SocketAddr, + pub preferred_endpoints: Vec, + pub connect_time: SystemTime, + pub inbound: bool, + pub initialized: bool, } -impl Connection { - async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> { - self.writer.lock().await.write_all(data).await.map(|_| { - self.last_send_time.store(now, Ordering::Relaxed); +pub struct Node { + internal: Arc>, + housekeeping_task: JoinHandle<()>, + listener_task: JoinHandle<()> +} + +impl Node { + pub async fn new(db: Arc, host: Arc, bind_address: SocketAddr) -> std::io::Result { + let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; + if listener.set_reuseport(true).is_err() { + listener.set_reuseaddr(true)?; + } + listener.bind(bind_address.clone())?; + let listener = listener.listen(1024)?; + + let internal = Arc::new(NodeInternal:: { + anti_loopback_secret: { + let mut tmp = [0_u8; 16]; + host.get_secure_random(&mut tmp); + tmp + }, + db: db.clone(), + host: host.clone(), + bind_address, + connections: Mutex::new(HashMap::with_capacity(64)), + }); + + Ok(Self { + internal: internal.clone(), + housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()), + listener_task: tokio::spawn(internal.listener_task_main(listener)), }) } - async fn send_obj(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> { - let data = rmp_serde::encode::to_vec_named(&obj); - if data.is_ok() { - let data = data.unwrap(); - let mut tmp = [0_u8; 16]; - tmp[0] = message_type; - let len = 1 + varint::encode(&mut tmp[1..], data.len() as u64); - let mut writer = self.writer.lock().await; - writer.write_all(&tmp[0..len]).await?; - writer.write_all(data.as_slice()).await?; - self.last_send_time.store(now, Ordering::Relaxed); - Ok(()) - } else { - Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure")) - } + #[inline(always)] + pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result { + self.internal.connect(endpoint).await } - fn kill(&self) { - let _ = self.io_task.lock().unwrap().take().map(|h| h.abort()); - } - - async fn read_msg<'a>(&self, reader: &mut BufReader, buf: &'a mut Vec, message_size: usize, now: i64) -> std::io::Result<&'a [u8]> { - if message_size > buf.len() { - buf.resize(((message_size / 4096) + 1) * 4096, 0); - } - let b = &mut buf.as_mut_slice()[0..message_size]; - reader.read_exact(b).await?; - self.last_receive_time.store(now, Ordering::Relaxed); - Ok(b) - } - - async fn read_obj<'a, O: Deserialize<'a>>(&self, reader: &mut BufReader, buf: &'a mut Vec, message_size: usize, now: i64) -> std::io::Result { - rmp_serde::from_slice(self.read_msg(reader, buf, message_size, now).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())) + pub fn list_connections(&self) -> Vec { + let mut connections = self.internal.connections.blocking_lock(); + let mut cl: Vec = Vec::with_capacity(connections.len()); + connections.retain(|_, c| { + c.0.upgrade().map_or(false, |c| { + cl.push(c.info.lock().unwrap().clone()); + true + }) + }); + cl } } -pub struct NodeInternal { - anti_loopback_secret: [u8; 64], +impl Drop for Node { + fn drop(&mut self) { + self.housekeeping_task.abort(); + self.listener_task.abort(); + } +} + +pub struct NodeInternal { + anti_loopback_secret: [u8; 16], db: Arc, host: Arc, bind_address: SocketAddr, - connections: Mutex>>, + connections: Mutex, Option>>)>>, } -impl NodeInternal { +impl NodeInternal { async fn housekeeping_task_main(self: Arc) { loop { - tokio::time::sleep(Duration::from_millis(CONNECTION_KEEPALIVE_AFTER as u64)).await; + tokio::time::sleep(Duration::from_millis((CONNECTION_KEEPALIVE_AFTER / 2) as u64)).await; let mut to_ping: Vec> = Vec::new(); + let mut dead: Vec<(SocketAddr, Option>>)> = Vec::new(); + let mut connections = self.connections.lock().await; let now = ms_monotonic(); - connections.retain(|_, c| { - c.upgrade().map_or(false, |c| { - if (now - c.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT { - if (now - c.last_send_time.load(Ordering::Relaxed)) > CONNECTION_KEEPALIVE_AFTER { - to_ping.push(c); + connections.retain(|sa, c| { + let cc = c.0.upgrade(); + if cc.is_some() { + let cc = cc.unwrap(); + if (now - cc.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT { + if (now - cc.last_send_time.load(Ordering::Relaxed)) >= CONNECTION_KEEPALIVE_AFTER { + to_ping.push(cc); } + true } else { - c.kill(); - return false; + c.1.take().map(|j| j.abort()); + false } - return true; - }) + } else { + let _ = c.1.take().map(|j| dead.push((sa.clone(), Some(j)))); + false + } }); drop(connections); // release lock + + for d in dead.iter_mut() { + d.1.take().unwrap().await.map_or_else(|e| { + self.host.on_connection_closed(&d.0, Some(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "timed out")))); + }, |r| { + if r.is_ok() { + self.host.on_connection_closed(&d.0, None); + } else { + self.host.on_connection_closed(&d.0, Some(Box::new(r.unwrap_err()))); + } + }); + } + for c in to_ping.iter() { - if c.send(&[MESSAGE_TYPE_NOP, 0], now).await.is_err() { - c.kill(); - } + let _ = c.send(&[MESSAGE_TYPE_NOP, 0], now).await; } } } @@ -120,7 +158,7 @@ impl NodeInternal { async fn listener_task_main(self: Arc, listener: TcpListener) { loop { let socket = listener.accept().await; - if socket.is_ok() { + if self.connections.lock().await.len() < self.host.max_endpoints() && socket.is_ok() { let (stream, endpoint) = socket.unwrap(); Self::connection_start(&self, endpoint, stream, true).await; } @@ -141,8 +179,9 @@ impl NodeInternal { preferred_ipv6: None }, ms_monotonic()).await?; + let mut init_received = false; let mut initialized = false; - let mut reader = BufReader::with_capacity(65536, reader); + let mut reader = BufReader::with_capacity(IO_BUFFER_SIZE, reader); let mut buf: Vec = Vec::new(); buf.resize(4096, 0); loop { @@ -156,20 +195,56 @@ impl NodeInternal { match message_type { MESSAGE_TYPE_INIT => { - if initialized { + if init_received { return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init")); } + let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?; + if !msg.domain.as_str().eq(self.db.domain()) { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set domain mismatch")); + } + if msg.key_size != D::KEY_SIZE as u16 || msg.max_value_size > D::MAX_VALUE_SIZE as u64 { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set key/value sizing mismatch")); + } + let mut antiloop = msg.anti_loopback_challenge.to_vec(); let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret); let antiloop = H::sha512(antiloop.as_slice()); - connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse { anti_loopback_response: &antiloop[0..16] }, now).await?; + init_received = true; + + let mut info = connection.info.lock().unwrap(); + info.node_name = msg.node_name.clone(); + info.node_contact = msg.node_contact.clone(); + let _ = msg.preferred_ipv4.map(|pv4| { + info.preferred_endpoints.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(pv4.ip[0], pv4.ip[1], pv4.ip[2], pv4.ip[3]), pv4.port))); + }); + let _ = msg.preferred_ipv6.map(|pv6| { + info.preferred_endpoints.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0))); + }); + }, + MESSAGE_TYPE_INIT_RESPONSE => { + if initialized { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init response")); + } + + let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?; + let mut antiloop = challenge.to_vec(); + let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret); + let antiloop = H::sha512(antiloop.as_slice()); + if msg.anti_loopback_response.eq(&antiloop[0..16]) { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "rejected connection to self")); + } + initialized = true; + let mut info = connection.info.lock().unwrap(); + info.initialized = true; + let info = info.clone(); + self.host.on_connect(&info); }, _ => { // Skip messages that aren't recognized or don't need to be parsed like NOP. @@ -185,21 +260,29 @@ impl NodeInternal { } } - async fn connection_start(self: &Arc, endpoint: SocketAddr, stream: TcpStream, incoming: bool) -> bool { + async fn connection_start(self: &Arc, endpoint: SocketAddr, stream: TcpStream, inbound: bool) -> bool { + let _ = stream.set_nodelay(true); let (reader, writer) = stream.into_split(); + let mut ok = false; let _ = self.connections.lock().await.entry(endpoint.clone()).or_insert_with(|| { ok = true; let now = ms_monotonic(); let connection = Arc::new(Connection { - writer: Mutex::new(writer), + writer: Mutex::new(BufWriter::with_capacity(IO_BUFFER_SIZE, writer)), last_send_time: AtomicI64::new(now), last_receive_time: AtomicI64::new(now), - io_task: std::sync::Mutex::new(None), - incoming + info: std::sync::Mutex::new(RemoteNodeInfo { + node_name: None, + node_contact: None, + endpoint: endpoint.clone(), + preferred_endpoints: Vec::new(), + connect_time: SystemTime::now(), + inbound, + initialized: false + }), }); - let _ = connection.io_task.lock().unwrap().insert(tokio::spawn(Self::connection_io_task_main(self.clone(), connection.clone(), reader))); - Arc::downgrade(&connection) + (Arc::downgrade(&connection), Some(tokio::spawn(self.clone().connection_io_task_main(connection.clone(), reader)))) }); ok } @@ -219,70 +302,59 @@ impl NodeInternal { } } -impl Drop for NodeInternal { +impl Drop for NodeInternal { fn drop(&mut self) { for (_, c) in self.connections.blocking_lock().drain() { - let _ = c.upgrade().map(|c| c.kill()); + c.1.map(|c| c.abort()); } } } -pub struct Node { - internal: Arc>, - housekeeping_task: JoinHandle<()>, - listener_task: JoinHandle<()> +struct Connection { + writer: Mutex>, + last_send_time: AtomicI64, + last_receive_time: AtomicI64, + info: std::sync::Mutex, } -impl Node { - pub async fn new(db: Arc, host: Arc, bind_address: SocketAddr) -> std::io::Result { - let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; - if listener.set_reuseport(true).is_err() { - listener.set_reuseaddr(true)?; +impl Connection { + async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> { + let mut writer = self.writer.lock().await; + writer.write_all(data).await?; + writer.flush().await?; + self.last_send_time.store(now, Ordering::Relaxed); + Ok(()) + } + + async fn send_obj(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> { + let data = rmp_serde::encode::to_vec_named(&obj); + if data.is_ok() { + let data = data.unwrap(); + let mut tmp = [0_u8; 16]; + tmp[0] = message_type; + let len = 1 + varint::encode(&mut tmp[1..], data.len() as u64); + let mut writer = self.writer.lock().await; + writer.write_all(&tmp[0..len]).await?; + writer.write_all(data.as_slice()).await?; + writer.flush().await?; + self.last_send_time.store(now, Ordering::Relaxed); + Ok(()) + } else { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure")) } - listener.bind(bind_address.clone())?; - let listener = listener.listen(1024)?; - - let internal = Arc::new(NodeInternal:: { - anti_loopback_secret: { - let mut tmp = [0_u8; 64]; - host.get_secure_random(&mut tmp); - tmp - }, - db: db.clone(), - host: host.clone(), - bind_address, - connections: Mutex::new(HashMap::with_capacity(64)), - }); - Ok(Self { - internal: internal.clone(), - housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()), - listener_task: tokio::spawn(internal.listener_task_main(listener)), - }) } - #[inline(always)] - pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result { - self.internal.connect(endpoint).await + async fn read_msg<'a>(&self, reader: &mut BufReader, buf: &'a mut Vec, message_size: usize, now: i64) -> std::io::Result<&'a [u8]> { + if message_size > buf.len() { + buf.resize(((message_size / 4096) + 1) * 4096, 0); + } + let b = &mut buf.as_mut_slice()[0..message_size]; + reader.read_exact(b).await?; + self.last_receive_time.store(now, Ordering::Relaxed); + Ok(b) } - pub fn list_connections(&self) -> Vec { - let mut connections = self.internal.connections.blocking_lock(); - let mut cl: Vec = Vec::with_capacity(connections.len()); - connections.retain(|sa, c| { - if c.strong_count() > 0 { - cl.push(sa.clone()); - true - } else { - false - } - }); - cl - } -} - -impl Drop for Node { - fn drop(&mut self) { - self.housekeeping_task.abort(); - self.listener_task.abort(); + async fn read_obj<'a, O: Deserialize<'a>>(&self, reader: &mut BufReader, buf: &'a mut Vec, message_size: usize, now: i64) -> std::io::Result { + rmp_serde::from_slice(self.read_msg(reader, buf, message_size, now).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())) } }