mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-04-26 17:03:43 +02:00
Sync stuff.
This commit is contained in:
parent
27389825da
commit
ee6fc671e4
5 changed files with 242 additions and 147 deletions
|
@ -37,7 +37,7 @@ pub enum StoreResult {
|
|||
Rejected
|
||||
}
|
||||
|
||||
/// API to be implemented by the data store we want to replicate.
|
||||
/// API to be implemented by the data set we want to replicate.
|
||||
///
|
||||
/// The API specified here supports temporally subjective data sets. These are data sets
|
||||
/// where the existence or non-existence of a record may depend on the (real world) time.
|
||||
|
@ -45,11 +45,11 @@ pub enum StoreResult {
|
|||
/// what time I think it is" value to be considered locally so that data can be replicated
|
||||
/// as of any given time.
|
||||
///
|
||||
/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a pure function of
|
||||
/// The KEY_IS_COMPUTED constant must be set to indicate whether keys are a function of
|
||||
/// values. If this is true, get_key() must be implemented.
|
||||
///
|
||||
/// The implementation must be thread safe.
|
||||
pub trait Database: Sync + Send {
|
||||
pub trait DataStore: Sync + Send {
|
||||
/// Type to be enclosed in the Ok() enum value in LoadResult.
|
||||
type LoadResultValueType: AsRef<[u8]> + Send;
|
||||
|
||||
|
@ -70,10 +70,13 @@ pub trait Database: Sync + Send {
|
|||
/// If KEY_IS_COMPUTED is true this must be implemented. The default implementation
|
||||
/// panics to indicate this. If KEY_IS_COMPUTED is false this is never called.
|
||||
#[allow(unused_variables)]
|
||||
fn get_key(value: &[u8], key: &mut [u8]) {
|
||||
fn get_key(&self, value: &[u8], key: &mut [u8]) {
|
||||
panic!("get_key() must be implemented if KEY_IS_COMPUTED is true");
|
||||
}
|
||||
|
||||
/// Get the domain of this data store, which is just an arbitrary unique identifier.
|
||||
fn domain(&self) -> &str;
|
||||
|
||||
/// Get an item if it exists as of a given reference time.
|
||||
///
|
||||
/// The supplied key must be of length KEY_SIZE or this may panic.
|
||||
|
@ -115,24 +118,28 @@ pub trait Database: Sync + Send {
|
|||
/// A simple in-memory data store backed by a BTreeMap.
|
||||
pub struct MemoryDatabase<const KEY_SIZE: usize> {
|
||||
max_age: i64,
|
||||
domain: String,
|
||||
db: Mutex<BTreeMap<[u8; KEY_SIZE], (i64, Arc<[u8]>)>>
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> MemoryDatabase<KEY_SIZE> {
|
||||
pub fn new(max_age: i64) -> Self {
|
||||
pub fn new(max_age: i64, domain: String) -> Self {
|
||||
Self {
|
||||
max_age: if max_age > 0 { max_age } else { i64::MAX },
|
||||
domain,
|
||||
db: Mutex::new(BTreeMap::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const KEY_SIZE: usize> Database for MemoryDatabase<KEY_SIZE> {
|
||||
impl<const KEY_SIZE: usize> DataStore for MemoryDatabase<KEY_SIZE> {
|
||||
type LoadResultValueType = Arc<[u8]>;
|
||||
const KEY_SIZE: usize = KEY_SIZE;
|
||||
const MAX_VALUE_SIZE: usize = 65536;
|
||||
const KEY_IS_COMPUTED: bool = false;
|
||||
|
||||
fn domain(&self) -> &str { self.domain.as_str() }
|
||||
|
||||
fn load(&self, reference_time: i64, key: &[u8]) -> LoadResult<Self::LoadResultValueType> {
|
||||
let db = self.db.lock().unwrap();
|
||||
let e = db.get(key);
|
|
@ -10,6 +10,8 @@ use std::collections::HashSet;
|
|||
use std::error::Error;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use crate::node::RemoteNodeInfo;
|
||||
|
||||
/// A trait that users of syncwhole implement to provide configuration information and listen for events.
|
||||
pub trait Host: Sync + Send {
|
||||
/// Compute SHA512.
|
||||
|
@ -27,11 +29,16 @@ pub trait Host: Sync + Send {
|
|||
/// of the set is less than the minimum active link count the host wishes to maintain.
|
||||
fn get_more_endpoints(&self, current_endpoints: &HashSet<SocketAddr>) -> Vec<SocketAddr>;
|
||||
|
||||
/// Called whenever we have successfully connected to a remote endpoint to (possibly) remember it.
|
||||
fn on_connect_success(&self, endpoint: &SocketAddr);
|
||||
/// Get the maximum number of endpoints allowed.
|
||||
///
|
||||
/// This is checked on incoming connect and incoming links are refused if the total is over this count.
|
||||
fn max_endpoints(&self) -> usize;
|
||||
|
||||
/// Called whenever an outgoing connection fails.
|
||||
fn on_connect_failure(&self, endpoint: &SocketAddr, reason: Box<dyn Error>);
|
||||
/// Called whenever we have successfully connected to a remote node (after connection is initialized).
|
||||
fn on_connect(&self, info: &RemoteNodeInfo);
|
||||
|
||||
/// Called when an open connection is closed.
|
||||
fn on_connection_closed(&self, endpoint: &SocketAddr, reason: Option<Box<dyn Error>>);
|
||||
|
||||
/// Fill a buffer with secure random bytes.
|
||||
fn get_secure_random(&self, buf: &mut [u8]);
|
||||
|
|
|
@ -6,9 +6,9 @@
|
|||
* https://www.zerotier.com/
|
||||
*/
|
||||
|
||||
use std::io::{Read, Write};
|
||||
use std::mem::{size_of, zeroed};
|
||||
use std::ptr::write_bytes;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use crate::varint;
|
||||
|
||||
|
@ -103,13 +103,13 @@ impl<const B: usize> IBLT<B> {
|
|||
unsafe { write_bytes((&mut self.map as *mut IBLTEntry).cast::<u8>(), 0, size_of::<[IBLTEntry; B]>()) };
|
||||
}
|
||||
|
||||
pub fn read<R: Read>(&mut self, r: &mut R) -> std::io::Result<()> {
|
||||
r.read_exact(unsafe { &mut *(&mut self.salt as *mut u64).cast::<[u8; 8]>() })?;
|
||||
pub async fn read<R: AsyncReadExt + Unpin>(&mut self, r: &mut R) -> std::io::Result<()> {
|
||||
r.read_exact(unsafe { &mut *(&mut self.salt as *mut u64).cast::<[u8; 8]>() }).await?;
|
||||
let mut prev_c = 0_i64;
|
||||
for b in self.map.iter_mut() {
|
||||
r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() })?;
|
||||
r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() })?;
|
||||
let mut c = varint::read(r)? as i64;
|
||||
let _ = r.read_exact(unsafe { &mut *(&mut b.key_sum as *mut u64).cast::<[u8; 8]>() }).await?;
|
||||
let _ = r.read_exact(unsafe { &mut *(&mut b.check_hash_sum as *mut u64).cast::<[u8; 8]>() }).await?;
|
||||
let mut c = varint::read_async(r).await? as i64;
|
||||
if (c & 1) == 0 {
|
||||
c = c.wrapping_shr(1);
|
||||
} else {
|
||||
|
@ -121,18 +121,18 @@ impl<const B: usize> IBLT<B> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
|
||||
w.write_all(unsafe { &*(&self.salt as *const u64).cast::<[u8; 8]>() })?;
|
||||
pub async fn write<W: AsyncWriteExt + Unpin>(&self, w: &mut W) -> std::io::Result<()> {
|
||||
let _ = w.write_all(unsafe { &*(&self.salt as *const u64).cast::<[u8; 8]>() }).await?;
|
||||
let mut prev_c = 0_i64;
|
||||
for b in self.map.iter() {
|
||||
w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() })?;
|
||||
w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() })?;
|
||||
let _ = w.write_all(unsafe { &*(&b.key_sum as *const u64).cast::<[u8; 8]>() }).await?;
|
||||
let _ = w.write_all(unsafe { &*(&b.check_hash_sum as *const u64).cast::<[u8; 8]>() }).await?;
|
||||
let mut c = (b.count - prev_c).wrapping_shl(1);
|
||||
prev_c = b.count;
|
||||
if c < 0 {
|
||||
c = -c | 1;
|
||||
}
|
||||
varint::write(w, c as u64)?;
|
||||
let _ = varint::write_async(w, c as u64).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,7 +1,16 @@
|
|||
/* This Source Code Form is subject to the terms of the Mozilla Public
|
||||
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
*
|
||||
* (c)2022 ZeroTier, Inc.
|
||||
* https://www.zerotier.com/
|
||||
*/
|
||||
|
||||
pub(crate) mod varint;
|
||||
pub(crate) mod protocol;
|
||||
pub(crate) mod iblt;
|
||||
pub mod database;
|
||||
|
||||
pub mod datastore;
|
||||
pub mod node;
|
||||
pub mod host;
|
||||
|
||||
|
|
|
@ -7,19 +7,19 @@
|
|||
*/
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
use std::time::Duration;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::database::Database;
|
||||
use crate::datastore::DataStore;
|
||||
use crate::host::Host;
|
||||
use crate::ms_monotonic;
|
||||
use crate::protocol::*;
|
||||
|
@ -27,92 +27,130 @@ use crate::varint;
|
|||
|
||||
const CONNECTION_TIMEOUT: i64 = 60000;
|
||||
const CONNECTION_KEEPALIVE_AFTER: i64 = 20000;
|
||||
const IO_BUFFER_SIZE: usize = 65536;
|
||||
|
||||
struct Connection {
|
||||
writer: Mutex<OwnedWriteHalf>,
|
||||
last_send_time: AtomicI64,
|
||||
last_receive_time: AtomicI64,
|
||||
io_task: std::sync::Mutex<Option<JoinHandle<std::io::Result<()>>>>,
|
||||
incoming: bool
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct RemoteNodeInfo {
|
||||
pub node_name: Option<String>,
|
||||
pub node_contact: Option<String>,
|
||||
pub endpoint: SocketAddr,
|
||||
pub preferred_endpoints: Vec<SocketAddr>,
|
||||
pub connect_time: SystemTime,
|
||||
pub inbound: bool,
|
||||
pub initialized: bool,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> {
|
||||
self.writer.lock().await.write_all(data).await.map(|_| {
|
||||
self.last_send_time.store(now, Ordering::Relaxed);
|
||||
pub struct Node<D: DataStore + 'static, H: Host + 'static> {
|
||||
internal: Arc<NodeInternal<D, H>>,
|
||||
housekeeping_task: JoinHandle<()>,
|
||||
listener_task: JoinHandle<()>
|
||||
}
|
||||
|
||||
impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
|
||||
pub async fn new(db: Arc<D>, host: Arc<H>, bind_address: SocketAddr) -> std::io::Result<Self> {
|
||||
let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
|
||||
if listener.set_reuseport(true).is_err() {
|
||||
listener.set_reuseaddr(true)?;
|
||||
}
|
||||
listener.bind(bind_address.clone())?;
|
||||
let listener = listener.listen(1024)?;
|
||||
|
||||
let internal = Arc::new(NodeInternal::<D, H> {
|
||||
anti_loopback_secret: {
|
||||
let mut tmp = [0_u8; 16];
|
||||
host.get_secure_random(&mut tmp);
|
||||
tmp
|
||||
},
|
||||
db: db.clone(),
|
||||
host: host.clone(),
|
||||
bind_address,
|
||||
connections: Mutex::new(HashMap::with_capacity(64)),
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
internal: internal.clone(),
|
||||
housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()),
|
||||
listener_task: tokio::spawn(internal.listener_task_main(listener)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_obj<O: Serialize>(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> {
|
||||
let data = rmp_serde::encode::to_vec_named(&obj);
|
||||
if data.is_ok() {
|
||||
let data = data.unwrap();
|
||||
let mut tmp = [0_u8; 16];
|
||||
tmp[0] = message_type;
|
||||
let len = 1 + varint::encode(&mut tmp[1..], data.len() as u64);
|
||||
let mut writer = self.writer.lock().await;
|
||||
writer.write_all(&tmp[0..len]).await?;
|
||||
writer.write_all(data.as_slice()).await?;
|
||||
self.last_send_time.store(now, Ordering::Relaxed);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure"))
|
||||
#[inline(always)]
|
||||
pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> {
|
||||
self.internal.connect(endpoint).await
|
||||
}
|
||||
|
||||
pub fn list_connections(&self) -> Vec<RemoteNodeInfo> {
|
||||
let mut connections = self.internal.connections.blocking_lock();
|
||||
let mut cl: Vec<RemoteNodeInfo> = Vec::with_capacity(connections.len());
|
||||
connections.retain(|_, c| {
|
||||
c.0.upgrade().map_or(false, |c| {
|
||||
cl.push(c.info.lock().unwrap().clone());
|
||||
true
|
||||
})
|
||||
});
|
||||
cl
|
||||
}
|
||||
}
|
||||
|
||||
fn kill(&self) {
|
||||
let _ = self.io_task.lock().unwrap().take().map(|h| h.abort());
|
||||
}
|
||||
|
||||
async fn read_msg<'a>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<&'a [u8]> {
|
||||
if message_size > buf.len() {
|
||||
buf.resize(((message_size / 4096) + 1) * 4096, 0);
|
||||
}
|
||||
let b = &mut buf.as_mut_slice()[0..message_size];
|
||||
reader.read_exact(b).await?;
|
||||
self.last_receive_time.store(now, Ordering::Relaxed);
|
||||
Ok(b)
|
||||
}
|
||||
|
||||
async fn read_obj<'a, O: Deserialize<'a>>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<O> {
|
||||
rmp_serde::from_slice(self.read_msg(reader, buf, message_size, now).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
|
||||
impl<D: DataStore + 'static, H: Host + 'static> Drop for Node<D, H> {
|
||||
fn drop(&mut self) {
|
||||
self.housekeeping_task.abort();
|
||||
self.listener_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NodeInternal<D: Database + 'static, H: Host + 'static> {
|
||||
anti_loopback_secret: [u8; 64],
|
||||
pub struct NodeInternal<D: DataStore + 'static, H: Host + 'static> {
|
||||
anti_loopback_secret: [u8; 16],
|
||||
db: Arc<D>,
|
||||
host: Arc<H>,
|
||||
bind_address: SocketAddr,
|
||||
connections: Mutex<HashMap<SocketAddr, Weak<Connection>>>,
|
||||
connections: Mutex<HashMap<SocketAddr, (Weak<Connection>, Option<JoinHandle<std::io::Result<()>>>)>>,
|
||||
}
|
||||
|
||||
impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
|
||||
impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
||||
async fn housekeeping_task_main(self: Arc<Self>) {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_millis(CONNECTION_KEEPALIVE_AFTER as u64)).await;
|
||||
tokio::time::sleep(Duration::from_millis((CONNECTION_KEEPALIVE_AFTER / 2) as u64)).await;
|
||||
|
||||
let mut to_ping: Vec<Arc<Connection>> = Vec::new();
|
||||
let mut dead: Vec<(SocketAddr, Option<JoinHandle<std::io::Result<()>>>)> = Vec::new();
|
||||
|
||||
let mut connections = self.connections.lock().await;
|
||||
let now = ms_monotonic();
|
||||
connections.retain(|_, c| {
|
||||
c.upgrade().map_or(false, |c| {
|
||||
if (now - c.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT {
|
||||
if (now - c.last_send_time.load(Ordering::Relaxed)) > CONNECTION_KEEPALIVE_AFTER {
|
||||
to_ping.push(c);
|
||||
connections.retain(|sa, c| {
|
||||
let cc = c.0.upgrade();
|
||||
if cc.is_some() {
|
||||
let cc = cc.unwrap();
|
||||
if (now - cc.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT {
|
||||
if (now - cc.last_send_time.load(Ordering::Relaxed)) >= CONNECTION_KEEPALIVE_AFTER {
|
||||
to_ping.push(cc);
|
||||
}
|
||||
true
|
||||
} else {
|
||||
c.1.take().map(|j| j.abort());
|
||||
false
|
||||
}
|
||||
} else {
|
||||
c.kill();
|
||||
return false;
|
||||
let _ = c.1.take().map(|j| dead.push((sa.clone(), Some(j))));
|
||||
false
|
||||
}
|
||||
return true;
|
||||
})
|
||||
});
|
||||
drop(connections); // release lock
|
||||
for c in to_ping.iter() {
|
||||
if c.send(&[MESSAGE_TYPE_NOP, 0], now).await.is_err() {
|
||||
c.kill();
|
||||
|
||||
for d in dead.iter_mut() {
|
||||
d.1.take().unwrap().await.map_or_else(|e| {
|
||||
self.host.on_connection_closed(&d.0, Some(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "timed out"))));
|
||||
}, |r| {
|
||||
if r.is_ok() {
|
||||
self.host.on_connection_closed(&d.0, None);
|
||||
} else {
|
||||
self.host.on_connection_closed(&d.0, Some(Box::new(r.unwrap_err())));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for c in to_ping.iter() {
|
||||
let _ = c.send(&[MESSAGE_TYPE_NOP, 0], now).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -120,7 +158,7 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
async fn listener_task_main(self: Arc<Self>, listener: TcpListener) {
|
||||
loop {
|
||||
let socket = listener.accept().await;
|
||||
if socket.is_ok() {
|
||||
if self.connections.lock().await.len() < self.host.max_endpoints() && socket.is_ok() {
|
||||
let (stream, endpoint) = socket.unwrap();
|
||||
Self::connection_start(&self, endpoint, stream, true).await;
|
||||
}
|
||||
|
@ -141,8 +179,9 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
preferred_ipv6: None
|
||||
}, ms_monotonic()).await?;
|
||||
|
||||
let mut init_received = false;
|
||||
let mut initialized = false;
|
||||
let mut reader = BufReader::with_capacity(65536, reader);
|
||||
let mut reader = BufReader::with_capacity(IO_BUFFER_SIZE, reader);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.resize(4096, 0);
|
||||
loop {
|
||||
|
@ -156,20 +195,56 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
|
||||
match message_type {
|
||||
MESSAGE_TYPE_INIT => {
|
||||
if initialized {
|
||||
if init_received {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init"));
|
||||
}
|
||||
|
||||
let msg: msg::Init = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?;
|
||||
|
||||
if !msg.domain.as_str().eq(self.db.domain()) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set domain mismatch"));
|
||||
}
|
||||
if msg.key_size != D::KEY_SIZE as u16 || msg.max_value_size > D::MAX_VALUE_SIZE as u64 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "data set key/value sizing mismatch"));
|
||||
}
|
||||
|
||||
let mut antiloop = msg.anti_loopback_challenge.to_vec();
|
||||
let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret);
|
||||
let antiloop = H::sha512(antiloop.as_slice());
|
||||
|
||||
connection.send_obj(MESSAGE_TYPE_INIT_RESPONSE, &msg::InitResponse {
|
||||
anti_loopback_response: &antiloop[0..16]
|
||||
}, now).await?;
|
||||
|
||||
init_received = true;
|
||||
|
||||
let mut info = connection.info.lock().unwrap();
|
||||
info.node_name = msg.node_name.clone();
|
||||
info.node_contact = msg.node_contact.clone();
|
||||
let _ = msg.preferred_ipv4.map(|pv4| {
|
||||
info.preferred_endpoints.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(pv4.ip[0], pv4.ip[1], pv4.ip[2], pv4.ip[3]), pv4.port)));
|
||||
});
|
||||
let _ = msg.preferred_ipv6.map(|pv6| {
|
||||
info.preferred_endpoints.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0)));
|
||||
});
|
||||
},
|
||||
MESSAGE_TYPE_INIT_RESPONSE => {
|
||||
if initialized {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "duplicate init response"));
|
||||
}
|
||||
|
||||
let msg: msg::InitResponse = connection.read_obj(&mut reader, &mut buf, message_size as usize, now).await?;
|
||||
let mut antiloop = challenge.to_vec();
|
||||
let _ = std::io::Write::write_all(&mut antiloop, &self.anti_loopback_secret);
|
||||
let antiloop = H::sha512(antiloop.as_slice());
|
||||
if msg.anti_loopback_response.eq(&antiloop[0..16]) {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "rejected connection to self"));
|
||||
}
|
||||
|
||||
initialized = true;
|
||||
let mut info = connection.info.lock().unwrap();
|
||||
info.initialized = true;
|
||||
let info = info.clone();
|
||||
self.host.on_connect(&info);
|
||||
},
|
||||
_ => {
|
||||
// Skip messages that aren't recognized or don't need to be parsed like NOP.
|
||||
|
@ -185,21 +260,29 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
}
|
||||
}
|
||||
|
||||
async fn connection_start(self: &Arc<Self>, endpoint: SocketAddr, stream: TcpStream, incoming: bool) -> bool {
|
||||
async fn connection_start(self: &Arc<Self>, endpoint: SocketAddr, stream: TcpStream, inbound: bool) -> bool {
|
||||
let _ = stream.set_nodelay(true);
|
||||
let (reader, writer) = stream.into_split();
|
||||
|
||||
let mut ok = false;
|
||||
let _ = self.connections.lock().await.entry(endpoint.clone()).or_insert_with(|| {
|
||||
ok = true;
|
||||
let now = ms_monotonic();
|
||||
let connection = Arc::new(Connection {
|
||||
writer: Mutex::new(writer),
|
||||
writer: Mutex::new(BufWriter::with_capacity(IO_BUFFER_SIZE, writer)),
|
||||
last_send_time: AtomicI64::new(now),
|
||||
last_receive_time: AtomicI64::new(now),
|
||||
io_task: std::sync::Mutex::new(None),
|
||||
incoming
|
||||
info: std::sync::Mutex::new(RemoteNodeInfo {
|
||||
node_name: None,
|
||||
node_contact: None,
|
||||
endpoint: endpoint.clone(),
|
||||
preferred_endpoints: Vec::new(),
|
||||
connect_time: SystemTime::now(),
|
||||
inbound,
|
||||
initialized: false
|
||||
}),
|
||||
});
|
||||
let _ = connection.io_task.lock().unwrap().insert(tokio::spawn(Self::connection_io_task_main(self.clone(), connection.clone(), reader)));
|
||||
Arc::downgrade(&connection)
|
||||
(Arc::downgrade(&connection), Some(tokio::spawn(self.clone().connection_io_task_main(connection.clone(), reader))))
|
||||
});
|
||||
ok
|
||||
}
|
||||
|
@ -219,70 +302,59 @@ impl<D: Database + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<D: Database + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
|
||||
impl<D: DataStore + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
|
||||
fn drop(&mut self) {
|
||||
for (_, c) in self.connections.blocking_lock().drain() {
|
||||
let _ = c.upgrade().map(|c| c.kill());
|
||||
c.1.map(|c| c.abort());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Node<D: Database + 'static, H: Host + 'static> {
|
||||
internal: Arc<NodeInternal<D, H>>,
|
||||
housekeeping_task: JoinHandle<()>,
|
||||
listener_task: JoinHandle<()>
|
||||
struct Connection {
|
||||
writer: Mutex<BufWriter<OwnedWriteHalf>>,
|
||||
last_send_time: AtomicI64,
|
||||
last_receive_time: AtomicI64,
|
||||
info: std::sync::Mutex<RemoteNodeInfo>,
|
||||
}
|
||||
|
||||
impl<D: Database + 'static, H: Host + 'static> Node<D, H> {
|
||||
pub async fn new(db: Arc<D>, host: Arc<H>, bind_address: SocketAddr) -> std::io::Result<Self> {
|
||||
let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
|
||||
if listener.set_reuseport(true).is_err() {
|
||||
listener.set_reuseaddr(true)?;
|
||||
}
|
||||
listener.bind(bind_address.clone())?;
|
||||
let listener = listener.listen(1024)?;
|
||||
|
||||
let internal = Arc::new(NodeInternal::<D, H> {
|
||||
anti_loopback_secret: {
|
||||
let mut tmp = [0_u8; 64];
|
||||
host.get_secure_random(&mut tmp);
|
||||
tmp
|
||||
},
|
||||
db: db.clone(),
|
||||
host: host.clone(),
|
||||
bind_address,
|
||||
connections: Mutex::new(HashMap::with_capacity(64)),
|
||||
});
|
||||
Ok(Self {
|
||||
internal: internal.clone(),
|
||||
housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()),
|
||||
listener_task: tokio::spawn(internal.listener_task_main(listener)),
|
||||
})
|
||||
impl Connection {
|
||||
async fn send(&self, data: &[u8], now: i64) -> std::io::Result<()> {
|
||||
let mut writer = self.writer.lock().await;
|
||||
writer.write_all(data).await?;
|
||||
writer.flush().await?;
|
||||
self.last_send_time.store(now, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> {
|
||||
self.internal.connect(endpoint).await
|
||||
}
|
||||
|
||||
pub fn list_connections(&self) -> Vec<SocketAddr> {
|
||||
let mut connections = self.internal.connections.blocking_lock();
|
||||
let mut cl: Vec<SocketAddr> = Vec::with_capacity(connections.len());
|
||||
connections.retain(|sa, c| {
|
||||
if c.strong_count() > 0 {
|
||||
cl.push(sa.clone());
|
||||
true
|
||||
async fn send_obj<O: Serialize>(&self, message_type: u8, obj: &O, now: i64) -> std::io::Result<()> {
|
||||
let data = rmp_serde::encode::to_vec_named(&obj);
|
||||
if data.is_ok() {
|
||||
let data = data.unwrap();
|
||||
let mut tmp = [0_u8; 16];
|
||||
tmp[0] = message_type;
|
||||
let len = 1 + varint::encode(&mut tmp[1..], data.len() as u64);
|
||||
let mut writer = self.writer.lock().await;
|
||||
writer.write_all(&tmp[0..len]).await?;
|
||||
writer.write_all(data.as_slice()).await?;
|
||||
writer.flush().await?;
|
||||
self.last_send_time.store(now, Ordering::Relaxed);
|
||||
Ok(())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
cl
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Database + 'static, H: Host + 'static> Drop for Node<D, H> {
|
||||
fn drop(&mut self) {
|
||||
self.housekeeping_task.abort();
|
||||
self.listener_task.abort();
|
||||
async fn read_msg<'a>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<&'a [u8]> {
|
||||
if message_size > buf.len() {
|
||||
buf.resize(((message_size / 4096) + 1) * 4096, 0);
|
||||
}
|
||||
let b = &mut buf.as_mut_slice()[0..message_size];
|
||||
reader.read_exact(b).await?;
|
||||
self.last_receive_time.store(now, Ordering::Relaxed);
|
||||
Ok(b)
|
||||
}
|
||||
|
||||
async fn read_obj<'a, O: Deserialize<'a>>(&self, reader: &mut BufReader<OwnedReadHalf>, buf: &'a mut Vec<u8>, message_size: usize, now: i64) -> std::io::Result<O> {
|
||||
rmp_serde::from_slice(self.read_msg(reader, buf, message_size, now).await?).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue