From f2028ce3a2ca6e0ab663297db8c25593c6cd8328 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Tue, 18 Oct 2022 16:01:16 -0400 Subject: [PATCH] Refactor send method a bit to hide some backward compatibility detail from outside code. --- network-hypervisor/src/vl1/node.rs | 33 ++-- network-hypervisor/src/vl1/peer.rs | 237 +++++++++++++++-------------- 2 files changed, 135 insertions(+), 135 deletions(-) diff --git a/network-hypervisor/src/vl1/node.rs b/network-hypervisor/src/vl1/node.rs index 48eeb9833..f0cb4db53 100644 --- a/network-hypervisor/src/vl1/node.rs +++ b/network-hypervisor/src/vl1/node.rs @@ -1,6 +1,7 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. use std::collections::HashMap; +use std::convert::Infallible; use std::hash::Hash; use std::io::Write; use std::sync::atomic::Ordering; @@ -892,24 +893,22 @@ impl Node { } /// Send a WHOIS query to the current best root. - fn send_whois(&self, host_system: &HostSystemImpl, addresses: &[Address], time_ticks: i64) { + fn send_whois(&self, host_system: &HostSystemImpl, mut addresses: &[Address], time_ticks: i64) { debug_assert!(!addresses.is_empty()); - if !addresses.is_empty() { - if let Some(root) = self.best_root() { - let mut packet = host_system.get_buffer(); - packet.set_size(v1::HEADER_SIZE); - let _ = packet.append_u8(verbs::VL1_WHOIS); - for a in addresses.iter() { - if (packet.len() + ADDRESS_SIZE) > UDP_DEFAULT_MTU { - root.send(host_system, None, self, time_ticks, packet); - packet = host_system.get_buffer(); - packet.set_size(v1::HEADER_SIZE); - let _ = packet.append_u8(verbs::VL1_WHOIS); - } - let _ = packet.append_bytes_fixed(&a.to_bytes()); - } - if packet.len() > (v1::HEADER_SIZE + 1) { - root.send(host_system, None, self, time_ticks, packet); + if let Some(root) = self.best_root() { + while !addresses.is_empty() { + if !root + .send(host_system, self, None, time_ticks, |packet| -> Result<(), Infallible> { + assert!(packet.append_u8(verbs::VL1_WHOIS).is_ok()); + while !addresses.is_empty() && (packet.len() + ADDRESS_SIZE) <= UDP_DEFAULT_MTU { + assert!(packet.append_bytes_fixed(&addresses[0].to_bytes()).is_ok()); + addresses = &addresses[1..]; + } + Ok(()) + }) + .is_some() + { + break; } } } diff --git a/network-hypervisor/src/vl1/peer.rs b/network-hypervisor/src/vl1/peer.rs index 3e406aa9a..8cf2d8ea8 100644 --- a/network-hypervisor/src/vl1/peer.rs +++ b/network-hypervisor/src/vl1/peer.rs @@ -1,6 +1,7 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. use std::collections::HashMap; +use std::convert::Infallible; use std::hash::Hash; use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; use std::sync::{Arc, Mutex, RwLock, Weak}; @@ -259,18 +260,23 @@ impl Peer { } } - /// Send a packet to this peer, returning true on (potential) success. + /// Send a packet to this peer. /// - /// This will go directly if there is an active path, or otherwise indirectly - /// via a root or some other route. - pub(crate) fn send( + /// This sets up a buffer and then invokes the supplied function to actually populate its contents. + /// It's structured this way to handle both V1 and V2 format packets and the need to set them up + /// differently while hiding that from higher level code. + /// + /// The builder function must append the verb (with any verb flags) and packet payload. If it returns + /// an error, the error is returned immediately and the send is aborted. None is returned if the send + /// function itself fails for some reason such as no paths being available. + pub fn send Result>( &self, host_system: &HostSystemImpl, - path: Option<&Arc>>, node: &Node, + path: Option<&Arc>>, time_ticks: i64, - mut packet: PooledPacketBuffer, - ) -> bool { + builder_function: BuilderFunction, + ) -> Option> { let mut _path_arc = None; let path = if let Some(path) = path { path @@ -279,83 +285,92 @@ impl Peer { if let Some(path) = _path_arc.as_ref() { path } else { - return false; + return None; } }; let max_fragment_size = path.endpoint.max_fragment_size(); - if self.remote_node_info.read().unwrap().remote_protocol_version >= 11 { - let flags_cipher_hops = if packet.len() > max_fragment_size { - v1::HEADER_FLAG_FRAGMENTED | v1::CIPHER_AES_GMAC_SIV - } else { - v1::CIPHER_AES_GMAC_SIV - }; + let mut packet = host_system.get_buffer(); + if !self.identity.p384.is_some() { + // For the V1 protocol, leave room for for the header in the buffer. + packet.set_size(v1::HEADER_SIZE); + } - let mut aes_gmac_siv = self.v1_proto_static_secret.aes_gmac_siv.get(); - aes_gmac_siv.encrypt_init(&self.v1_proto_next_message_id().to_be_bytes()); - aes_gmac_siv.encrypt_set_aad(&v1::get_packet_aad_bytes( - self.identity.address, - node.identity.address, - flags_cipher_hops, - )); - let tag = if let Ok(payload) = packet.as_bytes_starting_at_mut(v1::HEADER_SIZE) { - aes_gmac_siv.encrypt_first_pass(payload); - aes_gmac_siv.encrypt_first_pass_finish(); - aes_gmac_siv.encrypt_second_pass_in_place(payload); - aes_gmac_siv.encrypt_second_pass_finish() - } else { - return false; - }; + let r = builder_function(packet.as_mut()); - let header = packet.struct_mut_at::(0).unwrap(); - header.id.copy_from_slice(&tag[0..8]); - header.dest = self.identity.address.to_bytes(); - header.src = node.identity.address.to_bytes(); - header.flags_cipher_hops = flags_cipher_hops; - header.mac.copy_from_slice(&tag[8..16]); - } else { - let packet_len = packet.len(); - let flags_cipher_hops = if packet.len() > max_fragment_size { - v1::HEADER_FLAG_FRAGMENTED | v1::CIPHER_SALSA2012_POLY1305 + if r.is_ok() { + if self.identity.p384.is_some() { + todo!() // TODO: ZSSP / V2 protocol } else { - v1::CIPHER_SALSA2012_POLY1305 - }; + if self.remote_node_info.read().unwrap().remote_protocol_version >= 11 { + let flags_cipher_hops = if packet.len() > max_fragment_size { + v1::HEADER_FLAG_FRAGMENTED | v1::CIPHER_AES_GMAC_SIV + } else { + v1::CIPHER_AES_GMAC_SIV + }; + + let mut aes_gmac_siv = self.v1_proto_static_secret.aes_gmac_siv.get(); + aes_gmac_siv.encrypt_init(&self.v1_proto_next_message_id().to_be_bytes()); + aes_gmac_siv.encrypt_set_aad(&v1::get_packet_aad_bytes( + self.identity.address, + node.identity.address, + flags_cipher_hops, + )); + let payload = packet.as_bytes_starting_at_mut(v1::HEADER_SIZE).unwrap(); + aes_gmac_siv.encrypt_first_pass(payload); + aes_gmac_siv.encrypt_first_pass_finish(); + aes_gmac_siv.encrypt_second_pass_in_place(payload); + let tag = aes_gmac_siv.encrypt_second_pass_finish(); - let (mut salsa, poly1305_otk) = v1_proto_salsa_poly_create( - &self.v1_proto_static_secret, - { let header = packet.struct_mut_at::(0).unwrap(); - header.id = self.v1_proto_next_message_id().to_be_bytes(); + header.id.copy_from_slice(&tag[0..8]); header.dest = self.identity.address.to_bytes(); header.src = node.identity.address.to_bytes(); header.flags_cipher_hops = flags_cipher_hops; - header - }, - packet_len, + header.mac.copy_from_slice(&tag[8..16]); + } else { + let packet_len = packet.len(); + let flags_cipher_hops = if packet.len() > max_fragment_size { + v1::HEADER_FLAG_FRAGMENTED | v1::CIPHER_SALSA2012_POLY1305 + } else { + v1::CIPHER_SALSA2012_POLY1305 + }; + + let (mut salsa, poly1305_otk) = v1_proto_salsa_poly_create( + &self.v1_proto_static_secret, + { + let header = packet.struct_mut_at::(0).unwrap(); + header.id = self.v1_proto_next_message_id().to_be_bytes(); + header.dest = self.identity.address.to_bytes(); + header.src = node.identity.address.to_bytes(); + header.flags_cipher_hops = flags_cipher_hops; + header + }, + packet_len, + ); + + let payload = packet.as_bytes_starting_at_mut(v1::HEADER_SIZE).unwrap(); + salsa.crypt_in_place(payload); + let tag = poly1305::compute(&poly1305_otk, payload); + + packet.as_bytes_mut()[v1::MAC_FIELD_INDEX..(v1::MAC_FIELD_INDEX + 8)].copy_from_slice(&tag[0..8]); + } + } + + self.v1_proto_internal_send( + host_system, + &path.endpoint, + Some(&path.local_socket), + Some(&path.local_interface), + max_fragment_size, + packet, ); - let tag = if let Ok(payload) = packet.as_bytes_starting_at_mut(v1::HEADER_SIZE) { - salsa.crypt_in_place(payload); - poly1305::compute(&poly1305_otk, payload) - } else { - return false; - }; - packet.as_bytes_mut()[v1::MAC_FIELD_INDEX..(v1::MAC_FIELD_INDEX + 8)].copy_from_slice(&tag[0..8]); + self.last_send_time_ticks.store(time_ticks, Ordering::Relaxed); } - self.v1_proto_internal_send( - host_system, - &path.endpoint, - Some(&path.local_socket), - Some(&path.local_interface), - max_fragment_size, - packet, - ); - - self.last_send_time_ticks.store(time_ticks, Ordering::Relaxed); - - return true; + return Some(r); } /// Send a HELLO to this peer. @@ -594,24 +609,28 @@ impl Peer { ); } - let mut packet = host_system.get_buffer(); - packet.set_size(v1::HEADER_SIZE); - { - let f: &mut ( - v1::message_component_structs::OkHeader, - v1::message_component_structs::OkHelloFixedHeaderFields, - ) = packet.append_struct_get_mut().unwrap(); - f.0.verb = verbs::VL1_OK; - f.0.in_re_verb = verbs::VL1_HELLO; - f.0.in_re_message_id = message_id.to_ne_bytes(); - f.1.timestamp_echo = hello_fixed_headers.timestamp; - f.1.version_proto = PROTOCOL_VERSION; - f.1.version_major = VERSION_MAJOR; - f.1.version_minor = VERSION_MINOR; - f.1.version_revision = VERSION_REVISION.to_be_bytes(); - } + self.send( + host_system, + node, + Some(source_path), + time_ticks, + |packet| -> Result<(), Infallible> { + let f: &mut ( + v1::message_component_structs::OkHeader, + v1::message_component_structs::OkHelloFixedHeaderFields, + ) = packet.append_struct_get_mut().unwrap(); + f.0.verb = verbs::VL1_OK; + f.0.in_re_verb = verbs::VL1_HELLO; + f.0.in_re_message_id = message_id.to_ne_bytes(); + f.1.timestamp_echo = hello_fixed_headers.timestamp; + f.1.version_proto = PROTOCOL_VERSION; + f.1.version_major = VERSION_MAJOR; + f.1.version_minor = VERSION_MINOR; + f.1.version_revision = VERSION_REVISION.to_be_bytes(); + Ok(()) + }, + ); - self.send(host_system, Some(source_path), node, time_ticks, packet); return PacketHandlerResult::Ok; } } @@ -762,39 +781,25 @@ impl Peer { payload: &PacketBuffer, ) -> PacketHandlerResult { if node.this_node_is_root() || inner.should_communicate_with(&self.identity) { - let init_packet = |packet: &mut PacketBuffer| { - packet.set_size(v1::HEADER_SIZE); - let mut f: &mut v1::message_component_structs::OkHeader = packet.append_struct_get_mut().unwrap(); - f.verb = verbs::VL1_OK; - f.in_re_verb = verbs::VL1_WHOIS; - f.in_re_message_id = message_id.to_ne_bytes(); - }; - - let mut packet = host_system.get_buffer(); - init_packet(&mut packet); - let mut addresses = payload.as_bytes(); - loop { - if addresses.len() >= ADDRESS_SIZE { - if let Some(zt_address) = Address::from_bytes(&addresses[..ADDRESS_SIZE]) { - if let Some(peer) = node.peer(zt_address) { - if (packet.capacity() - packet.len()) < Identity::MAX_MARSHAL_SIZE { - self.send(host_system, None, node, time_ticks, packet); - packet = host_system.get_buffer(); - init_packet(&mut packet); - } - if !peer.identity.write_public(packet.as_mut(), self.identity.p384.is_none()).is_ok() { - break; + while addresses.len() >= ADDRESS_SIZE { + if !self + .send(host_system, node, None, time_ticks, |packet| { + while addresses.len() >= ADDRESS_SIZE && (packet.len() + Identity::MAX_MARSHAL_SIZE) <= UDP_DEFAULT_MTU { + if let Some(zt_address) = Address::from_bytes(&addresses[..ADDRESS_SIZE]) { + if let Some(peer) = node.peer(zt_address) { + peer.identity.write_public(packet, self.identity.p384.is_none())?; + } } + addresses = &addresses[ADDRESS_SIZE..]; } - } - addresses = &addresses[ADDRESS_SIZE..]; - } else { + Ok(()) + }) + .map_or(false, |r: std::io::Result<()>| r.is_ok()) + { break; } } - - self.send(host_system, None, node, time_ticks, packet); } return PacketHandlerResult::Ok; } @@ -822,17 +827,13 @@ impl Peer { payload: &PacketBuffer, ) -> PacketHandlerResult { if inner.should_communicate_with(&self.identity) || node.is_peer_root(self) { - let mut packet = host_system.get_buffer(); - packet.set_size(v1::HEADER_SIZE); - { + self.send(host_system, node, None, time_ticks, |packet| { let mut f: &mut v1::message_component_structs::OkHeader = packet.append_struct_get_mut().unwrap(); f.verb = verbs::VL1_OK; f.in_re_verb = verbs::VL1_ECHO; f.in_re_message_id = message_id.to_ne_bytes(); - } - if packet.append_bytes(payload.as_bytes()).is_ok() { - self.send(host_system, None, node, time_ticks, packet); - } + packet.append_bytes(payload.as_bytes()) + }); } else { debug_event!( host_system,