Sync stuff.

This commit is contained in:
Adam Ierymenko 2022-03-01 16:24:49 -05:00
parent 27389825da
commit ee6fc671e4
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
5 changed files with 242 additions and 147 deletions

View file

@ -37,7 +37,7 @@ pub enum StoreResult {
Rejected Rejected
} }
/// API to be implemented by the data store we want to replicate. /// API to be implemented by the data set we want to replicate.
/// ///
/// The API specified here supports temporally subjective data sets. These are data sets /// The API specified here supports temporally subjective data sets. These are data sets
/// where the existence or non-existence of a record may depend on the (real world) time. /// where the existence or non-existence of a record may depend on the (real world) time.
@ -45,11 +45,11 @@ pub enum StoreResult {
/// what time I think it is" value to be considered locally so that data can be replicated /// what time I think it is" value to be considered locally so that data can be replicated
/// as of any given time. /// as of any given time.
/// ///
/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a pure function of /// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a function of
/// values. If this is true, get_key() must be implemented. /// values. If this is true, get_key() must be implemented.
/// ///
/// The implementation must be thread safe. /// The implementation must be thread safe.
pub trait Database: Sync + Send { pub trait DataStore: Sync + Send {
/// Type to be enclosed in the Ok() enum value in LoadResult. /// Type to be enclosed in the Ok() enum value in LoadResult.
type LoadResultValueType: AsRef<[u8]> + Send; type LoadResultValueType: AsRef<[u8]> + Send;
@ -70,10 +70,13 @@ pub trait Database: Sync + Send {
/// If KEY_IS_COMPUTED is true this must be implemented. The default implementation /// If KEY_IS_COMPUTED is true this must be implemented. The default implementation
/// panics to indicate this. If KEY_IS_COMPUTED is false this is never called. /// panics to indicate this. If KEY_IS_COMPUTED is false this is never called.
#[allow(unused_variables)] #[allow(unused_variables)]
fn get_key(value: &[u8], key: &mut [u8]) { fn get_key(&self, value: &[u8], key: &mut [u8]) {
panic!("get_key() must be implemented if KEY_IS_COMPUTED is true"); panic!("get_key() must be implemented if KEY_IS_COMPUTED is true");
} }
/// Get the domain of this data store, which is just an arbitrary unique identifier.
fn domain(&self) -> &str;
/// Get an item if it exists as of a given reference time. /// Get an item if it exists as of a given reference time.
/// ///
/// The supplied key must be of length KEY_SIZE or this may panic. /// The supplied key must be of length KEY_SIZE or this may panic.
@ -115,24 +118,28 @@ pub trait Database: Sync + Send {
/// A simple in-memory data store backed by a BTreeMap. /// A simple in-memory data store backed by a BTreeMap.
pub struct MemoryDatabase<const KEY_SIZE: usize> { pub struct MemoryDatabase<const KEY_SIZE: usize> {
max_age: i64, max_age: i64,
domain: String,
db: Mutex<BTreeMap<[u8; KEY_SIZE], (i64, Arc<[u8]>)>> db: Mutex<BTreeMap<[u8; KEY_SIZE], (i64, Arc<[u8]>)>>
} }
impl<const KEY_SIZE: usize> MemoryDatabase<KEY_SIZE> { impl<const KEY_SIZE: usize> MemoryDatabase<KEY_SIZE> {
pub fn new(max_age: i64) -> Self { pub fn new(max_age: i64, domain: String) -> Self {
Self { Self {
max_age: if max_age > 0 { max_age } else { i64::MAX }, max_age: if max_age > 0 { max_age } else { i64::MAX },
domain,
db: Mutex::new(BTreeMap::new()) db: Mutex::new(BTreeMap::new())
} }
} }
} }
impl<const KEY_SIZE: usize> Database for MemoryDatabase<KEY_SIZE> { impl<const KEY_SIZE: usize> DataStore for MemoryDatabase<KEY_SIZE> {
type LoadResultValueType = Arc<[u8]>; type LoadResultValueType = Arc<[u8]>;
const KEY_SIZE: usize = KEY_SIZE; const KEY_SIZE: usize = KEY_SIZE;
const MAX_VALUE_SIZE: usize = 65536; const MAX_VALUE_SIZE: usize = 65536;
const KEY_IS_COMPUTED: bool = false; const KEY_IS_COMPUTED: bool = false;
fn domain(&self) -> &str { self.domain.as_str() }
fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult<Self::LoadResultValueType> { fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult<Self::LoadResultValueType> {
let db = self.db.lock().unwrap(); let db = self.db.lock().unwrap();
let e = db.get(key); let e = db.get(key);

View file

@ -10,6 +10,8 @@ use std::collections::HashSet;
use std::error::Error; use std::error::Error;
use std::net::SocketAddr; use std::net::SocketAddr;
use crate::node::RemoteNodeInfo;
/// A trait that users of syncwhole implement to provide configuration information and listen for events. /// A trait that users of syncwhole implement to provide configuration information and listen for events.
pub trait Host: Sync + Send { pub trait Host: Sync + Send {
/// Compute SHA512. /// Compute SHA512.
@ -27,11 +29,16 @@ pub trait Host: Sync + Send {
/// of the set is less than the minimum active link count the host wishes to maintain. /// of the set is less than the minimum active link count the host wishes to maintain.
fn get_more_endpoints(&self, current_endpoints: &HashSet<SocketAddr>) -> Vec<SocketAddr>; fn get_more_endpoints(&self, current_endpoints: &HashSet<SocketAddr>) -> Vec<SocketAddr>;
/// Called whenever we have successfully connected to a remote endpoint to (possibly) remember it. /// Get the maximum number of endpoints allowed.
fn on_connect_success(&self, endpoint: &SocketAddr); ///
/// This is checked on incoming connect and incoming links are refused if the total is over this count.
fn max_endpoints(&self) -> usize;
/// Called whenever an outgoing connection fails. /// Called whenever we have successfully connected to a remote node (after connection is initialized).
fn on_connect_failure(&self, endpoint: &SocketAddr, reason: Box<dyn Error>); fn on_connect(&self, info: &RemoteNodeInfo);
/// Called when an open connection is closed.
fn on_connection_closed(&self, endpoint: &SocketAddr, reason: Option<Box<dyn Error>>);
/// Fill a buffer with secure random bytes. /// Fill a buffer with secure random bytes.
fn get_secure_random(&self, buf: &mut [u8]); fn get_secure_random(&self, buf: &mut [u8]);

View file

@ -6,9 +6,9 @@
* https://www.zerotier.com/ * https://www.zerotier.com/
*/ */
use std::io::{Read, Write};
use std::mem::{size_of, zeroed}; use std::mem::{size_of, zeroed};
use std::ptr::write_bytes; use std::ptr::write_bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::varint; use crate::varint;
@ -103,13 +103,13 @@ impl<const B: usize> IBLT<B> {
unsafe { write_bytes((&mut self.map as *mut IBLTEntry).cast::<u8>(), 0, size_of::<[IBLTEntry; B]>()) }; unsafe { write_bytes((&mut self.map as *mut IBLTEntry).cast::<u8>(), 0, size_of::<[IBLTEntry; B]>()) };
} }
pub fn read<R: Read>(&mut self, r: &mut R) -> std::io::Result<()> { pub async fn read<R: AsyncReadExt + Unpin>(&mut self, r: &mut R) -> std::io::Result<()> {
r.read_exact(unsafe { &mut *(&mut self.salt as *mut u64).cast::<[u8; 8]>() })?; r.read_exact(unsafe { &mut *(&mut self.salt as *mut u64).cast::<[u8; 8]>() }).await?;
let mut prev_c = 0_i64; let mut prev_c = 0_i64;
for b in self.map.iter_mut() { for b in self.map.iter_mut() {
r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() })?; let _ = r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() }).await?;
r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() })?; let _ = r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() }).await?;
let mut c = varint::read(r)? as i64; let mut c = varint::read_async(r).await? as i64;
if (c & 1) == 0 { if (c & 1) == 0 {
c = c.wrapping_shr(1); c = c.wrapping_shr(1);
} else { } else {
@ -121,18 +121,18 @@ impl<const B: usize> IBLT<B> {
Ok(()) Ok(())
} }
pub fn write<W: Write>(&self, w: &mut W) -> std::io::Result<()> { pub async fn write<W: AsyncWriteExt + Unpin>(&self, w: &mut W) -> std::io::Result<()> {
w.write_all(unsafe { &*(&self.salt as *const u64).cast::<[u8; 8]>() })?; let _ = w.write_all(unsafe { &*(&self.salt as *const u64).cast::<[u8; 8]>() }).await?;
let mut prev_c = 0_i64; let mut prev_c = 0_i64;
for b in self.map.iter() { for b in self.map.iter() {
w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() })?; let _ = w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() }).await?;
w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() })?; let _ = w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() }).await?;
let mut c = (b.count - prev_c).wrapping_shl(1); let mut c = (b.count - prev_c).wrapping_shl(1);
prev_c = b.count; prev_c = b.count;
if c < 0 { if c < 0 {
c = -c | 1; c = -c | 1;
} }
varint::write(w, c as u64)?; let _ = varint::write_async(w, c as u64).await?;
} }
Ok(()) Ok(())
} }

View file

@ -1,7 +1,16 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
*
* (c)2022 ZeroTier, Inc.
* https://www.zerotier.com/
*/
pub(crate) mod varint; pub(crate) mod varint;
pub(crate) mod protocol; pub(crate) mod protocol;
pub(crate) mod iblt; pub(crate) mod iblt;
pub mod database;
pub mod datastore;
pub mod node; pub mod node;
pub mod host; pub mod host;

View file

@ -7,19 +7,19 @@
*/ */
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::atomic::{AtomicI64, Ordering};
use std::time::Duration; use std::time::{Duration, SystemTime};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use crate::database::Database; use crate::datastore::DataStore;
use crate::host::Host; use crate::host::Host;
use crate::ms_monotonic; use crate::ms_monotonic;
use crate::protocol::*; use crate::protocol::*;
@ -27,92 +27,130 @@ use crate::varint;
const CONNECTION_TIMEOUT: i64 = 60000; const CONNECTION_TIMEOUT: i64 = 60000;
const CONNECTION_KEEPALIVE_AFTER: i64 = 20000; const CONNECTION_KEEPALIVE_AFTER: i64 = 20000;
const IO_BUFFER_SIZE: usize = 65536;
struct Connection { #[derive(Clone, PartialEq, Eq)]
writer: Mutex<OwnedWriteHalf>, pub struct RemoteNodeInfo {
last_send_time: AtomicI64, pub node_name: Option<String>,
last_receive_time: AtomicI64, pub node_contact: Option<String>,
io_task: std::sync::Mutex<Option<JoinHandle<std::io::Result<()>>>>, pub endpoint: SocketAddr,
incoming: bool pub preferred_endpoints: Vec<SocketAddr>,
pub connect_time: SystemTime,
pub inbound: bool,
pub initialized: bool,
} }
impl Connection { pub struct Node<D: DataStore + 'static, H: Host + 'static> {
async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> { internal: Arc<NodeInternal<D, H>>,
self.writer.lock().await.write_all(data).await.map(|_| { housekeeping_task: JoinHandle<()>,
self.last_send_time.store(now, Ordering::Relaxed); listener_task: JoinHandle<()>
}
impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
pub async fn new(db: Arc<D>, host: Arc<H>, bind_address: SocketAddr) -> std::io::Result<Self> {
let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
if listener.set_reuseport(true).is_err() {
listener.set_reuseaddr(true)?;
}
listener.bind(bind_address.clone())?;
let listener = listener.listen(1024)?;
let internal = Arc::new(NodeInternal::<D, H> {
anti_loopback_secret: {
let mut tmp = [0_u8; 16];
host.get_secure_random(&mut tmp);
tmp
},
db: db.clone(),
host: host.clone(),
bind_address,
connections: Mutex::new(HashMap::with_capacity(64)),
});
Ok(Self {
internal: internal.clone(),
housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()),
listener_task: tokio::spawn(internal.listener_task_main(listener)),
}) })
} }
async fn send_obj<O: Serialize>(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> { #[inline(always)]
let data = rmp_serde::encode::to_vec_named(&obj); pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> {
if data.is_ok() { self.internal.connect(endpoint).await
let data = data.unwrap();
let mut tmp = [0_u8; 16];
tmp[0] = message_type;
let len = 1 + varint::encode(&mut tmp[1..], data.len() as u64);
let mut writer = self.writer.lock().await;
writer.write_all(&tmp[0..len]).await?;
writer.write_all(data.as_slice()).await?;
self.last_send_time.store(now, Ordering::Relaxed);
Ok(())
} else {
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure"))
}
} }
fn kill(&self) { pub fn list_connections(&self) -> Vec<RemoteNodeInfo> {
let _ = self.io_task.lock().unwrap().take().map(|h| h.abort()); let mut connections = self.internal.connections.blocking_lock();
} let mut cl: Vec<RemoteNodeInfo> = Vec::with_capacity(connections.len());
connections.retain(|_, c| {
async fn read_msg<'a>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<&'a [u8]> { c.0.upgrade().map_or(false, |c| {
if message_size > buf.len() { cl.push(c.info.lock().unwrap().clone());
buf.resize(((message_size / 4096) + 1) * 4096, 0); true
} })
let b = &mut buf.as_mut_slice()[0..message_size]; });
reader.read_exact(b).await?; cl
self.last_receive_time.store(now, Ordering::Relaxed);
Ok(b)
}
async fn read_obj<'a, O: Deserialize<'a>>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<O> {
rmp_serde::from_slice(self.read_msg(reader, buf, message_size, now).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
} }
} }
pub struct NodeInternal<D: Database + 'static, H: Host + 'static> { impl<D: DataStore + 'static, H: Host + 'static> Drop for Node<D, H> {
anti_loopback_secret: [u8; 64], fn drop(&mut self) {
self.housekeeping_task.abort();
self.listener_task.abort();
}
}
pub struct NodeInternal<D: DataStore + 'static, H: Host + 'static> {
anti_loopback_secret: [u8; 16],
db: Arc<D>, db: Arc<D>,
host: Arc<H>, host: Arc<H>,
bind_address: SocketAddr, bind_address: SocketAddr,
connections: Mutex<HashMap<SocketAddr, Weak<Connection>>>, connections: Mutex<HashMap<SocketAddr, (Weak<Connection>, Option<JoinHandle<std::io::Result<()>>>)>>,
} }
impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> { impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
async fn housekeeping_task_main(self: Arc<Self>) { async fn housekeeping_task_main(self: Arc<Self>) {
loop { loop {
tokio::time::sleep(Duration::from_millis(CONNECTION_KEEPALIVE_AFTER as u64)).await; tokio::time::sleep(Duration::from_millis((CONNECTION_KEEPALIVE_AFTER / 2) as u64)).await;
let mut to_ping: Vec<Arc<Connection>> = Vec::new(); let mut to_ping: Vec<Arc<Connection>> = Vec::new();
let mut dead: Vec<(SocketAddr, Option<JoinHandle<std::io::Result<()>>>)> = Vec::new();
let mut connections = self.connections.lock().await; let mut connections = self.connections.lock().await;
let now = ms_monotonic(); let now = ms_monotonic();
connections.retain(|_, c| { connections.retain(|sa, c| {
c.upgrade().map_or(false, |c| { let cc = c.0.upgrade();
if (now - c.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT { if cc.is_some() {
if (now - c.last_send_time.load(Ordering::Relaxed)) > CONNECTION_KEEPALIVE_AFTER { let cc = cc.unwrap();
to_ping.push(c); if (now - cc.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT {
if (now - cc.last_send_time.load(Ordering::Relaxed)) >= CONNECTION_KEEPALIVE_AFTER {
to_ping.push(cc);
} }
true
} else { } else {
c.kill(); c.1.take().map(|j| j.abort());
return false; false
} }
return true; } else {
}) let _ = c.1.take().map(|j| dead.push((sa.clone(), Some(j))));
false
}
}); });
drop(connections); // release lock drop(connections); // release lock
for d in dead.iter_mut() {
d.1.take().unwrap().await.map_or_else(|e| {
self.host.on_connection_closed(&d.0, Some(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "timed out"))));
}, |r| {
if r.is_ok() {
self.host.on_connection_closed(&d.0, None);
} else {
self.host.on_connection_closed(&d.0, Some(Box::new(r.unwrap_err())));
}
});
}
for c in to_ping.iter() { for c in to_ping.iter() {
if c.send(&[MESSAGE_TYPE_NOP, 0], now).await.is_err() { let _ = c.send(&[MESSAGE_TYPE_NOP, 0], now).await;
c.kill();
}
} }
} }
} }
@ -120,7 +158,7 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
async fn listener_task_main(self: Arc<Self>, listener: TcpListener) { async fn listener_task_main(self: Arc<Self>, listener: TcpListener) {
loop { loop {
let socket = listener.accept().await; let socket = listener.accept().await;
if socket.is_ok() { if self.connections.lock().await.len() < self.host.max_endpoints() && socket.is_ok() {
let (stream, endpoint) = socket.unwrap(); let (stream, endpoint) = socket.unwrap();
Self::connection_start(&self, endpoint, stream, true).await; Self::connection_start(&self, endpoint, stream, true).await;
} }
@ -141,8 +179,9 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
preferred_ipv6: None preferred_ipv6: None
}, ms_monotonic()).await?; }, ms_monotonic()).await?;
let mut init_received = false;
let mut initialized = false; let mut initialized = false;
let mut reader = BufReader::with_capacity(65536, reader); let mut reader = BufReader::with_capacity(IO_BUFFER_SIZE, reader);
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
buf.resize(4096, 0); buf.resize(4096, 0);
loop { loop {
@ -156,20 +195,56 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
match message_type { match message_type {
MESSAGE_TYPE_INIT => { MESSAGE_TYPE_INIT => {
if initialized { if init_received {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init")); return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init"));
} }
let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?; let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?;
if !msg.domain.as_str().eq(self.db.domain()) {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set domain mismatch"));
}
if msg.key_size != D::KEY_SIZE as u16 || msg.max_value_size > D::MAX_VALUE_SIZE as u64 {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set key/value sizing mismatch"));
}
let mut antiloop = msg.anti_loopback_challenge.to_vec(); let mut antiloop = msg.anti_loopback_challenge.to_vec();
let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret); let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret);
let antiloop = H::sha512(antiloop.as_slice()); let antiloop = H::sha512(antiloop.as_slice());
connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse { connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse {
anti_loopback_response: &antiloop[0..16] anti_loopback_response: &antiloop[0..16]
}, now).await?; }, now).await?;
init_received = true;
let mut info = connection.info.lock().unwrap();
info.node_name = msg.node_name.clone();
info.node_contact = msg.node_contact.clone();
let _ = msg.preferred_ipv4.map(|pv4| {
info.preferred_endpoints.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(pv4.ip[0], pv4.ip[1], pv4.ip[2], pv4.ip[3]), pv4.port)));
});
let _ = msg.preferred_ipv6.map(|pv6| {
info.preferred_endpoints.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0)));
});
},
MESSAGE_TYPE_INIT_RESPONSE => {
if initialized {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init response"));
}
let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?;
let mut antiloop = challenge.to_vec();
let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret);
let antiloop = H::sha512(antiloop.as_slice());
if msg.anti_loopback_response.eq(&antiloop[0..16]) {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "rejected connection to self"));
}
initialized = true; initialized = true;
let mut info = connection.info.lock().unwrap();
info.initialized = true;
let info = info.clone();
self.host.on_connect(&info);
}, },
_ => { _ => {
// Skip messages that aren't recognized or don't need to be parsed like NOP. // Skip messages that aren't recognized or don't need to be parsed like NOP.
@ -185,21 +260,29 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
} }
} }
async fn connection_start(self: &Arc<Self>, endpoint: SocketAddr, stream: TcpStream, incoming: bool) -> bool { async fn connection_start(self: &Arc<Self>, endpoint: SocketAddr, stream: TcpStream, inbound: bool) -> bool {
let _ = stream.set_nodelay(true);
let (reader, writer) = stream.into_split(); let (reader, writer) = stream.into_split();
let mut ok = false; let mut ok = false;
let _ = self.connections.lock().await.entry(endpoint.clone()).or_insert_with(|| { let _ = self.connections.lock().await.entry(endpoint.clone()).or_insert_with(|| {
ok = true; ok = true;
let now = ms_monotonic(); let now = ms_monotonic();
let connection = Arc::new(Connection { let connection = Arc::new(Connection {
writer: Mutex::new(writer), writer: Mutex::new(BufWriter::with_capacity(IO_BUFFER_SIZE, writer)),
last_send_time: AtomicI64::new(now), last_send_time: AtomicI64::new(now),
last_receive_time: AtomicI64::new(now), last_receive_time: AtomicI64::new(now),
io_task: std::sync::Mutex::new(None), info: std::sync::Mutex::new(RemoteNodeInfo {
incoming node_name: None,
node_contact: None,
endpoint: endpoint.clone(),
preferred_endpoints: Vec::new(),
connect_time: SystemTime::now(),
inbound,
initialized: false
}),
}); });
let _ = connection.io_task.lock().unwrap().insert(tokio::spawn(Self::connection_io_task_main(self.clone(), connection.clone(), reader))); (Arc::downgrade(&connection), Some(tokio::spawn(self.clone().connection_io_task_main(connection.clone(), reader))))
Arc::downgrade(&connection)
}); });
ok ok
} }
@ -219,70 +302,59 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
} }
} }
impl<D: Database + 'static, H: Host + 'static> Drop for NodeInternal<D, H> { impl<D: DataStore + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
fn drop(&mut self) { fn drop(&mut self) {
for (_, c) in self.connections.blocking_lock().drain() { for (_, c) in self.connections.blocking_lock().drain() {
let _ = c.upgrade().map(|c| c.kill()); c.1.map(|c| c.abort());
} }
} }
} }
pub struct Node<D: Database + 'static, H: Host + 'static> { struct Connection {
internal: Arc<NodeInternal<D, H>>, writer: Mutex<BufWriter<OwnedWriteHalf>>,
housekeeping_task: JoinHandle<()>, last_send_time: AtomicI64,
listener_task: JoinHandle<()> last_receive_time: AtomicI64,
info: std::sync::Mutex<RemoteNodeInfo>,
} }
impl<D: Database + 'static, H: Host + 'static> Node<D, H> { impl Connection {
pub async fn new(db: Arc<D>, host: Arc<H>, bind_address: SocketAddr) -> std::io::Result<Self> { async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> {
let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; let mut writer = self.writer.lock().await;
if listener.set_reuseport(true).is_err() { writer.write_all(data).await?;
listener.set_reuseaddr(true)?; writer.flush().await?;
self.last_send_time.store(now, Ordering::Relaxed);
Ok(())
}
async fn send_obj<O: Serialize>(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> {
let data = rmp_serde::encode::to_vec_named(&obj);
if data.is_ok() {
let data = data.unwrap();
let mut tmp = [0_u8; 16];
tmp[0] = message_type;
let len = 1 + varint::encode(&mut tmp[1..], data.len() as u64);
let mut writer = self.writer.lock().await;
writer.write_all(&tmp[0..len]).await?;
writer.write_all(data.as_slice()).await?;
writer.flush().await?;
self.last_send_time.store(now, Ordering::Relaxed);
Ok(())
} else {
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure"))
} }
listener.bind(bind_address.clone())?;
let listener = listener.listen(1024)?;
let internal = Arc::new(NodeInternal::<D, H> {
anti_loopback_secret: {
let mut tmp = [0_u8; 64];
host.get_secure_random(&mut tmp);
tmp
},
db: db.clone(),
host: host.clone(),
bind_address,
connections: Mutex::new(HashMap::with_capacity(64)),
});
Ok(Self {
internal: internal.clone(),
housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()),
listener_task: tokio::spawn(internal.listener_task_main(listener)),
})
} }
#[inline(always)] async fn read_msg<'a>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<&'a [u8]> {
pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> { if message_size > buf.len() {
self.internal.connect(endpoint).await buf.resize(((message_size / 4096) + 1) * 4096, 0);
}
let b = &mut buf.as_mut_slice()[0..message_size];
reader.read_exact(b).await?;
self.last_receive_time.store(now, Ordering::Relaxed);
Ok(b)
} }
pub fn list_connections(&self) -> Vec<SocketAddr> { async fn read_obj<'a, O: Deserialize<'a>>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<O> {
let mut connections = self.internal.connections.blocking_lock(); rmp_serde::from_slice(self.read_msg(reader, buf, message_size, now).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
let mut cl: Vec<SocketAddr> = Vec::with_capacity(connections.len());
connections.retain(|sa, c| {
if c.strong_count() > 0 {
cl.push(sa.clone());
true
} else {
false
}
});
cl
}
}
impl<D: Database + 'static, H: Host + 'static> Drop for Node<D, H> {
fn drop(&mut self) {
self.housekeeping_task.abort();
self.listener_task.abort();
} }
} }