mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-04-26 08:57:26 +02:00
620 lines
26 KiB
Rust
620 lines
26 KiB
Rust
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
*
|
|
* (c)2022 ZeroTier, Inc.
|
|
* https://www.zerotier.com/
|
|
*/
|
|
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::io::IoSlice;
|
|
use std::mem::MaybeUninit;
|
|
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
|
use std::ops::Add;
|
|
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
|
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
|
use tokio::sync::Mutex;
|
|
use tokio::task::JoinHandle;
|
|
use tokio::time::{Duration, Instant};
|
|
|
|
use crate::datastore::*;
|
|
use crate::host::Host;
|
|
use crate::iblt::IBLT;
|
|
use crate::protocol::*;
|
|
use crate::utils::*;
|
|
use crate::varint;
|
|
|
|
/// Period for running main housekeeping pass.
|
|
const HOUSEKEEPING_PERIOD: i64 = SYNC_STATUS_PERIOD;
|
|
|
|
/// Inactivity timeout for connections in milliseconds.
|
|
const CONNECTION_TIMEOUT: i64 = SYNC_STATUS_PERIOD * 4;
|
|
|
|
/// Information about a remote node to which we are connected.
|
|
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct RemoteNodeInfo {
|
|
/// Optional name advertised by remote node (arbitrary).
|
|
pub name: String,
|
|
|
|
/// Optional contact information advertised by remote node (arbitrary).
|
|
pub contact: String,
|
|
|
|
/// Actual remote endpoint address.
|
|
pub remote_address: SocketAddr,
|
|
|
|
/// Explicitly advertised remote addresses supplied by remote node (not necessarily verified).
|
|
pub explicit_addresses: Vec<SocketAddr>,
|
|
|
|
/// Time TCP connection was established (ms since epoch).
|
|
pub connect_time: i64,
|
|
|
|
/// Time TCP connection was estaablished (ms, monotonic).
|
|
pub connect_instant: i64,
|
|
|
|
/// True if this is an inbound TCP connection.
|
|
pub inbound: bool,
|
|
|
|
/// True if this connection has exchanged init messages successfully.
|
|
pub initialized: bool,
|
|
}
|
|
|
|
/// An instance of the syncwhole data set synchronization engine.
|
|
///
|
|
/// This holds a number of async tasks that are terminated or aborted if this object
|
|
/// is dropped. In other words this implements structured concurrency.
|
|
pub struct Node<D: DataStore + 'static, H: Host + 'static> {
|
|
internal: Arc<NodeInternal<D, H>>,
|
|
housekeeping_task: JoinHandle<()>,
|
|
announce_task: JoinHandle<()>,
|
|
listener_task: JoinHandle<()>,
|
|
}
|
|
|
|
impl<D: DataStore + 'static, H: Host + 'static> Node<D, H> {
|
|
pub async fn new(db: Arc<D>, host: Arc<H>, bind_address: SocketAddr) -> std::io::Result<Self> {
|
|
let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
|
|
configure_tcp_socket(&listener)?;
|
|
listener.bind(bind_address.clone())?;
|
|
let listener = listener.listen(1024)?;
|
|
|
|
let internal = Arc::new(NodeInternal::<D, H> {
|
|
anti_loopback_secret: {
|
|
let mut tmp = [0_u8; 64];
|
|
host.get_secure_random(&mut tmp);
|
|
tmp
|
|
},
|
|
datastore: db.clone(),
|
|
host: host.clone(),
|
|
connections: Mutex::new(HashMap::with_capacity(64)),
|
|
announce_queue: Mutex::new(HashMap::with_capacity(256)),
|
|
bind_address,
|
|
starting_instant: Instant::now(),
|
|
});
|
|
|
|
Ok(Self {
|
|
internal: internal.clone(),
|
|
housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()),
|
|
announce_task: tokio::spawn(internal.clone().announce_task_main()),
|
|
listener_task: tokio::spawn(internal.listener_task_main(listener)),
|
|
})
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn datastore(&self) -> &Arc<D> {
|
|
&self.internal.datastore
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub fn host(&self) -> &Arc<H> {
|
|
&self.internal.host
|
|
}
|
|
|
|
/// Attempt to connect to an explicitly specified TCP endpoint.
|
|
pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result<bool> {
|
|
self.internal.clone().connect(endpoint, Instant::now().add(Duration::from_millis(CONNECTION_TIMEOUT as u64))).await
|
|
}
|
|
|
|
/// Get open peer to peer connections.
|
|
pub async fn list_connections(&self) -> Vec<RemoteNodeInfo> {
|
|
let connections = self.internal.connections.lock().await;
|
|
let mut cl: Vec<RemoteNodeInfo> = Vec::with_capacity(connections.len());
|
|
for (_, c) in connections.iter() {
|
|
cl.push(c.info.lock().unwrap().clone());
|
|
}
|
|
cl
|
|
}
|
|
|
|
/// Get the number of open peer to peer connections.
|
|
pub async fn connection_count(&self) -> usize {
|
|
self.internal.connections.lock().await.len()
|
|
}
|
|
}
|
|
|
|
impl<D: DataStore + 'static, H: Host + 'static> Drop for Node<D, H> {
|
|
fn drop(&mut self) {
|
|
self.housekeeping_task.abort();
|
|
self.announce_task.abort();
|
|
self.listener_task.abort();
|
|
}
|
|
}
|
|
|
|
/********************************************************************************************************************/
|
|
|
|
fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> {
|
|
let _ = socket.set_linger(None);
|
|
if socket.set_reuseport(true).is_ok() {
|
|
Ok(())
|
|
} else {
|
|
socket.set_reuseaddr(true)
|
|
}
|
|
}
|
|
|
|
fn decode_msgpack<'a, T: Deserialize<'a>>(b: &'a [u8]) -> std::io::Result<T> {
|
|
rmp_serde::from_slice(b).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid msgpack object: {}", e.to_string())))
|
|
}
|
|
|
|
pub struct NodeInternal<D: DataStore + 'static, H: Host + 'static> {
|
|
// Secret used to perform HMAC to detect and drop loopback connections to self.
|
|
anti_loopback_secret: [u8; 64],
|
|
|
|
// Outside code implementations of DataStore and Host traits.
|
|
datastore: Arc<D>,
|
|
host: Arc<H>,
|
|
|
|
// Connections and their task join handles, by remote endpoint address.
|
|
connections: Mutex<HashMap<SocketAddr, Arc<Connection>>>,
|
|
|
|
// Records received since last announce and the endpoints that we know already have them.
|
|
announce_queue: Mutex<HashMap<[u8; ANNOUNCE_KEY_LEN], Vec<SocketAddr>>>,
|
|
|
|
// Local address to which this node is bound
|
|
bind_address: SocketAddr,
|
|
|
|
// Instant this node started.
|
|
starting_instant: Instant,
|
|
}
|
|
|
|
impl<D: DataStore + 'static, H: Host + 'static> NodeInternal<D, H> {
|
|
fn ms_monotonic(&self) -> i64 {
|
|
Instant::now().duration_since(self.starting_instant).as_millis() as i64
|
|
}
|
|
|
|
async fn housekeeping_task_main(self: Arc<Self>) {
|
|
let mut tasks: Vec<JoinHandle<()>> = Vec::new();
|
|
let mut counts: Vec<u64> = Vec::new();
|
|
let mut connected_to_addresses: HashSet<SocketAddr> = HashSet::new();
|
|
let mut sleep_until = Instant::now().add(Duration::from_millis(500));
|
|
loop {
|
|
tokio::time::sleep_until(sleep_until).await;
|
|
sleep_until = sleep_until.add(Duration::from_millis(HOUSEKEEPING_PERIOD as u64));
|
|
|
|
tasks.clear();
|
|
counts.clear();
|
|
connected_to_addresses.clear();
|
|
let now = self.ms_monotonic();
|
|
|
|
self.connections.lock().await.retain(|sa, c| {
|
|
if !c.closed.load(Ordering::Relaxed) {
|
|
let cc = c.clone();
|
|
if (now - c.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT {
|
|
connected_to_addresses.insert(sa.clone());
|
|
true // keep connection
|
|
} else {
|
|
let _ = c.read_task.lock().unwrap().take().map(|j| j.abort());
|
|
let host = self.host.clone();
|
|
tasks.push(tokio::spawn(async move {
|
|
host.on_connection_closed(&*cc.info.lock().unwrap(), "timeout".to_string());
|
|
}));
|
|
false // discard connection
|
|
}
|
|
} else {
|
|
let host = self.host.clone();
|
|
let cc = c.clone();
|
|
let j = c.read_task.lock().unwrap().take();
|
|
tasks.push(tokio::spawn(async move {
|
|
if j.is_some() {
|
|
let e = j.unwrap().await;
|
|
if e.is_ok() {
|
|
let e = e.unwrap();
|
|
host.on_connection_closed(&*cc.info.lock().unwrap(), e.map_or_else(|e| e.to_string(), |_| "unknown error".to_string()));
|
|
} else {
|
|
host.on_connection_closed(&*cc.info.lock().unwrap(), "remote host closed connection".to_string());
|
|
}
|
|
} else {
|
|
host.on_connection_closed(&*cc.info.lock().unwrap(), "remote host closed connection".to_string());
|
|
}
|
|
}));
|
|
false // discard connection
|
|
}
|
|
});
|
|
|
|
let config = self.host.node_config();
|
|
|
|
// Always try to connect to anchor peers.
|
|
for sa in config.anchors.iter() {
|
|
if !connected_to_addresses.contains(sa) {
|
|
let sa = sa.clone();
|
|
let self2 = self.clone();
|
|
tasks.push(tokio::spawn(async move {
|
|
let _ = self2.connect(&sa, sleep_until).await;
|
|
}));
|
|
connected_to_addresses.insert(sa.clone());
|
|
}
|
|
}
|
|
|
|
// Try to connect to more peers until desired connection count is reached.
|
|
let desired_connection_count = config.desired_connection_count.min(config.max_connection_count);
|
|
for sa in config.seeds.iter() {
|
|
if connected_to_addresses.len() >= desired_connection_count {
|
|
break;
|
|
}
|
|
if !connected_to_addresses.contains(sa) {
|
|
connected_to_addresses.insert(sa.clone());
|
|
let self2 = self.clone();
|
|
let sa = sa.clone();
|
|
tasks.push(tokio::spawn(async move {
|
|
let _ = self2.connect(&sa, sleep_until).await;
|
|
}));
|
|
}
|
|
}
|
|
|
|
// Wait for this iteration's batched background tasks to complete.
|
|
loop {
|
|
let s = tasks.pop();
|
|
if s.is_some() {
|
|
let _ = s.unwrap().await;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn announce_task_main(self: Arc<Self>) {
|
|
let mut sleep_until = Instant::now().add(Duration::from_millis(ANNOUNCE_PERIOD as u64));
|
|
let mut to_announce: Vec<([u8; ANNOUNCE_KEY_LEN], Vec<SocketAddr>)> = Vec::with_capacity(256);
|
|
let background_tasks = AsyncTaskReaper::new();
|
|
let announce_timeout = Duration::from_millis(CONNECTION_TIMEOUT as u64);
|
|
loop {
|
|
tokio::time::sleep_until(sleep_until).await;
|
|
sleep_until = sleep_until.add(Duration::from_millis(ANNOUNCE_PERIOD as u64));
|
|
|
|
for (key, already_has) in self.announce_queue.lock().await.drain() {
|
|
to_announce.push((key, already_has));
|
|
}
|
|
|
|
let now = self.ms_monotonic();
|
|
for c in self.connections.lock().await.iter() {
|
|
if c.1.announce_new_records.load(Ordering::Relaxed) {
|
|
let mut have_records: Vec<u8> = Vec::with_capacity((to_announce.len() * ANNOUNCE_KEY_LEN) + 4);
|
|
have_records.push(ANNOUNCE_KEY_LEN as u8);
|
|
for (key, already_has) in to_announce.iter() {
|
|
if !already_has.contains(c.0) {
|
|
let _ = std::io::Write::write_all(&mut have_records, key);
|
|
}
|
|
}
|
|
if have_records.len() > 1 {
|
|
let c2 = c.1.clone();
|
|
background_tasks.spawn(async move {
|
|
// If the connection dies this will either fail or time out in 1s. Usually these execute instantly due to
|
|
// write buffering but a short timeout prevents them from building up too much.
|
|
let _ = tokio::time::timeout(announce_timeout, c2.send_msg(MessageType::HaveRecords, have_records.as_slice(), now));
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
to_announce.clear();
|
|
}
|
|
}
|
|
|
|
async fn listener_task_main(self: Arc<Self>, listener: TcpListener) {
|
|
loop {
|
|
let socket = listener.accept().await;
|
|
if socket.is_ok() {
|
|
let (stream, address) = socket.unwrap();
|
|
if self.host.allow(&address) {
|
|
let config = self.host.node_config();
|
|
if self.connections.lock().await.len() < config.max_connection_count || config.anchors.contains(&address) {
|
|
Self::connection_start(&self, address, stream, true).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn connect(self: Arc<Self>, address: &SocketAddr, deadline: Instant) -> std::io::Result<bool> {
|
|
self.host.on_connect_attempt(address);
|
|
let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?;
|
|
configure_tcp_socket(&stream)?;
|
|
stream.bind(self.bind_address.clone())?;
|
|
let stream = tokio::time::timeout_at(deadline, stream.connect(address.clone())).await;
|
|
if stream.is_ok() {
|
|
Ok(self.connection_start(address.clone(), stream.unwrap()?, false).await)
|
|
} else {
|
|
Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "connect timed out"))
|
|
}
|
|
}
|
|
|
|
async fn connection_start(self: &Arc<Self>, address: SocketAddr, stream: TcpStream, inbound: bool) -> bool {
|
|
let mut ok = false;
|
|
let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| {
|
|
ok = true;
|
|
let _ = stream.set_nodelay(false);
|
|
let (reader, writer) = stream.into_split();
|
|
let now = self.ms_monotonic();
|
|
let connection = Arc::new(Connection {
|
|
writer: Mutex::new(writer),
|
|
last_send_time: AtomicI64::new(now),
|
|
last_receive_time: AtomicI64::new(now),
|
|
info: std::sync::Mutex::new(RemoteNodeInfo {
|
|
name: String::new(),
|
|
contact: String::new(),
|
|
remote_address: address.clone(),
|
|
explicit_addresses: Vec::new(),
|
|
connect_time: ms_since_epoch(),
|
|
connect_instant: now,
|
|
inbound,
|
|
initialized: false,
|
|
}),
|
|
read_task: std::sync::Mutex::new(None),
|
|
announce_new_records: AtomicBool::new(false),
|
|
closed: AtomicBool::new(false),
|
|
});
|
|
let self2 = self.clone();
|
|
let c2 = connection.clone();
|
|
connection.read_task.lock().unwrap().replace(tokio::spawn(async move {
|
|
let result = self2.connection_io_task_main(&c2, address, reader).await;
|
|
c2.closed.store(true, Ordering::Relaxed);
|
|
result
|
|
}));
|
|
connection
|
|
});
|
|
ok
|
|
}
|
|
|
|
async fn connection_io_task_main(self: Arc<Self>, connection: &Arc<Connection>, remote_address: SocketAddr, mut reader: OwnedReadHalf) -> std::io::Result<()> {
|
|
const BUF_CHUNK_SIZE: usize = 4096;
|
|
const READ_BUF_INITIAL_SIZE: usize = 65536; // should be a multiple of BUF_CHUNK_SIZE
|
|
|
|
let mut write_buffer: Vec<u8> = Vec::with_capacity(BUF_CHUNK_SIZE);
|
|
let mut read_buffer: Vec<u8> = Vec::new();
|
|
read_buffer.resize(READ_BUF_INITIAL_SIZE, 0);
|
|
|
|
let config = self.host.node_config();
|
|
let mut anti_loopback_challenge_sent = [0_u8; 64];
|
|
let mut domain_challenge_sent = [0_u8; 64];
|
|
let mut auth_challenge_sent = [0_u8; 64];
|
|
self.host.get_secure_random(&mut anti_loopback_challenge_sent);
|
|
self.host.get_secure_random(&mut domain_challenge_sent);
|
|
self.host.get_secure_random(&mut auth_challenge_sent);
|
|
connection
|
|
.send_obj(
|
|
&mut write_buffer,
|
|
MessageType::Init,
|
|
&msg::Init {
|
|
anti_loopback_challenge: &anti_loopback_challenge_sent,
|
|
domain_challenge: &domain_challenge_sent,
|
|
auth_challenge: &auth_challenge_sent,
|
|
node_name: config.name.as_str(),
|
|
node_contact: config.contact.as_str(),
|
|
locally_bound_port: self.bind_address.port(),
|
|
explicit_ipv4: None,
|
|
explicit_ipv6: None,
|
|
},
|
|
self.ms_monotonic(),
|
|
)
|
|
.await?;
|
|
drop(config);
|
|
|
|
let max_message_size = ((D::MAX_VALUE_SIZE * 8) + (D::KEY_SIZE * 1024) + 65536) as u64; // sanity limit
|
|
let mut initialized = false;
|
|
let mut init_received = false;
|
|
let mut buffer_fill = 0_usize;
|
|
loop {
|
|
let message_type: MessageType;
|
|
let message_size: usize;
|
|
let header_size: usize;
|
|
let total_size: usize;
|
|
loop {
|
|
buffer_fill += reader.read(&mut read_buffer.as_mut_slice()[buffer_fill..]).await?;
|
|
if buffer_fill >= 2 {
|
|
// type and at least one byte of varint
|
|
let ms = varint::decode(&read_buffer.as_slice()[1..]);
|
|
if ms.1 > 0 {
|
|
// varint is all there and parsed correctly
|
|
if ms.0 > max_message_size {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large"));
|
|
}
|
|
|
|
message_type = MessageType::from(*read_buffer.get(0).unwrap());
|
|
message_size = ms.0 as usize;
|
|
header_size = 1 + ms.1;
|
|
total_size = header_size + message_size;
|
|
|
|
if read_buffer.len() < total_size {
|
|
read_buffer.resize(((total_size / BUF_CHUNK_SIZE) + 1) * BUF_CHUNK_SIZE, 0);
|
|
}
|
|
while buffer_fill < total_size {
|
|
buffer_fill += reader.read(&mut read_buffer.as_mut_slice()[buffer_fill..]).await?;
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
let mut message = &read_buffer.as_slice()[header_size..total_size];
|
|
|
|
let now = self.ms_monotonic();
|
|
connection.last_receive_time.store(now, Ordering::Relaxed);
|
|
|
|
match message_type {
|
|
MessageType::Nop => {}
|
|
|
|
MessageType::Init => {
|
|
if init_received {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init"));
|
|
}
|
|
init_received = true;
|
|
|
|
let msg: msg::Init = decode_msgpack(message)?;
|
|
let (anti_loopback_response, domain_challenge_response, auth_challenge_response) = {
|
|
let mut info = connection.info.lock().unwrap();
|
|
info.name = msg.node_name.to_string();
|
|
info.contact = msg.node_contact.to_string();
|
|
let _ = msg.explicit_ipv4.map(|pv4| {
|
|
info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port)));
|
|
});
|
|
let _ = msg.explicit_ipv6.map(|pv6| {
|
|
info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0)));
|
|
});
|
|
let info = info.clone();
|
|
|
|
let auth_challenge_response = self.host.authenticate(&info, msg.auth_challenge);
|
|
if auth_challenge_response.is_none() {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped"));
|
|
}
|
|
let auth_challenge_response = auth_challenge_response.unwrap();
|
|
|
|
(H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), msg.domain_challenge), auth_challenge_response)
|
|
};
|
|
|
|
connection
|
|
.send_obj(
|
|
&mut write_buffer,
|
|
MessageType::InitResponse,
|
|
&msg::InitResponse {
|
|
anti_loopback_response: &anti_loopback_response,
|
|
domain_response: &domain_challenge_response,
|
|
auth_response: &auth_challenge_response,
|
|
},
|
|
now,
|
|
)
|
|
.await?;
|
|
}
|
|
|
|
MessageType::InitResponse => {
|
|
let msg: msg::InitResponse = decode_msgpack(message)?;
|
|
let mut info = connection.info.lock().unwrap();
|
|
|
|
if info.initialized {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init response"));
|
|
}
|
|
|
|
if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self"));
|
|
}
|
|
if !msg.domain_response.eq(&H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), &domain_challenge_sent)) {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::Other, "domain mismatch"));
|
|
}
|
|
if !self.host.authenticate(&info, &auth_challenge_sent).map_or(false, |cr| msg.auth_response.eq(&cr)) {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed"));
|
|
}
|
|
|
|
initialized = true;
|
|
|
|
info.initialized = true;
|
|
|
|
let info = info.clone(); // also releases lock since info is replaced/destroyed
|
|
self.host.on_connect(&info);
|
|
}
|
|
|
|
// Handle messages other than INIT and INIT_RESPONSE after checking 'initialized' flag.
|
|
_ => {
|
|
if !initialized {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::Other, "init exchange must be completed before other messages are sent"));
|
|
}
|
|
|
|
match message_type {
|
|
MessageType::HaveRecords => {}
|
|
|
|
MessageType::GetRecords => {}
|
|
|
|
MessageType::Record => {
|
|
let key = H::sha512(&[message]);
|
|
match self.datastore.store(&key, message).await {
|
|
StoreResult::Ok => {
|
|
let announce_key: [u8; ANNOUNCE_KEY_LEN] = (&key[..ANNOUNCE_KEY_LEN]).try_into().unwrap();
|
|
let mut q = self.announce_queue.lock().await;
|
|
let ql = q.entry(announce_key).or_insert_with(|| Vec::with_capacity(2));
|
|
if !ql.contains(&remote_address) {
|
|
ql.push(remote_address.clone());
|
|
}
|
|
}
|
|
StoreResult::Rejected => {
|
|
return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("record rejected by data store: {}", to_hex_string(&key))));
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
MessageType::Sync => {
|
|
let msg: msg::Sync = decode_msgpack(message)?;
|
|
}
|
|
|
|
_ => {}
|
|
}
|
|
}
|
|
}
|
|
|
|
read_buffer.copy_within(total_size..buffer_fill, 0);
|
|
buffer_fill -= total_size;
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<D: DataStore + 'static, H: Host + 'static> Drop for NodeInternal<D, H> {
|
|
fn drop(&mut self) {
|
|
let _ = tokio::runtime::Handle::try_current().map_or_else(
|
|
|_| {
|
|
for (_, c) in self.connections.blocking_lock().drain() {
|
|
c.read_task.lock().unwrap().as_mut().map(|c| c.abort());
|
|
}
|
|
},
|
|
|h| {
|
|
let _ = h.block_on(async {
|
|
for (_, c) in self.connections.lock().await.drain() {
|
|
c.read_task.lock().unwrap().as_mut().map(|c| c.abort());
|
|
}
|
|
});
|
|
},
|
|
);
|
|
}
|
|
}
|
|
|
|
struct Connection {
|
|
writer: Mutex<OwnedWriteHalf>,
|
|
last_send_time: AtomicI64,
|
|
last_receive_time: AtomicI64,
|
|
info: std::sync::Mutex<RemoteNodeInfo>,
|
|
read_task: std::sync::Mutex<Option<JoinHandle<std::io::Result<()>>>>,
|
|
announce_new_records: AtomicBool,
|
|
closed: AtomicBool,
|
|
}
|
|
|
|
impl Connection {
|
|
async fn send_msg(&self, message_type: MessageType, data: &[u8], now: i64) -> std::io::Result<()> {
|
|
let mut header: [u8; 16] = unsafe { MaybeUninit::uninit().assume_init() };
|
|
header[0] = message_type as u8;
|
|
let header_size = 1 + varint::encode(&mut header[1..], data.len() as u64);
|
|
if self.writer.lock().await.write_vectored(&[IoSlice::new(&header[0..header_size]), IoSlice::new(data)]).await? == (data.len() + header_size) {
|
|
self.last_send_time.store(now, Ordering::Relaxed);
|
|
Ok(())
|
|
} else {
|
|
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error"))
|
|
}
|
|
}
|
|
|
|
async fn send_obj<O: Serialize>(&self, write_buf: &mut Vec<u8>, message_type: MessageType, obj: &O, now: i64) -> std::io::Result<()> {
|
|
write_buf.clear();
|
|
if rmp_serde::encode::write_named(write_buf, obj).is_ok() {
|
|
self.send_msg(message_type, write_buf.as_slice(), now).await
|
|
} else {
|
|
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure (internal error)"))
|
|
}
|
|
}
|
|
}
|