From 3cc407cecd884f6d3d51938e163b1f182f59878b Mon Sep 17 00:00:00 2001 From: mamoniot Date: Sun, 25 Dec 2022 11:35:07 -0500 Subject: [PATCH] implemented proper windowing --- zssp/src/counter.rs | 109 ++++++++++++++++++++++++++++++++-------- zssp/src/zssp.rs | 120 ++++++++++++++++++++++---------------------- 2 files changed, 147 insertions(+), 82 deletions(-) diff --git a/zssp/src/counter.rs b/zssp/src/counter.rs index 9aa147670..861fe1f4a 100644 --- a/zssp/src/counter.rs +++ b/zssp/src/counter.rs @@ -1,4 +1,7 @@ -use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Mutex, RwLock, +}; use zerotier_crypto::random; @@ -23,7 +26,7 @@ impl Counter { /// Get the value most recently used to send a packet. #[inline(always)] pub fn previous(&self) -> CounterValue { - CounterValue(self.0.load(Ordering::SeqCst)) + CounterValue(self.0.load(Ordering::SeqCst).wrapping_sub(1)) } /// Get a counter value for the next packet being sent. @@ -56,33 +59,95 @@ impl CounterValue { } /// Incoming packet deduplication and replay protection window. -pub(crate) struct CounterWindow(AtomicU32, [AtomicU32; COUNTER_MAX_DELTA as usize]); +pub(crate) struct CounterWindowAlt(RwLock<(u32, [u32; COUNTER_MAX_DELTA as usize])>); -impl CounterWindow { +impl CounterWindowAlt { #[inline(always)] pub fn new(initial: u32) -> Self { - Self(AtomicU32::new(initial), std::array::from_fn(|_| AtomicU32::new(initial))) + Self(RwLock::new((initial, std::array::from_fn(|_| initial)))) } #[inline(always)] pub fn message_received(&self, received_counter_value: u32) -> bool { - let prev_max = self.0.fetch_max(received_counter_value, Ordering::AcqRel); - if received_counter_value >= prev_max || prev_max.wrapping_sub(received_counter_value) <= COUNTER_MAX_DELTA { - // First, the most common case: counter is higher than the previous maximum OR is no older than MAX_DELTA. - // In that case we accept the packet if it is not a duplicate. Duplicate check is this swap/compare. - self.1[(received_counter_value % COUNTER_MAX_DELTA) as usize].swap(received_counter_value, Ordering::AcqRel) - != received_counter_value - } else if received_counter_value.wrapping_sub(prev_max) <= COUNTER_MAX_DELTA { - // If the received value is lower and wraps when the previous max is subtracted, this means the - // unsigned integer counter has wrapped. In that case we write the new lower-but-actually-higher "max" - // value and then check the deduplication window. - self.0.store(received_counter_value, Ordering::Release); - self.1[(received_counter_value % COUNTER_MAX_DELTA) as usize].swap(received_counter_value, Ordering::AcqRel) - != received_counter_value - } else { - // If the received value is more than MAX_DELTA in the past and wrapping has NOT occurred, this packet - // is too old and is rejected. - false + let idx = (received_counter_value % COUNTER_MAX_DELTA) as usize; + let data = self.0.read().unwrap(); + let max_counter_seen = data.0; + let lower_window = max_counter_seen.wrapping_sub(COUNTER_MAX_DELTA / 2); + let upper_window = max_counter_seen.wrapping_add(COUNTER_MAX_DELTA / 2); + if lower_window < upper_window { + if (lower_window <= received_counter_value) & (received_counter_value < upper_window) { + if data.1[idx] != received_counter_value { + return true; + } + } + } else if (lower_window <= received_counter_value) | (received_counter_value < upper_window) { + if data.1[idx] != received_counter_value { + return true; + } } + return false; + } + + #[inline(always)] + pub fn message_authenticated(&self, received_counter_value: u32) -> bool { + let idx = (received_counter_value % COUNTER_MAX_DELTA) as usize; + let mut data = self.0.write().unwrap(); + let max_counter_seen = data.0; + let lower_window = max_counter_seen.wrapping_sub(COUNTER_MAX_DELTA / 2); + let upper_window = max_counter_seen.wrapping_add(COUNTER_MAX_DELTA / 2); + if lower_window < upper_window { + if (lower_window <= received_counter_value) & (received_counter_value < upper_window) { + if data.1[idx] != received_counter_value { + data.1[idx] = received_counter_value; + data.0 = max_counter_seen.max(received_counter_value); + return true; + } + } + } else if (lower_window <= received_counter_value) | (received_counter_value < upper_window) { + if data.1[idx] != received_counter_value { + data.1[idx] = received_counter_value; + data.0 = (max_counter_seen as i32).max(received_counter_value as i32) as u32; + return true; + } + } + return false; + } +} + +pub(crate) struct CounterWindow(Mutex<(usize, [u64; COUNTER_MAX_DELTA as usize])>); + +impl CounterWindow { + #[inline(always)] + pub fn new(initial: u32) -> Self { + let initial_nonce = (initial as u64).wrapping_shl(32); + Self(Mutex::new((0, std::array::from_fn(|_| initial_nonce)))) + } + + #[inline(always)] + pub fn message_received(&self, received_counter_value: u32, received_fragment_no: u8) -> bool { + let fragment_nonce = (received_counter_value as u64).wrapping_shl(32) | (received_fragment_no as u64); + //everything past this point must be atomic, i.e. these instructions must be run mutually exclusive to completion; + //atomic instructions are only ever atomic within themselves; + //sequentially consistent atomics do not guarantee that the thread is not preempted between individual atomic instructions + let mut data = self.0.lock().unwrap(); + let mut is_in = false; + let mut is_gt_min = false; + for nonce in data.1 { + is_in |= nonce == fragment_nonce; + let udist = nonce.abs_diff(fragment_nonce); + let sdist = (nonce as i64).abs_diff(fragment_nonce as i64); + if udist < sdist { + is_gt_min |= nonce < fragment_nonce; + } else { + is_gt_min |= (nonce as i64) < (fragment_nonce as i64); + } + } + if !is_in & is_gt_min { + let idx = data.0; + data.1[idx] = fragment_nonce; + data.0 = (idx + 1) % (COUNTER_MAX_DELTA as usize); + return true; + } + return false; } } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 1bd92b357..837c81cf8 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -113,6 +113,7 @@ pub struct Session { pub application_data: Application::Data, send_counter: Counter, // Outgoing packet counter and nonce state + receive_window: CounterWindow, // Receive window for anti-replay and deduplication psk: Secret<64>, // Arbitrary PSK provided by external code noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) @@ -136,7 +137,6 @@ struct SessionKey { secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret creation_time: i64, // Time session key was established creation_counter: CounterValue, // Counter value at which session was established - receive_window: CounterWindow, // Receive window for anti-replay and deduplication lifetime: KeyLifetime, // Key expiration time and counter ratchet_key: Secret<64>, // Ratchet key for deriving the next session key receive_key: Secret, // Receive side AES-GCM key @@ -486,45 +486,50 @@ impl ReceiveContext { { if let Some(session) = app.lookup_session(local_session_id) { if verify_header_check_code(incoming_packet, &session.header_check_cipher) { - let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); - if fragment_count > 1 { - if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { - let mut defrag = session.defrag.lock().unwrap(); - let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { - drop(defrag); // release lock - return self.receive_complete( - app, - remote_address, - &mut send, - data_buf, - counter, - canonical_header.as_bytes(), - assembled_packet.as_ref(), - packet_type, - Some(session), - mtu, - current_time, - ); + if session.receive_window.message_received(counter, fragment_no) { + let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); + if fragment_count > 1 { + if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { + let mut defrag = session.defrag.lock().unwrap(); + let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { + drop(defrag); // release lock + return self.receive_complete( + app, + remote_address, + &mut send, + data_buf, + counter, + canonical_header.as_bytes(), + assembled_packet.as_ref(), + packet_type, + Some(session), + mtu, + current_time, + ); + } + } else { + unlikely_branch(); + return Err(Error::InvalidPacket); } } else { - unlikely_branch(); - return Err(Error::InvalidPacket); + return self.receive_complete( + app, + remote_address, + &mut send, + data_buf, + counter, + canonical_header.as_bytes(), + &[incoming_packet_buf], + packet_type, + Some(session), + mtu, + current_time, + ); } } else { - return self.receive_complete( - app, - remote_address, - &mut send, - data_buf, - counter, - canonical_header.as_bytes(), - &[incoming_packet_buf], - packet_type, - Some(session), - mtu, - current_time, - ); + unlikely_branch(); + return Ok(ReceiveResult::Ignored); } } else { unlikely_branch(); @@ -658,36 +663,31 @@ impl ReceiveContext { session_key.return_receive_cipher(c); if aead_authentication_ok { - if session_key.receive_window.message_received(counter) { - // Select this key as the new default if it's newer than the current key. - if p > 0 - && state.session_keys[state.cur_session_key_idx] - .as_ref() - .map_or(true, |old| old.creation_counter < session_key.creation_counter) - { - drop(state); - let mut state = session.state.write().unwrap(); - state.cur_session_key_idx = key_idx; - for i in 0..KEY_HISTORY_SIZE { - if i != key_idx { - if let Some(old_key) = state.session_keys[key_idx].as_ref() { - // Release pooled cipher memory from old keys. - old_key.receive_cipher_pool.lock().unwrap().clear(); - old_key.send_cipher_pool.lock().unwrap().clear(); - } + // Select this key as the new default if it's newer than the current key. + if p > 0 + && state.session_keys[state.cur_session_key_idx] + .as_ref() + .map_or(true, |old| old.creation_counter < session_key.creation_counter) + { + drop(state); + let mut state = session.state.write().unwrap(); + state.cur_session_key_idx = key_idx; + for i in 0..KEY_HISTORY_SIZE { + if i != key_idx { + if let Some(old_key) = state.session_keys[key_idx].as_ref() { + // Release pooled cipher memory from old keys. + old_key.receive_cipher_pool.lock().unwrap().clear(); + old_key.send_cipher_pool.lock().unwrap().clear(); } } } + } - if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); - } else { - unlikely_branch(); - return Ok(ReceiveResult::Ok); - } + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); } else { unlikely_branch(); - return Ok(ReceiveResult::Ignored); + return Ok(ReceiveResult::Ok); } } }