diff --git a/crypto/src/zssp.rs b/crypto/src/zssp.rs index 5a2fbc19e..bbd358bae 100644 --- a/crypto/src/zssp.rs +++ b/crypto/src/zssp.rs @@ -4,7 +4,6 @@ // FIPS compliant Noise_IK with Jedi powers and built-in attack-resistant large payload (fragmentation) support. use std::io::{Read, Write}; -use std::num::NonZeroU64; use std::ops::Deref; use std::sync::atomic::{AtomicU64, Ordering}; @@ -17,6 +16,7 @@ use crate::secret::Secret; use zerotier_utils::gatherarray::GatherArray; use zerotier_utils::memory; use zerotier_utils::ringbuffermap::RingBufferMap; +use zerotier_utils::unlikely_branch; use zerotier_utils::varint; use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard}; @@ -165,14 +165,10 @@ impl From for Error { } } -#[cold] -#[inline(never)] -extern "C" fn unlikely_branch() {} - impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::UnknownLocalSessionId(id) => f.write_str(format!("UnknownLocalSessionId({})", id.0.get()).as_str()), + Self::UnknownLocalSessionId(id) => f.write_str(format!("UnknownLocalSessionId({})", id.0).as_str()), Self::InvalidPacket => f.write_str("InvalidPacket"), Self::InvalidParameter => f.write_str("InvalidParameter"), Self::FailedAuthentication => f.write_str("FailedAuthentication"), @@ -217,28 +213,37 @@ pub enum ReceiveResult<'a, H: Host> { /// 48-bit session ID (most significant 24 bits of u64 are unused) #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[repr(transparent)] -pub struct SessionId(NonZeroU64); +pub struct SessionId(u64); impl SessionId { - pub const MAX_BIT_MASK: u64 = 0xffffffffffff; + pub const NIL: SessionId = SessionId(0xffffffffffff); - #[inline(always)] + #[inline] pub fn new_from_u64(i: u64) -> Option { - debug_assert!(i <= Self::MAX_BIT_MASK); - NonZeroU64::new(i).map(|i| Self(i)) + if i < Self::NIL.0 { + Some(Self(i)) + } else { + None + } } + #[inline] pub fn new_from_reader(r: &mut R) -> std::io::Result> { let mut tmp = [0_u8; 8]; - r.read_exact(&mut tmp[..SESSION_ID_SIZE]) - .map(|_| NonZeroU64::new(u64::from_le_bytes(tmp)).map(|i| Self(i))) + r.read_exact(&mut tmp[..SESSION_ID_SIZE])?; + Ok(Self::new_from_u64(u64::from_le_bytes(tmp))) + } + + #[inline] + pub fn new_random() -> Self { + Self(random::xorshift64_random() % (Self::NIL.0 - 1)) } } impl From for u64 { #[inline(always)] fn from(sid: SessionId) -> Self { - sid.0.get() + sid.0 } } @@ -578,8 +583,7 @@ impl ReceiveContext { let fragment_count = ((packet_type_fragment_info.wrapping_shr(4) + 1) as u8) & 63; let fragment_no = packet_type_fragment_info.wrapping_shr(10) as u8; - if let Some(local_session_id) = - SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & SessionId::MAX_BIT_MASK) + if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64) { if let Some(session) = host.session_lookup(local_session_id) { if check_header_mac(incoming_packet, &session.header_check_cipher) { @@ -632,7 +636,7 @@ impl ReceiveContext { } else { unlikely_branch(); if check_header_mac(incoming_packet, &self.incoming_init_header_check_cipher) { - let pseudoheader = Pseudoheader::make(0, packet_type, counter); + let pseudoheader = Pseudoheader::make(SessionId::NIL.0, packet_type, counter); if fragment_count > 1 { let mut defrag = self.initial_offer_defrag.lock(); let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); @@ -993,7 +997,7 @@ impl ReceiveContext { rp.write_all(bob_e0_keypair.public_key_bytes())?; rp.write_all(&offer_id)?; - rp.write_all(&session.id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?; + rp.write_all(&session.id.0.to_le_bytes()[..SESSION_ID_SIZE])?; varint::write(&mut rp, 0)?; // they don't need our static public; they have it varint::write(&mut rp, 0)?; // no meta-data in counter-offers (could be used in the future) if let Some(bob_e1_public) = bob_e1_public.as_ref() { @@ -1300,7 +1304,7 @@ fn create_initial_offer( p.write_all(alice_e0_keypair.public_key_bytes())?; p.write_all(&id)?; - p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?; + p.write_all(&alice_session_id.0.to_le_bytes()[..SESSION_ID_SIZE])?; varint::write(&mut p, alice_s_public.len() as u64)?; p.write_all(alice_s_public)?; varint::write(&mut p, alice_metadata.len() as u64)?; @@ -1321,7 +1325,7 @@ fn create_initial_offer( PACKET_BUF_SIZE - p.len() }; - let bob_session_id: u64 = bob_session_id.map_or(0_u64, |i| i.into()); + let bob_session_id: u64 = bob_session_id.map_or(SessionId::NIL.0, |i| i.into()); create_packet_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id, counter)?; let pseudoheader = Pseudoheader::make(bob_session_id, PACKET_TYPE_KEY_OFFER, counter.to_u32()); @@ -1684,7 +1688,7 @@ mod tests { local_s_hash, psk, session: Mutex::new(None), - session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1), + session_id_counter: Mutex::new(1), queue: Mutex::new(LinkedList::new()), key_id: Mutex::new([0; 16]), this_name, @@ -1765,7 +1769,7 @@ mod tests { Session::new( &alice_host, |data| bob_host.queue.lock().push_front(data.to_vec()), - SessionId::new_from_u64(random::xorshift64_random().wrapping_shr(16)).unwrap(), + SessionId::new_random(), bob_host.local_s.public_key_bytes(), &[], &psk,