From e3c200556408449a78bf1976f6ece71580b27146 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Wed, 21 Dec 2022 14:45:29 -0500 Subject: [PATCH] Incoming packet dedup and anti-replay in ZSSP. --- zssp/src/constants.rs | 3 ++ zssp/src/counter.rs | 36 +++++++++++++++++++- zssp/src/tests.rs | 14 +++++++- zssp/src/zssp.rs | 77 +++++++++++++++++++++++++++++-------------- 4 files changed, 103 insertions(+), 27 deletions(-) diff --git a/zssp/src/constants.rs b/zssp/src/constants.rs index 16f3c039b..523b6692e 100644 --- a/zssp/src/constants.rs +++ b/zssp/src/constants.rs @@ -77,6 +77,9 @@ pub(crate) const SESSION_ID_SIZE: usize = 6; /// Number of session keys to hold at a given time (current, previous, next). pub(crate) const KEY_HISTORY_SIZE: usize = 3; +/// Maximum difference between out-of-order incoming packet counters, and size of deduplication buffer. +pub(crate) const COUNTER_MAX_DELTA: u32 = 16; + // Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use pub(crate) const PACKET_TYPE_DATA: u8 = 0; pub(crate) const PACKET_TYPE_NOP: u8 = 1; diff --git a/zssp/src/counter.rs b/zssp/src/counter.rs index 3bee54dd8..9aa147670 100644 --- a/zssp/src/counter.rs +++ b/zssp/src/counter.rs @@ -1,7 +1,9 @@ -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use zerotier_crypto::random; +use crate::constants::COUNTER_MAX_DELTA; + /// Outgoing packet counter with strictly ordered atomic semantics. /// /// The counter used in packets is actually 32 bits, but using a 64-bit integer internally @@ -52,3 +54,35 @@ impl CounterValue { Self(self.0.checked_add(uses).unwrap()) } } + +/// Incoming packet deduplication and replay protection window. +pub(crate) struct CounterWindow(AtomicU32, [AtomicU32; COUNTER_MAX_DELTA as usize]); + +impl CounterWindow { + #[inline(always)] + pub fn new(initial: u32) -> Self { + Self(AtomicU32::new(initial), std::array::from_fn(|_| AtomicU32::new(initial))) + } + + #[inline(always)] + pub fn message_received(&self, received_counter_value: u32) -> bool { + let prev_max = self.0.fetch_max(received_counter_value, Ordering::AcqRel); + if received_counter_value >= prev_max || prev_max.wrapping_sub(received_counter_value) <= COUNTER_MAX_DELTA { + // First, the most common case: counter is higher than the previous maximum OR is no older than MAX_DELTA. + // In that case we accept the packet if it is not a duplicate. Duplicate check is this swap/compare. + self.1[(received_counter_value % COUNTER_MAX_DELTA) as usize].swap(received_counter_value, Ordering::AcqRel) + != received_counter_value + } else if received_counter_value.wrapping_sub(prev_max) <= COUNTER_MAX_DELTA { + // If the received value is lower and wraps when the previous max is subtracted, this means the + // unsigned integer counter has wrapped. In that case we write the new lower-but-actually-higher "max" + // value and then check the deduplication window. + self.0.store(received_counter_value, Ordering::Release); + self.1[(received_counter_value % COUNTER_MAX_DELTA) as usize].swap(received_counter_value, Ordering::AcqRel) + != received_counter_value + } else { + // If the received value is more than MAX_DELTA in the past and wrapping has NOT occurred, this packet + // is too old and is rejected. + false + } + } +} diff --git a/zssp/src/tests.rs b/zssp/src/tests.rs index ab3dba199..45de756df 100644 --- a/zssp/src/tests.rs +++ b/zssp/src/tests.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] #[cfg(test)] mod tests { use std::collections::LinkedList; @@ -8,7 +9,7 @@ mod tests { use zerotier_crypto::secret::Secret; use zerotier_utils::hex; - #[allow(unused_imports)] + use crate::counter::CounterWindow; use crate::*; use constants::*; @@ -216,4 +217,15 @@ mod tests { } } } + + #[test] + fn counter_window() { + let w = CounterWindow::new(0xffffffff); + assert!(!w.message_received(0xffffffff)); + assert!(w.message_received(0)); + assert!(w.message_received(1)); + assert!(w.message_received(COUNTER_MAX_DELTA * 2)); + assert!(!w.message_received(0xffffffff)); + assert!(w.message_received(0xfffffffe)); + } } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 5e7a797f9..5e37da8a9 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -20,7 +20,7 @@ use zerotier_utils::varint; use crate::applicationlayer::ApplicationLayer; use crate::constants::*; -use crate::counter::{Counter, CounterValue}; +use crate::counter::{Counter, CounterValue, CounterWindow}; use crate::sessionid::SessionId; pub enum Error { @@ -136,6 +136,7 @@ struct SessionKey { secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret creation_time: i64, // Time session key was established creation_counter: CounterValue, // Counter value at which session was established + receive_window: CounterWindow, // Receive window for anti-replay and deduplication lifetime: KeyLifetime, // Key expiration time and counter ratchet_key: Secret<64>, // Ratchet key for deriving the next session key receive_key: Secret, // Receive side AES-GCM key @@ -497,6 +498,7 @@ impl ReceiveContext { remote_address, &mut send, data_buf, + counter, canonical_header.as_bytes(), assembled_packet.as_ref(), packet_type, @@ -515,6 +517,7 @@ impl ReceiveContext { remote_address, &mut send, data_buf, + counter, canonical_header.as_bytes(), &[incoming_packet_buf], packet_type, @@ -546,6 +549,7 @@ impl ReceiveContext { remote_address, &mut send, data_buf, + counter, canonical_header.as_bytes(), assembled_packet.as_ref(), packet_type, @@ -560,6 +564,7 @@ impl ReceiveContext { remote_address, &mut send, data_buf, + counter, canonical_header.as_bytes(), &[incoming_packet_buf], packet_type, @@ -587,6 +592,7 @@ impl ReceiveContext { remote_address: &Application::RemoteAddress, send: &mut SendFunction, data_buf: &'a mut [u8], + counter: u32, canonical_header_bytes: &[u8; 12], fragments: &[Application::IncomingPacketBuffer], packet_type: u8, @@ -652,31 +658,36 @@ impl ReceiveContext { session_key.return_receive_cipher(c); if aead_authentication_ok { - // Select this key as the new default if it's newer than the current key. - if p > 0 - && state.session_keys[state.cur_session_key_idx] - .as_ref() - .map_or(true, |old| old.creation_counter < session_key.creation_counter) - { - drop(state); - let mut state = session.state.write().unwrap(); - state.cur_session_key_idx = key_idx; - for i in 0..KEY_HISTORY_SIZE { - if i != key_idx { - if let Some(old_key) = state.session_keys[key_idx].as_ref() { - // Release pooled cipher memory from old keys. - old_key.receive_cipher_pool.lock().unwrap().clear(); - old_key.send_cipher_pool.lock().unwrap().clear(); + if session_key.receive_window.message_received(counter) { + // Select this key as the new default if it's newer than the current key. + if p > 0 + && state.session_keys[state.cur_session_key_idx] + .as_ref() + .map_or(true, |old| old.creation_counter < session_key.creation_counter) + { + drop(state); + let mut state = session.state.write().unwrap(); + state.cur_session_key_idx = key_idx; + for i in 0..KEY_HISTORY_SIZE { + if i != key_idx { + if let Some(old_key) = state.session_keys[key_idx].as_ref() { + // Release pooled cipher memory from old keys. + old_key.receive_cipher_pool.lock().unwrap().clear(); + old_key.send_cipher_pool.lock().unwrap().clear(); + } } } } - } - if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); + } else { + unlikely_branch(); + return Ok(ReceiveResult::Ok); + } } else { unlikely_branch(); - return Ok(ReceiveResult::Ok); + return Ok(ReceiveResult::Ignored); } } } @@ -942,6 +953,7 @@ impl ReceiveContext { // packet encoding for noise key counter offer // <- e, ee, se //////////////////////////////////////////////////////////////// + let mut reply_buf = [0_u8; KEX_BUF_LEN]; let reply_counter = session.send_counter.next(); let mut idx = HEADER_SIZE; @@ -1018,6 +1030,7 @@ impl ReceiveContext { Role::Bob, current_time, reply_counter, + counter, last_ratchet_count + 1, hybrid_kk.is_some(), ); @@ -1041,6 +1054,7 @@ impl ReceiveContext { PACKET_TYPE_KEY_COUNTER_OFFER => { // bob (remote) -> alice (local) + //////////////////////////////////////////////////////////////// // packet decoding for noise key counter offer // <- e, ee, se @@ -1134,11 +1148,12 @@ impl ReceiveContext { // Alice has now completed and validated the full hybrid exchange. - let counter = session.send_counter.next(); + let reply_counter = session.send_counter.next(); let session_key = SessionKey::new( session_key, Role::Alice, current_time, + reply_counter, counter, last_ratchet_count + 1, hybrid_kk.is_some(), @@ -1147,6 +1162,7 @@ impl ReceiveContext { //////////////////////////////////////////////////////////////// // packet encoding for post-noise session start ack //////////////////////////////////////////////////////////////// + let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; create_packet_header( &mut reply_buf, @@ -1154,11 +1170,13 @@ impl ReceiveContext { mtu, PACKET_TYPE_NOP, bob_session_id.into(), - counter, + reply_counter, )?; - let mut c = session_key.get_send_cipher(counter)?; - c.reset_init_gcm(CanonicalHeader::make(bob_session_id.into(), PACKET_TYPE_NOP, counter.to_u32()).as_bytes()); + let mut c = session_key.get_send_cipher(reply_counter)?; + c.reset_init_gcm( + CanonicalHeader::make(bob_session_id.into(), PACKET_TYPE_NOP, reply_counter.to_u32()).as_bytes(), + ); let gcm_tag = c.finish_encrypt(); safe_write_all(&mut reply_buf, HEADER_SIZE, &gcm_tag)?; session_key.return_send_cipher(c); @@ -1474,7 +1492,15 @@ fn parse_dec_key_offer_after_header( impl SessionKey { /// Create a new symmetric shared session key and set its key expiration times, etc. - fn new(key: Secret<64>, role: Role, current_time: i64, current_counter: CounterValue, ratchet_count: u64, jedi: bool) -> Self { + fn new( + key: Secret<64>, + role: Role, + current_time: i64, + current_counter: CounterValue, + remote_counter: u32, + ratchet_count: u64, + jedi: bool, + ) -> Self { let a2b: Secret = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n_clone(); let b2a: Secret = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n_clone(); let (receive_key, send_key) = match role { @@ -1485,6 +1511,7 @@ impl SessionKey { secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(), creation_time: current_time, creation_counter: current_counter, + receive_window: CounterWindow::new(remote_counter), lifetime: KeyLifetime::new(current_counter, current_time), ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING), receive_key,