diff --git a/zssp/src/fragged.rs b/zssp/src/fragged.rs index f3f5f1eae..47370158b 100644 --- a/zssp/src/fragged.rs +++ b/zssp/src/fragged.rs @@ -55,24 +55,11 @@ impl Fragged { /// /// When a fully assembled packet is returned the internal state is reset and this object can /// be reused to assemble another packet. - #[inline(always)] pub fn assemble(&mut self, counter: u64, fragment: Fragment, fragment_no: u8, fragment_count: u8) -> Option> { if fragment_no < fragment_count && (fragment_count as usize) <= MAX_FRAGMENTS { // If the counter has changed, reset the structure to receive a new packet. if counter != self.counter { - if needs_drop::() { - let mut have = self.have; - let mut i = 0; - while have != 0 { - if (have & 1) != 0 { - debug_assert!(i < MAX_FRAGMENTS); - unsafe { self.frags.get_unchecked_mut(i).assume_init_drop() }; - } - have = have.wrapping_shr(1); - i += 1; - } - } - self.have = 0; + self.drop_in_place(); self.count = fragment_count as u32; self.counter = counter; } @@ -96,11 +83,10 @@ impl Fragged { } return None; } -} -impl Drop for Fragged { + /// Drops any remaining fragments and resets this object. #[inline(always)] - fn drop(&mut self) { + pub fn drop_in_place(&mut self) { if needs_drop::() { let mut have = self.have; let mut i = 0; @@ -113,5 +99,15 @@ impl Drop for Fragged Drop for Fragged { + #[inline(always)] + fn drop(&mut self) { + self.drop_in_place(); } } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index f07260799..b8d522b14 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -14,7 +14,7 @@ use std::collections::hash_map::RandomState; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher}; use std::num::NonZeroU64; -use std::sync::atomic::{AtomicI64, AtomicU64, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicI64, AtomicU64, AtomicUsize, Ordering, AtomicBool}; use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; use zerotier_crypto::aes::{Aes, AesGcm}; @@ -41,7 +41,8 @@ const GCM_CIPHER_POOL_SIZE: usize = 4; pub struct Context { default_physical_mtu: AtomicUsize, defrag_salt: RandomState, - defrag: [Mutex>; MAX_INCOMPLETE_SESSION_QUEUE_SIZE], + defrag_has_pending: AtomicBool, // Allowed to be falsely positive + defrag: [Mutex<(Fragged, i64)>; MAX_INCOMPLETE_SESSION_QUEUE_SIZE], sessions: RwLock>, } @@ -153,7 +154,8 @@ impl Context { Self { default_physical_mtu: AtomicUsize::new(default_physical_mtu), defrag_salt: RandomState::new(), - defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), + defrag_has_pending: AtomicBool::new(false), + defrag: std::array::from_fn(|_| Mutex::new((Fragged::new(), i64::MAX))), sessions: RwLock::new(SessionsById { active: HashMap::with_capacity(64), incoming: HashMap::with_capacity(64), @@ -245,6 +247,24 @@ impl Context { dead_pending.push(*id); } } + + } + // Only check for expiration if we have a pending packet. + // This check is allowed to have false positives for simplicity's sake. + if self.defrag_has_pending.swap(false, Ordering::Relaxed) { + let mut has_pending = false; + for m in &self.defrag { + let mut pending = m.lock().unwrap(); + if pending.1 <= negotiation_timeout_cutoff { + pending.1 = i64::MAX; + pending.0.drop_in_place(); + } else if pending.0.counter() != 0 { + has_pending = true; + } + } + if has_pending { + self.defrag_has_pending.store(true, Ordering::Relaxed); + } } if !dead_active.is_empty() || !dead_pending.is_empty() { @@ -544,15 +564,23 @@ impl Context { // cannot control which slots their fragments index to. And since Alice's packet header has a randomly // generated counter value replaying it in time requires extreme amounts of network control. let mut slot0 = self.defrag[idx0].lock().unwrap(); - if slot0.counter() == hashed_counter { - assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if slot0.0.counter() == hashed_counter { + assembled = slot0.0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if assembled.is_some() { slot0.1 = i64::MAX } } else { let mut slot1 = self.defrag[idx1].lock().unwrap(); - if slot1.counter() == hashed_counter || slot1.counter() == 0 { - assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if slot1.0.counter() == hashed_counter || slot1.0.counter() == 0 { + if slot1.0.counter() == 0 { + slot1.1 = current_time; + self.defrag_has_pending.store(true, Ordering::Relaxed); + } + assembled = slot1.0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if assembled.is_some() { slot1.1 = i64::MAX } } else { - // slot1 is full so kick out whatever is in slot0 to make more room. - assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + // slot0 is either occupied or empty so we overwrite whatever is there to make more room. + slot0.1 = current_time; + self.defrag_has_pending.store(true, Ordering::Relaxed); + assembled = slot0.0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); } }