/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at https://mozilla.org/MPL/2.0/. * * (c)2022 ZeroTier, Inc. * https://www.zerotier.com/ */ use std::collections::{HashMap, HashSet}; use std::io::IoSlice; use std::mem::MaybeUninit; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::ops::Add; use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering}; use std::sync::Arc; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio::time::{Duration, Instant}; use crate::datastore::*; use crate::host::Host; use crate::iblt::IBLT; use crate::protocol::*; use crate::utils::*; use crate::varint; /// Inactivity timeout for connections in milliseconds. const CONNECTION_TIMEOUT: i64 = SYNC_STATUS_PERIOD * 4; /// How often to run the housekeeping task's loop in milliseconds. const HOUSEKEEPING_INTERVAL: i64 = SYNC_STATUS_PERIOD; /// Information about a remote node to which we are connected. #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RemoteNodeInfo { /// Optional name advertised by remote node (arbitrary). pub name: String, /// Optional contact information advertised by remote node (arbitrary). pub contact: String, /// Actual remote endpoint address. pub remote_address: SocketAddr, /// Explicitly advertised remote addresses supplied by remote node (not necessarily verified). pub explicit_addresses: Vec, /// Time TCP connection was established (ms since epoch). pub connect_time: i64, /// Time TCP connection was estaablished (ms, monotonic). pub connect_instant: i64, /// True if this is an inbound TCP connection. pub inbound: bool, /// True if this connection has exchanged init messages successfully. pub initialized: bool, } fn configure_tcp_socket(socket: &TcpSocket) -> std::io::Result<()> { let _ = socket.set_linger(None); if socket.set_reuseport(true).is_ok() { Ok(()) } else { socket.set_reuseaddr(true) } } fn decode_msgpack<'a, T: Deserialize<'a>>(b: &'a [u8]) -> std::io::Result { rmp_serde::from_slice(b).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid msgpack object: {}", e.to_string()))) } /// An instance of the syncwhole data set synchronization engine. /// /// This holds a number of async tasks that are terminated or aborted if this object /// is dropped. In other words this implements structured concurrency. pub struct Node { internal: Arc>, housekeeping_task: JoinHandle<()>, listener_task: JoinHandle<()>, } impl Node { pub async fn new(db: Arc, host: Arc, bind_address: SocketAddr) -> std::io::Result { let listener = if bind_address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; configure_tcp_socket(&listener)?; listener.bind(bind_address.clone())?; let listener = listener.listen(1024)?; let internal = Arc::new(NodeInternal:: { anti_loopback_secret: { let mut tmp = [0_u8; 64]; host.get_secure_random(&mut tmp); tmp }, datastore: db.clone(), host: host.clone(), connections: Mutex::new(HashMap::with_capacity(64)), bind_address, starting_instant: Instant::now(), }); Ok(Self { internal: internal.clone(), housekeeping_task: tokio::spawn(internal.clone().housekeeping_task_main()), listener_task: tokio::spawn(internal.listener_task_main(listener)) }) } #[inline(always)] pub fn datastore(&self) -> &Arc { &self.internal.datastore } #[inline(always)] pub fn host(&self) -> &Arc { &self.internal.host } pub async fn connect(&self, endpoint: &SocketAddr) -> std::io::Result { self.internal.clone().connect(endpoint, Instant::now().add(Duration::from_millis(CONNECTION_TIMEOUT as u64))).await } pub async fn list_connections(&self) -> Vec { let connections = self.internal.connections.lock().await; let mut cl: Vec = Vec::with_capacity(connections.len()); for (_, c) in connections.iter() { cl.push(c.info.lock().await.clone()); } cl } pub async fn connection_count(&self) -> usize { self.internal.connections.lock().await.len() } } impl Drop for Node { fn drop(&mut self) { self.housekeeping_task.abort(); self.listener_task.abort(); } } pub struct NodeInternal { // Secret used to perform HMAC to detect and drop loopback connections to self. anti_loopback_secret: [u8; 64], // Outside code implementations of DataStore and Host traits. datastore: Arc, host: Arc, // Connections and their task join handles, by remote endpoint address. connections: Mutex>>, // Local address to which this node is bound bind_address: SocketAddr, // Instant this node started. starting_instant: Instant, } impl NodeInternal { fn ms_monotonic(&self) -> i64 { Instant::now().duration_since(self.starting_instant).as_millis() as i64 } /// Loop that constantly runs in the background to do cleanup and service things. async fn housekeeping_task_main(self: Arc) { let mut tasks: Vec> = Vec::new(); let mut connected_to_addresses: HashSet = HashSet::new(); let mut sleep_until = Instant::now().add(Duration::from_millis(500)); loop { tokio::time::sleep_until(sleep_until).await; sleep_until = sleep_until.add(Duration::from_millis(HOUSEKEEPING_INTERVAL as u64)); tasks.clear(); connected_to_addresses.clear(); let now = self.ms_monotonic(); self.connections.lock().await.retain(|sa, c| { if !c.closed.load(Ordering::Relaxed) { if (now - c.last_receive_time.load(Ordering::Relaxed)) < CONNECTION_TIMEOUT { connected_to_addresses.insert(sa.clone()); true // keep connection } else { let _ = c.read_task.lock().unwrap().take().map(|j| j.abort()); let host = self.host.clone(); let cc = c.clone(); tasks.push(tokio::spawn(async move { host.on_connection_closed(&*cc.info.lock().await, "timeout".to_string()); })); false // discard connection } } else { let host = self.host.clone(); let cc = c.clone(); let j = c.read_task.lock().unwrap().take(); tasks.push(tokio::spawn(async move { if j.is_some() { let e = j.unwrap().await; if e.is_ok() { let e = e.unwrap(); host.on_connection_closed(&*cc.info.lock().await, e.map_or_else(|e| e.to_string(), |_| "unknown error".to_string())); } else { host.on_connection_closed(&*cc.info.lock().await, "remote host closed connection".to_string()); } } else { host.on_connection_closed(&*cc.info.lock().await, "remote host closed connection".to_string()); } })); false // discard connection } }); let config = self.host.node_config(); // Always try to connect to anchor peers. for sa in config.anchors.iter() { if !connected_to_addresses.contains(sa) { let sa = sa.clone(); let self2 = self.clone(); tasks.push(tokio::spawn(async move { let _ = self2.connect(&sa, sleep_until).await; })); connected_to_addresses.insert(sa.clone()); } } // Try to connect to more peers until desired connection count is reached. let desired_connection_count = config.desired_connection_count.min(config.max_connection_count); for sa in config.seeds.iter() { if connected_to_addresses.len() >= desired_connection_count { break; } if !connected_to_addresses.contains(sa) { connected_to_addresses.insert(sa.clone()); let self2 = self.clone(); let sa = sa.clone(); tasks.push(tokio::spawn(async move { let _ = self2.connect(&sa, sleep_until).await; })); } } // Wait for this iteration's batched background tasks to complete. loop { let s = tasks.pop(); if s.is_some() { let _ = s.unwrap().await; } else { break; } } } } /// Incoming TCP acceptor task. async fn listener_task_main(self: Arc, listener: TcpListener) { loop { let socket = listener.accept().await; if socket.is_ok() { let (stream, address) = socket.unwrap(); if self.host.allow(&address) { let config = self.host.node_config(); if self.connections.lock().await.len() < config.max_connection_count || config.anchors.contains(&address) { Self::connection_start(&self, address, stream, true).await; } } } } } /// Initiate an outgoing connection with a deadline based timeout. async fn connect(self: Arc, address: &SocketAddr, deadline: Instant) -> std::io::Result { self.host.on_connect_attempt(address); let stream = if address.is_ipv4() { TcpSocket::new_v4() } else { TcpSocket::new_v6() }?; configure_tcp_socket(&stream)?; stream.bind(self.bind_address.clone())?; let stream = tokio::time::timeout_at(deadline, stream.connect(address.clone())).await; if stream.is_ok() { Ok(self.connection_start(address.clone(), stream.unwrap()?, false).await) } else { Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "connect timed out")) } } /// Sets up and spawns the task for a new TCP connection whether inbound or outbound. async fn connection_start(self: &Arc, address: SocketAddr, stream: TcpStream, inbound: bool) -> bool { let mut ok = false; let _ = self.connections.lock().await.entry(address.clone()).or_insert_with(|| { ok = true; //let _ = stream.set_nodelay(true); let (reader, writer) = stream.into_split(); let now = self.ms_monotonic(); let connection = Arc::new(Connection { writer: Mutex::new(writer), last_send_time: AtomicI64::new(now), last_receive_time: AtomicI64::new(now), bytes_sent: AtomicU64::new(0), bytes_received: AtomicU64::new(0), info: Mutex::new(RemoteNodeInfo { name: String::new(), contact: String::new(), remote_address: address.clone(), explicit_addresses: Vec::new(), connect_time: ms_since_epoch(), connect_instant: now, inbound, initialized: false }), read_task: std::sync::Mutex::new(None), closed: AtomicBool::new(false), }); let self2 = self.clone(); let c2 = connection.clone(); connection.read_task.lock().unwrap().replace(tokio::spawn(async move { let result = self2.connection_io_task_main(&c2, reader).await; c2.closed.store(true, Ordering::Relaxed); result })); connection }); ok } /// Main I/O task launched for each connection. /// /// This handles reading from the connection and reacting to what it sends. Killing this /// task is done when the connection is closed. async fn connection_io_task_main(self: Arc, connection: &Arc, mut reader: OwnedReadHalf) -> std::io::Result<()> { const BUF_CHUNK_SIZE: usize = 4096; const READ_BUF_INITIAL_SIZE: usize = 65536; // should be a multiple of BUF_CHUNK_SIZE let background_tasks = AsyncTaskReaper::new(); let mut write_buffer: Vec = Vec::with_capacity(BUF_CHUNK_SIZE); let mut read_buffer: Vec = Vec::new(); read_buffer.resize(READ_BUF_INITIAL_SIZE, 0); let config = self.host.node_config(); let mut anti_loopback_challenge_sent = [0_u8; 64]; let mut domain_challenge_sent = [0_u8; 64]; let mut auth_challenge_sent = [0_u8; 64]; self.host.get_secure_random(&mut anti_loopback_challenge_sent); self.host.get_secure_random(&mut domain_challenge_sent); self.host.get_secure_random(&mut auth_challenge_sent); connection .send_obj( &mut write_buffer, MessageType::Init, &msg::Init { anti_loopback_challenge: &anti_loopback_challenge_sent, domain_challenge: &domain_challenge_sent, auth_challenge: &auth_challenge_sent, node_name: config.name.as_str(), node_contact: config.contact.as_str(), locally_bound_port: self.bind_address.port(), explicit_ipv4: None, explicit_ipv6: None, }, self.ms_monotonic(), ) .await?; drop(config); let max_message_size = ((D::MAX_VALUE_SIZE * 8) + (D::KEY_SIZE * 1024) + 65536) as u64; // sanity limit let mut initialized = false; let mut init_received = false; let mut buffer_fill = 0_usize; loop { let message_type: MessageType; let message_size: usize; let header_size: usize; let total_size: usize; loop { buffer_fill += reader.read(&mut read_buffer.as_mut_slice()[buffer_fill..]).await?; if buffer_fill >= 2 { // type and at least one byte of varint let ms = varint::decode(&read_buffer.as_slice()[1..]); if ms.1 > 0 { // varint is all there and parsed correctly if ms.0 > max_message_size { return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large")); } message_type = MessageType::from(*read_buffer.get(0).unwrap()); message_size = ms.0 as usize; header_size = 1 + ms.1; total_size = header_size + message_size; if read_buffer.len() < total_size { read_buffer.resize(((total_size / BUF_CHUNK_SIZE) + 1) * BUF_CHUNK_SIZE, 0); } while buffer_fill < total_size { buffer_fill += reader.read(&mut read_buffer.as_mut_slice()[buffer_fill..]).await?; } break; } } } let message = &read_buffer.as_slice()[header_size..total_size]; let now = self.ms_monotonic(); connection.last_receive_time.store(now, Ordering::Relaxed); match message_type { MessageType::Nop => {} MessageType::Init => { if init_received { return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init")); } init_received = true; let msg: msg::Init = decode_msgpack(message)?; let (anti_loopback_response, domain_challenge_response, auth_challenge_response) = { let mut info = connection.info.lock().await; info.name = msg.node_name.to_string(); info.contact = msg.node_contact.to_string(); let _ = msg.explicit_ipv4.map(|pv4| { info.explicit_addresses.push(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(pv4.ip), pv4.port))); }); let _ = msg.explicit_ipv6.map(|pv6| { info.explicit_addresses.push(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(pv6.ip), pv6.port, 0, 0))); }); let auth_challenge_response = self.host.authenticate(&info, msg.auth_challenge); if auth_challenge_response.is_none() { return Err(std::io::Error::new(std::io::ErrorKind::Other, "authenticate() returned None, connection dropped")); } (H::hmac_sha512(&self.anti_loopback_secret, msg.anti_loopback_challenge), H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), msg.domain_challenge), auth_challenge_response.unwrap()) }; connection.send_obj(&mut write_buffer, MessageType::InitResponse, &msg::InitResponse { anti_loopback_response: &anti_loopback_response, domain_response: &domain_challenge_response, auth_response: &auth_challenge_response }, now).await?; } MessageType::InitResponse => { let msg: msg::InitResponse = decode_msgpack(message)?; let mut info = connection.info.lock().await; if info.initialized { return Err(std::io::Error::new(std::io::ErrorKind::Other, "duplicate init response")); } if msg.anti_loopback_response.eq(&H::hmac_sha512(&self.anti_loopback_secret, &anti_loopback_challenge_sent)) { return Err(std::io::Error::new(std::io::ErrorKind::Other, "rejected connection to self")); } if !msg.domain_response.eq(&H::hmac_sha512(&H::sha512(&[self.datastore.domain().as_bytes()]), &domain_challenge_sent)) { return Err(std::io::Error::new(std::io::ErrorKind::Other, "domain mismatch")); } if !self.host.authenticate(&info, &auth_challenge_sent).map_or(false, |cr| msg.auth_response.eq(&cr)) { return Err(std::io::Error::new(std::io::ErrorKind::Other, "challenge/response authentication failed")); } info.initialized = true; initialized = true; let info = info.clone(); self.host.on_connect(&info); } // Handle messages other than INIT and INIT_RESPONSE after checking 'initialized' flag. _ => { if !initialized { return Err(std::io::Error::new(std::io::ErrorKind::Other, "init exchange must be completed before other messages are sent")); } match message_type { _ => {} MessageType::HaveRecords => { let msg: msg::HaveRecords = decode_msgpack(message)?; } MessageType::GetRecords => { let msg: msg::GetRecords = decode_msgpack(message)?; } MessageType::Record => { let key = H::sha512(&[message]); match self.datastore.store(&key, message) { StoreResult::Ok => { // TODO: probably should not announce if way out of sync let connections = self.connections.lock().await; let mut announce_to: Vec> = Vec::with_capacity(connections.len()); for (_, c) in connections.iter() { if !Arc::ptr_eq(&connection, c) { announce_to.push(c.clone()); } } drop(connections); // release lock background_tasks.spawn(async move { for c in announce_to.iter() { let _ = c.send_msg(MessageType::HaveRecord, &key[0..ANNOUNCE_KEY_LEN], now).await; } }); } StoreResult::Rejected => { return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("record rejected by data store: {}", to_hex_string(&key)))); } _ => {} } } MessageType::SyncStatus => { let msg: msg::SyncStatus = decode_msgpack(message)?; } MessageType::SyncRequest => { let msg: msg::SyncRequest = decode_msgpack(message)?; } MessageType::SyncResponse => { let msg: msg::SyncResponse = decode_msgpack(message)?; } } } } read_buffer.copy_within(total_size..buffer_fill, 0); buffer_fill -= total_size; connection.bytes_received.fetch_add(total_size as u64, Ordering::Relaxed); } } } impl Drop for NodeInternal { fn drop(&mut self) { let _ = tokio::runtime::Handle::try_current().map_or_else( |_| { for (_, c) in self.connections.blocking_lock().drain() { c.read_task.lock().unwrap().as_mut().map(|c| c.abort()); } }, |h| { let _ = h.block_on(async { for (_, c) in self.connections.lock().await.drain() { c.read_task.lock().unwrap().as_mut().map(|c| c.abort()); } }); }, ); } } struct Connection { writer: Mutex, last_send_time: AtomicI64, last_receive_time: AtomicI64, bytes_sent: AtomicU64, bytes_received: AtomicU64, info: Mutex, read_task: std::sync::Mutex>>>, closed: AtomicBool, } impl Connection { async fn send_msg(&self, message_type: MessageType, data: &[u8], now: i64) -> std::io::Result<()> { let mut header: [u8; 16] = unsafe { MaybeUninit::uninit().assume_init() }; header[0] = message_type as u8; let header_size = 1 + varint::encode(&mut header[1..], data.len() as u64); if self.writer.lock().await.write_vectored(&[IoSlice::new(&header[0..header_size]), IoSlice::new(data)]).await? == (data.len() + header_size) { self.last_send_time.store(now, Ordering::Relaxed); self.bytes_sent.fetch_add((header_size + data.len()) as u64, Ordering::Relaxed); Ok(()) } else { Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "write error")) } } async fn send_obj(&self, write_buf: &mut Vec, message_type: MessageType, obj: &O, now: i64) -> std::io::Result<()> { write_buf.clear(); if rmp_serde::encode::write_named(write_buf, obj).is_ok() { self.send_msg(message_type, write_buf.as_slice(), now).await } else { Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "serialize failure (internal error)")) } } }