diff --git a/core-crypto/src/zssp.rs b/core-crypto/src/zssp.rs index 54eb4cac7..2253d065a 100644 --- a/core-crypto/src/zssp.rs +++ b/core-crypto/src/zssp.rs @@ -12,6 +12,7 @@ use crate::random; use crate::secret::Secret; use zerotier_utils::gatherarray::GatherArray; +use zerotier_utils::memory; use zerotier_utils::ringbuffermap::RingBufferMap; use zerotier_utils::varint; @@ -41,6 +42,7 @@ const OFFER_RATE_LIMIT_MS: i64 = 2000; /// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication) const SESSION_PROTOCOL_VERSION: u8 = 0x00; +// Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use const PACKET_TYPE_DATA: u8 = 0; const PACKET_TYPE_NOP: u8 = 1; const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice" @@ -56,6 +58,10 @@ const AES_GCM_TAG_SIZE: usize = 16; const HMAC_SIZE: usize = 48; const SESSION_ID_SIZE: usize = 6; +const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; +const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; +const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; + /// Aribitrary starting value for key derivation chain. /// /// It doesn't matter very much what this is, but it's good for it to be unique. @@ -66,10 +72,6 @@ const KEY_DERIVATION_CHAIN_STARTING_SALT: [u8; 64] = [ 0xb4, 0x32, 0x85, 0xaf, 0x7f, 0x0d, 0xa9, 0x6c, 0x01, 0xfb, 0x72, 0x46, 0xc0, 0x09, 0x58, 0xb8, 0xe0, 0xa8, 0xcf, 0xb1, 0x58, 0x04, 0x6e, 0x32, 0xba, 0xa8, 0xb8, 0xf9, 0x0a, 0xa4, 0xbf, 0x36, ]; -const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; -const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; -const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; - pub enum Error { /// The packet was addressed to an unrecognized local session UnknownLocalSessionId(SessionId), @@ -147,16 +149,17 @@ pub enum ReceiveResult<'a, H: Host> { /// Packet is valid, no action needs to be taken. Ok, - /// Packet is valid and contained a data payload. - OkData(&'a [u8]), + /// Packet is valid and a data payload was decoded and authenticated. + /// + /// The returned reference is to the filled parts of the data buffer supplied to receive. + OkData(&'a mut [u8]), - /// Packet is valid and a new session was created, also includes a reply to be sent back. + /// Packet is valid and a new session was created. + /// + /// The session will have already been gated by the accept_new_session() method in the Host trait. OkNewSession(Session), - /// Packet appears valid but was ignored as a duplicate. - Duplicate, - - /// Packet apperas valid but was ignored for another reason. + /// Packet apperas valid but was ignored e.g. as a duplicate. Ignored, } @@ -167,20 +170,15 @@ pub struct SessionId(NonZeroU64); impl SessionId { pub const MAX_BIT_MASK: u64 = 0xffffffffffff; - pub fn new_from_bytes(b: &[u8]) -> Option { - if b.len() >= 6 { - let value = (u32::from_le_bytes(b[..4].try_into().unwrap()) as u64) | (u16::from_le_bytes(b[4..6].try_into().unwrap()) as u64).wrapping_shl(32); - if value > 0 && value <= Self::MAX_BIT_MASK { - return Some(Self(NonZeroU64::new(value).unwrap())); - } - } - return None; + #[inline(always)] + pub fn new_from_u64(i: u64) -> Option { + debug_assert!(i <= Self::MAX_BIT_MASK); + NonZeroU64::new(i).map(|i| Self(i)) } pub fn new_from_reader(r: &mut R) -> std::io::Result> { - let mut tmp = [0_u8; SESSION_ID_SIZE]; - r.read_exact(&mut tmp)?; - Ok(Self::new_from_bytes(&tmp)) + let mut tmp = [0_u8; 8]; + r.read_exact(&mut tmp[..SESSION_ID_SIZE]).map(|_| NonZeroU64::new(u64::from_le_bytes(tmp)).map(|i| Self(i))) } pub fn new_random() -> Self { @@ -193,19 +191,6 @@ impl SessionId { } } -impl TryFrom for SessionId { - type Error = self::Error; - - #[inline(always)] - fn try_from(value: u64) -> Result { - if value > 0 && value <= Self::MAX_BIT_MASK { - Ok(Self(NonZeroU64::new(value).unwrap())) - } else { - Err(Error::InvalidParameter) - } - } -} - impl From for u64 { #[inline(always)] fn from(sid: SessionId) -> Self { @@ -213,21 +198,31 @@ impl From for u64 { } } +/// Trait to implement to integrate the session into an application. pub trait Host: Sized { - type AssociatedObject: Sized; + /// Arbitrary object that can be associated with sessions. + type AssociatedObject; + + /// Arbitrary object that dereferences to the session, such as Arc>. type SessionRef: Deref>; + + /// A buffer containing data read from the network that can be cached. + /// + /// This can be e.g. a pooled buffer that automatically returns itself to the pool when dropped. type IncomingPacketBuffer: AsRef<[u8]>; /// Get a reference to this host's static public key blob. + /// + /// This must contain a NIST P-384 public key but can contain other information. fn get_local_s_public(&self) -> &[u8]; - /// Get SHA384(this host's static public key blob) + /// Get SHA384(this host's static public key blob), included here so we don't have to calculate it each time. fn get_local_s_public_hash(&self) -> &[u8; 48]; /// Get a reference to this hosts' static public key's NIST P-384 secret key pair fn get_local_s_keypair_p384(&self) -> &P384KeyPair; - /// Extract a NIST P-384 ECC public key from a static public key blob. + /// Extract the NIST P-384 ECC public key component from a static public key blob or return None on failure. fn extract_p384_static(static_public: &[u8]) -> Option; /// Look up a local session by local ID. @@ -246,23 +241,23 @@ pub struct Session { pub associated_object: H::AssociatedObject, send_counter: Counter, - psk: Secret<64>, - ss: Secret<48>, - state: RwLock, - remote_s_public_hash: [u8; 48], - remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], - defrag: Mutex, 64, 32>>, + psk: Secret<64>, // Arbitrary PSK provided by external code + ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer + state: RwLock, // Mutable parts of state (other than defrag buffers) + remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob) + remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key + defrag: Mutex, 16, 4>>, } struct MutableState { remote_session_id: Option, - keys: [Option; 2], // current, next + keys: [Option; 2], // current, next (promoted to current on successful decrypt) offer: Option, } /// State information to associate with receiving contexts such as sockets or remote paths/endpoints. pub struct ReceiveContext { - initial_offer_defrag: Mutex, 1024, 256>>, + initial_offer_defrag: Mutex, 1024, 128>>, } impl Session { @@ -362,26 +357,25 @@ impl ReceiveContext { current_time: i64, ) -> Result, Error> { let incoming_packet = incoming_packet_buf.as_ref(); - if incoming_packet.len() < MIN_PACKET_SIZE { unlikely_branch(); return Err(Error::InvalidPacket); } - let type_and_frag_info = u16::from_le_bytes(incoming_packet[0..2].try_into().unwrap()); - let local_session_id = SessionId::new_from_bytes(&incoming_packet[2..8]); - let counter = u32::from_le_bytes(incoming_packet[8..12].try_into().unwrap()); - let packet_type = (type_and_frag_info as u8) & 15; - let fragment_count = type_and_frag_info.wrapping_shr(4) & 63; - let fragment_no = type_and_frag_info.wrapping_shr(10); // & 63 not needed + let header_0_8 = memory::u64_from_le_bytes(incoming_packet); // type, frag info, session ID + let counter = memory::u32_from_le_bytes(&incoming_packet[8..]); + let local_session_id = SessionId::new_from_u64(header_0_8.wrapping_shr(16)); + let packet_type = (header_0_8 as u8) & 15; + let fragment_count = ((header_0_8.wrapping_shr(4) as u8) & 63).wrapping_add(1); + let fragment_no = (header_0_8.wrapping_shr(10) as u8) & 63; if fragment_count > 1 { if let Some(local_session_id) = local_session_id { - if fragment_count < (MAX_FRAGMENTS as u16) && fragment_no < fragment_count { + if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { if let Some(session) = host.session_lookup(local_session_id) { let mut defrag = session.defrag.lock(); - let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) { + let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time); } @@ -394,10 +388,10 @@ impl ReceiveContext { return Err(Error::InvalidPacket); } } else { - if fragment_count < (KEY_EXCHANGE_MAX_FRAGMENTS as u16) && fragment_no < fragment_count { + if fragment_count <= (KEY_EXCHANGE_MAX_FRAGMENTS as u8) && fragment_no < fragment_count { let mut defrag = self.initial_offer_defrag.lock(); - let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) { + let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); } @@ -444,22 +438,19 @@ impl ReceiveContext { let state = session.state.read(); for ki in 0..2 { if let Some(key) = state.keys[ki].as_ref() { - let head = fragments.first().unwrap().as_ref(); - debug_assert!(head.len() >= MIN_PACKET_SIZE); + let tail = fragments.last().unwrap().as_ref(); + if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { + unlikely_branch(); + return Err(Error::InvalidPacket); + } let mut c = key.get_receive_cipher(); - c.init(&get_aes_gcm_nonce(head)); + c.init(&get_aes_gcm_nonce(fragments.first().unwrap().as_ref())); - let mut data_len = head.len() - HEADER_SIZE; - if data_len > data_buf.len() { - unlikely_branch(); - key.return_receive_cipher(c); - return Err(Error::DataBufferTooSmall); - } - c.crypt(&head[HEADER_SIZE..], &mut data_buf[..data_len]); + let mut data_len = 0; - for fi in 1..(fragments.len() - 1) { - let f = fragments[fi].as_ref(); + for f in fragments[..(fragments.len() - 1)].iter() { + let f = f.as_ref(); debug_assert!(f.len() >= HEADER_SIZE); let current_frag_data_start = data_len; data_len += f.len() - HEADER_SIZE; @@ -471,21 +462,14 @@ impl ReceiveContext { c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); } - let tail = fragments.last().unwrap().as_ref(); - if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { - unlikely_branch(); - key.return_receive_cipher(c); - return Err(Error::InvalidPacket); - } - let tail_data_len = tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); let current_frag_data_start = data_len; - data_len += tail_data_len; + data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); if data_len > data_buf.len() { unlikely_branch(); key.return_receive_cipher(c); return Err(Error::DataBufferTooSmall); } - c.crypt(&tail[HEADER_SIZE..tail_data_len], &mut data_buf[current_frag_data_start..data_len]); + c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]); let tag = c.finish(); key.return_receive_cipher(c); @@ -500,7 +484,7 @@ impl ReceiveContext { } if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(&data_buf[..data_len])); + return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); } else { return Ok(ReceiveResult::Ok); } @@ -540,8 +524,6 @@ impl ReceiveContext { return Err(Error::UnknownProtocolVersion); } - let local_s_keypair_p384 = host.get_local_s_keypair_p384(); - match packet_type { PACKET_TYPE_KEY_OFFER => { // alice (remote) -> bob (local) @@ -565,7 +547,7 @@ impl ReceiveContext { } let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) - .and_then(|pk| local_s_keypair_p384.agree(&pk).map(move |s| (pk, s))) + .and_then(|pk| host.get_local_s_keypair_p384().agree(&pk).map(move |s| (pk, s))) .ok_or(Error::FailedAuthentication)?; let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes())); @@ -588,7 +570,7 @@ impl ReceiveContext { } let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?; - let ss = local_s_keypair_p384.agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?; + let ss = host.get_local_s_keypair_p384().agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?; let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); @@ -668,7 +650,7 @@ impl ReceiveContext { } (MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len() }; - let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_OFFER, session.id.into(), reply_counter); + let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter); reply_buf[..HEADER_SIZE].copy_from_slice(&header); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true); @@ -715,7 +697,7 @@ impl ReceiveContext { let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) .and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s))) .ok_or(Error::FailedAuthentication)?; - let se0 = local_s_keypair_p384.agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?; + let se0 = host.get_local_s_keypair_p384().agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?; let key = Secret(hmac_sha512( session.psk.as_bytes(), @@ -956,15 +938,16 @@ impl EphemeralOffer { #[inline(always)] fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] { - let fragment_count = (packet_len / mtu) + (((packet_len % mtu) != 0) as usize); + let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize; debug_assert!(mtu >= MIN_MTU); debug_assert!(packet_len >= HEADER_SIZE); debug_assert!(fragment_count <= MAX_FRAGMENTS); + debug_assert!(fragment_count > 0); debug_assert!(packet_type <= 0x0f); // packet type is 4 bits debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits // Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter - let mut header = (fragment_count.wrapping_shl(4) | (packet_type as usize)) as u128; + let mut header = ((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u128; header |= recipient_session_id.wrapping_shl(16) as u128; header |= (counter.to_u32() as u128).wrapping_shl(64); header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap() @@ -980,7 +963,8 @@ fn send_with_fragmentation(mut send: SendFunctio if fragment_end < packet_len { fragment_start = fragment_end - HEADER_SIZE; fragment_end = (fragment_start + mtu).min(packet_len); - header[1] += 0x04; // increment fragment number, at bit 2 in byte 1 since type/fragment u16 is little-endian + debug_assert!(header[1].wrapping_shr(2) < 63); + header[1] += 0x04; // increment fragment number in least significant 6 bits of byte 1 packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header); } else { debug_assert_eq!(fragment_end, packet_len); @@ -1104,171 +1088,154 @@ fn get_aes_gcm_nonce(packet: &[u8]) -> [u8; 16] { tmp } -#[cfg(test)] -mod tests {} - -/* #[cfg(test)] mod tests { - use std::rc::Rc; + use parking_lot::Mutex; + use std::collections::LinkedList; + use std::sync::Arc; #[allow(unused_imports)] use super::*; - #[test] - fn alice_bob() { - let psk: Secret<64> = Secret::default(); - let mut a_buffer = [0_u8; 1500]; - let mut b_buffer = [0_u8; 1500]; - let alice_static_keypair = P384KeyPair::generate(); - let bob_static_keypair = P384KeyPair::generate(); - let outgoing_obfuscator_to_alice = Obfuscator::new(alice_static_keypair.public_key_bytes()); - let outgoing_obfuscator_to_bob = Obfuscator::new(bob_static_keypair.public_key_bytes()); + struct TestHost { + local_s: P384KeyPair, + local_s_hash: [u8; 48], + psk: Secret<64>, + session: Mutex>>>>, + session_id_counter: Mutex, + pub queue: Mutex>>, + pub this_name: &'static str, + pub other_name: &'static str, + } - let mut from_alice: Vec> = Vec::new(); - let mut from_bob: Vec> = Vec::new(); - - // Session TO Bob, on Alice's side. - let (alice, packet) = Session::new( - &mut a_buffer, - SessionId::new_random(), - alice_static_keypair.public_key_bytes(), - &alice_static_keypair, - bob_static_keypair.public_key_bytes(), - bob_static_keypair.public_key(), - &psk, - 0, - 0, - true, - ) - .unwrap(); - let alice = Rc::new(alice); - from_alice.push(packet.to_vec()); - - // Session FROM Alice, on Bob's side. - let mut bob: Option>> = None; - - for _ in 0..256 { - while !from_alice.is_empty() || !from_bob.is_empty() { - if let Some(packet) = from_alice.pop() { - let r = Session::receive( - packet.as_slice(), - &mut b_buffer, - &bob_static_keypair, - &outgoing_obfuscator_to_bob, - |p: &[u8; P384_PUBLIC_KEY_SIZE]| P384PublicKey::from_bytes(p), - |sid| { - println!("[noise] [bob] session ID: {}", u64::from(sid)); - if let Some(bob) = bob.as_ref() { - if sid == bob.id { - Some(bob.clone()) - } else { - None - } - } else { - None - } - }, - |_: &[u8; P384_PUBLIC_KEY_SIZE]| { - if bob.is_none() { - Some((SessionId::new_random(), psk.clone(), 0)) - } else { - panic!("[noise] [bob] Bob received a second new session request from Alice"); - } - }, - 0, - true, - ); - if let Ok(r) = r { - match r { - ReceiveResult::OkData(data, counter) => { - println!("[noise] [bob] DATA len=={} counter=={}", data.len(), counter); - } - ReceiveResult::OkSendReply(p) => { - println!("[noise] [bob] OK (reply: {} bytes)", p.len()); - from_bob.push(p.to_vec()); - } - ReceiveResult::OkNewSession(ns, p) => { - if bob.is_some() { - panic!("[noise] [bob] attempt to create new session on Bob's side when he already has one"); - } - let id: u64 = ns.id.into(); - let _ = bob.replace(Rc::new(ns)); - from_bob.push(p.to_vec()); - println!("[noise] [bob] NEW SESSION {}", id); - } - ReceiveResult::Ok => { - println!("[noise] [bob] OK"); - } - ReceiveResult::Duplicate => { - println!("[noise] [bob] duplicate packet"); - } - ReceiveResult::Ignored => { - println!("[noise] [bob] ignored packet"); - } - } - } else { - println!("ERROR (bob): {}", r.err().unwrap().to_string()); - panic!(); - } - } - - if let Some(packet) = from_bob.pop() { - let r = Session::receive( - packet.as_slice(), - &mut b_buffer, - &alice_static_keypair, - &outgoing_obfuscator_to_alice, - |p: &[u8; P384_PUBLIC_KEY_SIZE]| P384PublicKey::from_bytes(p), - |sid| { - println!("[noise] [alice] session ID: {}", u64::from(sid)); - if sid == alice.id { - Some(alice.clone()) - } else { - panic!("[noise] [alice] received from Bob addressed to unknown session ID, not Alice"); - } - }, - |_: &[u8; P384_PUBLIC_KEY_SIZE]| { - panic!("[noise] [alice] Alice received an unexpected new session request from Bob"); - }, - 0, - true, - ); - if let Ok(r) = r { - match r { - ReceiveResult::OkData(data, counter) => { - println!("[noise] [alice] DATA len=={} counter=={}", data.len(), counter); - } - ReceiveResult::OkSendReply(p) => { - println!("[noise] [alice] OK (reply: {} bytes)", p.len()); - from_alice.push(p.to_vec()); - } - ReceiveResult::OkNewSession(_, _) => { - panic!("[noise] [alice] attempt to create new session on Alice's side; Bob should not initiate"); - } - ReceiveResult::Ok => { - println!("[noise] [alice] OK"); - } - ReceiveResult::Duplicate => { - println!("[noise] [alice] duplicate packet"); - } - ReceiveResult::Ignored => { - println!("[noise] [alice] ignored packet"); - } - } - } else { - println!("ERROR (alice): {}", r.err().unwrap().to_string()); - panic!(); - } - } + impl TestHost { + fn new(psk: Secret<64>, this_name: &'static str, other_name: &'static str) -> Self { + let local_s = P384KeyPair::generate(); + let local_s_hash = SHA384::hash(local_s.public_key_bytes()); + Self { + local_s, + local_s_hash, + psk, + session: Mutex::new(None), + session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1), + queue: Mutex::new(LinkedList::new()), + this_name, + other_name, } + } + } - if (random::next_u32_secure() & 1) == 0 { - from_alice.push(alice.send(&mut a_buffer, &[0_u8; 16]).unwrap().to_vec()); - } else if bob.is_some() { - from_bob.push(bob.as_ref().unwrap().send(&mut b_buffer, &[0_u8; 16]).unwrap().to_vec()); + impl Host for Box { + type AssociatedObject = u32; + type SessionRef = Arc>>; + type IncomingPacketBuffer = Vec; + + fn get_local_s_public(&self) -> &[u8] { + self.local_s.public_key_bytes() + } + + fn get_local_s_public_hash(&self) -> &[u8; 48] { + &self.local_s_hash + } + + fn get_local_s_keypair_p384(&self) -> &P384KeyPair { + &self.local_s + } + + fn extract_p384_static(static_public: &[u8]) -> Option { + P384PublicKey::from_bytes(static_public) + } + + fn session_lookup(&self, local_session_id: SessionId) -> Option { + self.session.lock().as_ref().and_then(|s| { + if s.id == local_session_id { + Some(s.clone()) + } else { + None + } + }) + } + + fn accept_new_session(&self, _: &[u8], _: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)> { + loop { + let mut new_id = self.session_id_counter.lock(); + *new_id += 1; + return Some((SessionId::new_from_u64(*new_id).unwrap(), self.psk.clone(), 0)); + } + } + } + + #[allow(unused_variables)] + #[test] + fn establish_session() { + let mut psk: Secret<64> = Secret::default(); + random::fill_bytes_secure(&mut psk.0); + let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob")); + let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice")); + let rc: Box>> = Box::new(ReceiveContext::new()); + let mut data_buf = [0_u8; 4096]; + + //println!("zssp: size of session (bytes): {}", std::mem::size_of::>>()); + + let _ = alice_host.session.lock().insert(Arc::new( + Session::new( + &alice_host, + |data| bob_host.queue.lock().push_front(data.to_vec()), + SessionId::new_random(), + bob_host.local_s.public_key_bytes(), + &[], + &psk, + 1, + 1280, + 1, + true, + ) + .unwrap(), + )); + + let mut ts = 0; + for _ in 0..256 { + for host in [&alice_host, &bob_host] { + let send_to_other = |data: &mut [u8]| { + if std::ptr::eq(host, &alice_host) { + bob_host.queue.lock().push_front(data.to_vec()); + } else { + alice_host.queue.lock().push_front(data.to_vec()); + } + }; + + loop { + if let Some(qi) = host.queue.lock().pop_back() { + let qi_len = qi.len(); + ts += 1; + let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, true, ts); + if r.is_ok() { + let r = r.unwrap(); + match r { + ReceiveResult::Ok => { + println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len); + } + ReceiveResult::OkData(data) => { + println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len()); + } + ReceiveResult::OkNewSession(new_session) => { + println!("zssp: {} => {} ({}): OkNewSession ({})", host.other_name, host.this_name, qi_len, u64::from(new_session.id)); + let mut hs = host.session.lock(); + assert!(hs.is_none()); + let _ = hs.insert(Arc::new(new_session)); + } + ReceiveResult::Ignored => { + println!("zssp: {} => {} ({}): Ignored", host.other_name, host.this_name, qi_len); + } + } + } else { + println!("zssp: {} => {}: error: {}", host.other_name, host.this_name, r.err().unwrap().to_string()); + } + } else { + break; + } + } } } } } -*/ diff --git a/utils/src/gatherarray.rs b/utils/src/gatherarray.rs index 4cf93a230..f4359a678 100644 --- a/utils/src/gatherarray.rs +++ b/utils/src/gatherarray.rs @@ -11,16 +11,16 @@ use crate::arrayvec::ArrayVec; pub struct GatherArray { a: [MaybeUninit; C], have_bits: u64, - have_count: u32, - goal: u32, + have_count: u8, + goal: u8, } impl GatherArray { /// Create a new gather array, which must be initialized prior to use. #[inline(always)] - pub fn new(goal: u32) -> Self { + pub fn new(goal: u8) -> Self { assert!(C <= 64); - assert!(goal <= (C as u32)); + assert!(goal <= (C as u8)); assert_eq!(size_of::<[T; C]>(), size_of::<[MaybeUninit; C]>()); Self { a: unsafe { MaybeUninit::uninit().assume_init() }, @@ -32,10 +32,10 @@ impl GatherArray { /// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here. #[inline(always)] - pub fn add(&mut self, index: u32, value: T) -> Option> { + pub fn add(&mut self, index: u8, value: T) -> Option> { if index < self.goal { let mut have = self.have_bits; - let got = 1u64.wrapping_shl(index); + let got = 1u64.wrapping_shl(index as u32); if (have & got) == 0 { have |= got; self.have_bits = have; @@ -64,7 +64,7 @@ impl Drop for GatherArray { fn drop(&mut self) { let have = self.have_bits; for i in 0..self.goal { - if (have & 1u64.wrapping_shl(i)) != 0 { + if (have & 1u64.wrapping_shl(i as u32)) != 0 { unsafe { self.a.get_unchecked_mut(i as usize).assume_init_drop() }; } } @@ -78,8 +78,8 @@ mod tests { #[test] fn gather_array() { - for goal in 2..64 { - let mut m = GatherArray::::new(goal); + for goal in 2u8..64u8 { + let mut m = GatherArray::::new(goal); for x in 0..(goal - 1) { assert!(m.add(x, x).is_none()); } diff --git a/utils/src/lib.rs b/utils/src/lib.rs index d134532dc..7ea17d81d 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -1,5 +1,6 @@ pub mod arrayvec; pub mod gatherarray; pub mod hex; +pub mod memory; pub mod ringbuffermap; pub mod varint; diff --git a/utils/src/memory.rs b/utils/src/memory.rs new file mode 100644 index 000000000..2404334f6 --- /dev/null +++ b/utils/src/memory.rs @@ -0,0 +1,143 @@ +// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. + +#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))] +#[allow(unused)] +mod fast_int_memory_access { + #[inline(always)] + pub fn u64_from_le_bytes(b: &[u8]) -> u64 { + assert!(b.len() >= 8); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn u32_from_le_bytes(b: &[u8]) -> u32 { + assert!(b.len() >= 4); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn u16_from_le_bytes(b: &[u8]) -> u16 { + assert!(b.len() >= 2); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn i64_from_le_bytes(b: &[u8]) -> i64 { + assert!(b.len() >= 8); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn i32_from_le_bytes(b: &[u8]) -> i32 { + assert!(b.len() >= 4); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn i16_from_le_bytes(b: &[u8]) -> i16 { + assert!(b.len() >= 2); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn u64_from_be_bytes(b: &[u8]) -> u64 { + assert!(b.len() >= 8); + unsafe { *b.as_ptr().cast::() }.swap_bytes() + } + + #[inline(always)] + pub fn u32_from_be_bytes(b: &[u8]) -> u32 { + assert!(b.len() >= 4); + unsafe { *b.as_ptr().cast::() }.swap_bytes() + } + + #[inline(always)] + pub fn u16_from_be_bytes(b: &[u8]) -> u16 { + assert!(b.len() >= 2); + unsafe { *b.as_ptr().cast::() }.swap_bytes() + } + + #[inline(always)] + pub fn i64_from_be_bytes(b: &[u8]) -> i64 { + assert!(b.len() >= 8); + unsafe { *b.as_ptr().cast::() }.swap_bytes() + } + + #[inline(always)] + pub fn i32_from_be_bytes(b: &[u8]) -> i32 { + assert!(b.len() >= 4); + unsafe { *b.as_ptr().cast::() }.swap_bytes() + } + + #[inline(always)] + pub fn i16_from_be_bytes(b: &[u8]) -> i16 { + assert!(b.len() >= 2); + unsafe { *b.as_ptr().cast::() }.swap_bytes() + } +} + +#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))] +#[allow(unused)] +mod fast_int_memory_access { + #[inline(always)] + pub fn u64_from_le_bytes(b: &[u8]) -> u64 { + u64::from_le_bytes(b[..8].try_into().unwrap()) + } + + #[inline(always)] + pub fn u32_from_le_bytes(b: &[u8]) -> u32 { + u32::from_le_bytes(b[..4].try_into().unwrap()) + } + + #[inline(always)] + pub fn u16_from_le_bytes(b: &[u8]) -> u16 { + u16::from_le_bytes(b[..2].try_into().unwrap()) + } + + #[inline(always)] + pub fn i64_from_le_bytes(b: &[u8]) -> i64 { + i64::from_le_bytes(b[..8].try_into().unwrap()) + } + + #[inline(always)] + pub fn i32_from_le_bytes(b: &[u8]) -> i32 { + i32::from_le_bytes(b[..4].try_into().unwrap()) + } + + #[inline(always)] + pub fn i16_from_le_bytes(b: &[u8]) -> i16 { + i16::from_le_bytes(b[..2].try_into().unwrap()) + } + + #[inline(always)] + pub fn u64_from_be_bytes(b: &[u8]) -> u64 { + u64::from_be_bytes(b[..8].try_into().unwrap()) + } + + #[inline(always)] + pub fn u32_from_be_bytes(b: &[u8]) -> u32 { + u32::from_be_bytes(b[..4].try_into().unwrap()) + } + + #[inline(always)] + pub fn u16_from_be_bytes(b: &[u8]) -> u16 { + u16::from_be_bytes(b[..2].try_into().unwrap()) + } + + #[inline(always)] + pub fn i64_from_be_bytes(b: &[u8]) -> i64 { + i64::from_be_bytes(b[..8].try_into().unwrap()) + } + + #[inline(always)] + pub fn i32_from_be_bytes(b: &[u8]) -> i32 { + i32::from_be_bytes(b[..4].try_into().unwrap()) + } + + #[inline(always)] + pub fn i16_from_be_bytes(b: &[u8]) -> i16 { + i16::from_be_bytes(b[..2].try_into().unwrap()) + } +} + +pub use fast_int_memory_access::*; diff --git a/utils/src/ringbuffermap.rs b/utils/src/ringbuffermap.rs index 79508ad10..989bbca47 100644 --- a/utils/src/ringbuffermap.rs +++ b/utils/src/ringbuffermap.rs @@ -106,13 +106,15 @@ impl RingBu #[inline] pub fn new(salt: u32) -> Self { Self { - entries: std::array::from_fn(|_| Entry:: { - key: MaybeUninit::uninit(), - value: MaybeUninit::uninit(), - bucket: -1, - next: -1, - prev: -1, - }), + entries: { + let mut entries: [Entry; C] = unsafe { MaybeUninit::uninit().assume_init() }; + for e in entries.iter_mut() { + e.bucket = -1; + e.next = -1; + e.prev = -1; + } + entries + }, buckets: [-1; B], entry_ptr: 0, salt,