diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index cc53c8cad..78144f429 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -3,6 +3,7 @@ // ZSSP: ZeroTier Secure Session Protocol // FIPS compliant Noise_IK with Jedi powers and built-in attack-resistant large payload (fragmentation) support. +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Mutex, RwLock}; use zerotier_crypto::aes::{Aes, AesGcm}; @@ -113,6 +114,7 @@ pub struct Session { /// An arbitrary application defined object associated with each session pub application_data: Application::Data, + ratchet_counts: [AtomicU64; 2], // Number of preceding session keys in ratchet header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) receive_windows: [CounterWindow; 2], // Receive window for anti-replay and deduplication state: RwLock, // Mutable parts of state (other than defrag buffers) @@ -144,7 +146,6 @@ struct SessionKey { send_key: Secret, // Send side AES-GCM key receive_cipher_pool: Mutex>>, // Pool of reusable sending ciphers send_cipher_pool: Mutex>>, // Pool of reusable receiving ciphers - ratchet_count: u64, // Number of preceding session keys in ratchet jedi: bool, // True if Kyber1024 was used (both sides enabled) } @@ -246,6 +247,7 @@ impl Session { &bob_s_public_blob_hash, &noise_ss, None, + 1, None, mtu, current_time, @@ -256,6 +258,7 @@ impl Session { return Ok(Self { id: local_session_id, application_data, + ratchet_counts: [AtomicU64::new(1), AtomicU64::new(0)], header_check_cipher, receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], state: RwLock::new(SessionMutableState { @@ -293,12 +296,14 @@ impl Session { debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); let state = self.state.read().unwrap(); if let Some(remote_session_id) = state.remote_session_id { - if let Some(session_key) = state.session_keys[state.cur_session_key_id as usize].as_ref() { + let key_id = state.cur_session_key_id; + if let Some(session_key) = state.session_keys[key_id as usize].as_ref() { // Total size of the armored packet we are going to send (may end up being fragmented) let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; - + //key ratchet count to be used for salting + let ratchet_count = self.ratchet_counts[key_id as usize].load(Ordering::Relaxed); // This outgoing packet's nonce counter value. - let counter = state.send_counters[state.cur_session_key_id as usize].next(); + let counter = state.send_counters[key_id as usize].next(); //////////////////////////////////////////////////////////////// // packet encoding for post-noise transport @@ -312,7 +317,7 @@ impl Session { PACKET_TYPE_DATA, remote_session_id.into(), counter, - state.cur_session_key_id + key_id )?; // Get an initialized AES-GCM cipher and re-initialize with a 96-bit IV built from remote session ID, @@ -331,7 +336,7 @@ impl Session { let fragment_size = fragment_data_size + HEADER_SIZE; c.crypt(&data[..fragment_data_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); data = &data[fragment_data_size..]; - set_header_check_code(mtu_sized_buffer, &self.header_check_cipher); + set_header_check_code(mtu_sized_buffer, ratchet_count, &self.header_check_cipher); send(&mut mtu_sized_buffer[..fragment_size]); debug_assert!(header[15].wrapping_shr(2) < 63); @@ -352,7 +357,7 @@ impl Session { c.crypt(data, &mut mtu_sized_buffer[HEADER_SIZE..payload_end]); let gcm_tag = c.finish_encrypt(); mtu_sized_buffer[payload_end..last_fragment_size].copy_from_slice(&gcm_tag); - set_header_check_code(mtu_sized_buffer, &self.header_check_cipher); + set_header_check_code(mtu_sized_buffer, ratchet_count, &self.header_check_cipher); send(&mut mtu_sized_buffer[..last_fragment_size]); // Check reusable AES-GCM instance back into pool. @@ -380,8 +385,9 @@ impl Session { /// and whether Kyber1024 was used. None is returned if the session isn't established. pub fn status(&self) -> Option<([u8; 16], i64, u64, bool)> { let state = self.state.read().unwrap(); - if let Some(key) = state.session_keys[state.cur_session_key_id as usize].as_ref() { - Some((key.secret_fingerprint, key.creation_time, key.ratchet_count, key.jedi)) + let key_id = state.cur_session_key_id; + if let Some(key) = state.session_keys[key_id as usize].as_ref() { + Some((key.secret_fingerprint, key.creation_time, self.ratchet_counts[key_id as usize].load(Ordering::Relaxed), key.jedi)) } else { None } @@ -419,6 +425,7 @@ impl Session { //mark the previous key as no longer being supported because it is about to be overwritten self.receive_windows[(!current_key_id) as usize].invalidate(); let mut offer = None; + //TODO: what happens here if the session is in a limbo state due to dropped packets? if send_ephemeral_offer( &mut send, state.send_counters[current_key_id as usize].next(), @@ -432,6 +439,7 @@ impl Session { &self.remote_s_public_blob_hash, &self.noise_ss, state.session_keys[current_key_id as usize].as_ref(), + self.ratchet_counts[current_key_id as usize].load(Ordering::Relaxed), if state.remote_session_id.is_some() { Some(&self.header_check_cipher) } else { @@ -497,7 +505,10 @@ impl ReceiveContext { if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64) { if let Some(session) = app.lookup_session(local_session_id) { - if verify_header_check_code(incoming_packet, &session.header_check_cipher) { + //this is the only time ratchet_counts is ever accessed outside of a lock + //as such this read can be wrong, but that is incredibly unlikely since we are tracking the last two ratchet counts, and if it's wrong it just means we drop a packet that would have been dropped anyways for being too old or too new + let ratchet_count = session.ratchet_counts[key_id as usize].load(Ordering::SeqCst); + if verify_header_check_code(incoming_packet, ratchet_count, &session.header_check_cipher) { if session.receive_windows[key_id as usize].message_received(counter) { let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); if fragment_count > 1 { @@ -556,8 +567,8 @@ impl ReceiveContext { } } else { unlikely_branch(); // we want data receive to be the priority branch, this is only occasionally used - - if verify_header_check_code(incoming_packet, &self.incoming_init_header_check_cipher) { + //salt with a known value so new sessions can be established + if verify_header_check_code(incoming_packet, 1u64, &self.incoming_init_header_check_cipher) { let canonical_header = CanonicalHeader::make(SessionId::NIL, packet_type, counter); if fragment_count > 1 { let mut defrag = self.initial_offer_defrag.lock().unwrap(); @@ -839,7 +850,7 @@ impl ReceiveContext { // Perform checks and match ratchet key if there's an existing session, or gate (via host) and // then create new sessions. - let (new_session, reply_counter, new_key_id, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() { + let (new_session, reply_counter, new_key_id, ratchet_key) = if let Some(session) = session.as_ref() { // Existing session identity must match the one in this offer. if !secure_eq(&session.remote_s_public_blob_hash, &SHA384::hash(&alice_s_public_blob)) { return Err(Error::FailedAuthentication); @@ -848,7 +859,6 @@ impl ReceiveContext { // Match ratchet key fingerprint and fail if no match, which likely indicates an old offer packet. let alice_ratchet_key_fingerprint = alice_ratchet_key_fingerprint.unwrap(); let mut ratchet_key = None; - let mut last_ratchet_count = 0; let state = session.state.read().unwrap(); //key_id here is the key id of the key being rekeyed and replaced //it must be equal to the current session key, and not the previous session key @@ -858,14 +868,13 @@ impl ReceiveContext { if let Some(k) = state.session_keys[key_id as usize].as_ref() { if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { ratchet_key = Some(k.ratchet_key.clone()); - last_ratchet_count = k.ratchet_count; } } if ratchet_key.is_none() { return Ok(ReceiveResult::Ignored); // old packet? } - (None, state.send_counters[state.cur_session_key_id as usize].next(), !key_id, ratchet_key, last_ratchet_count) + (None, state.send_counters[state.cur_session_key_id as usize].next(), !key_id, ratchet_key) } else { if key_id != false { return Ok(ReceiveResult::Ignored); @@ -882,6 +891,7 @@ impl ReceiveContext { Some(Session:: { id: new_session_id, application_data: associated_object, + ratchet_counts: [AtomicU64::new(1), AtomicU64::new(0)], header_check_cipher, receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], state: RwLock::new(SessionMutableState { @@ -900,8 +910,7 @@ impl ReceiveContext { }), reply_counter, false, - None, - 0, + None ) } else { return Err(Error::NewSessionRejected); @@ -1037,32 +1046,34 @@ impl ReceiveContext { session_key, Role::Bob, current_time, - last_ratchet_count + 1, hybrid_kk.is_some(), ); - //TODO: check for correct orderings + let ratchet_count; let mut state = session.state.write().unwrap(); let _ = state.session_keys[new_key_id as usize].replace(session_key); if existing_session.is_some() { + debug_assert!(new_key_id != key_id); // receive_windows only has race conditions with the counter of the remote party. It is theoretically possible that the local host receives counters under new_key_id while the receive_window is still in the process of resetting, but this is very unlikely. If it does happen, two things could happen: // 1) The received counter is less than what is currently stored in the window, so a valid packet is rejected // 2) The received counter is greater than what is currently stored in the window, so a valid packet is accepted *but* its counter is deleted from the window so it can be replayed - // 1 is completely acceptable behavior; 2 is unacceptable, but extremely extremely unlikely. Since it is utterly impractical for an adversary to trigger 2 intentionally, and preventing 2 is expensive, we do not currently plan to prevent it. - // if receive_window is ever reimplemented, double check it maintains the above properties. + // To prevent these race conditions, we only update the ratchet_count for salting the check code after the window has reset. So if a counter passes the initial check code: it either means the thread sees ratchet count has been update, therefore it either sees receive_window has been reset (due to memory orderings), or it means a rare accidental check forge has occurred. session.receive_windows[new_key_id as usize].reset_for_initial_offer(); + ratchet_count = session.ratchet_counts[new_key_id as usize].fetch_add(2, Ordering::SeqCst) + 1; let _ = state.remote_session_id.replace(alice_session_id); - // if this wasn't done inside a lock, a theoretical race condition exists where a thread uses the new key id before the counter is reset, or worse: a thread has held onto the previous key_id == new_key_id, and attempts to use the reset counter + // if the following wasn't done inside a lock, a theoretical race condition exists where a thread uses the new key id before the counter is reset, or worse: a thread has held onto the previous key_id == new_key_id, and attempts to use the reset counter // for this reason do not access send_counters without holding the read lock state.cur_session_key_id = new_key_id; state.send_counters[new_key_id as usize].reset_for_initial_offer(); + } else { + ratchet_count = 1; } drop(state); // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. - send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher); + send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, ratchet_count, &session.header_check_cipher); if let Some(new_session) = new_session { return Ok(ReceiveResult::OkNewSession(new_session)); @@ -1176,21 +1187,20 @@ impl ReceiveContext { session_key, Role::Alice, current_time, - last_ratchet_count + 1, hybrid_kk.is_some(), ); let new_key_id = offer.key_id; - let is_new_session = offer.ratchet_count == 0; drop(state); //TODO: check for correct orderings let mut state = session.state.write().unwrap(); let _ = state.remote_session_id.replace(bob_session_id); let _ = state.session_keys[new_key_id as usize].replace(session_key); - if !is_new_session { + if last_ratchet_count > 0 { //when an brand new key offer is sent, it is sent using the new_key_id==false counter, we cannot reset it in that case. //NOTE: the following code should be properly threadsafe, see the large comment above at the end of KEY_OFFER decoding for more info session.receive_windows[new_key_id as usize].reset_for_initial_offer(); + let _ = session.ratchet_counts[new_key_id as usize].fetch_add(2, Ordering::SeqCst).wrapping_add(2); state.cur_session_key_id = new_key_id; state.send_counters[new_key_id as usize].reset_for_initial_offer(); } @@ -1229,6 +1239,7 @@ fn send_ephemeral_offer( bob_s_public_blob_hash: &[u8], noise_ss: &Secret<48>, current_key: Option<&SessionKey>, + ratchet_count: u64, header_check_cipher: Option<&Aes>, // None to use one based on the recipient's public key for initial contact mtu: usize, current_time: i64, @@ -1248,10 +1259,10 @@ fn send_ephemeral_offer( }; // Get ratchet key for current key if one exists. - let (ratchet_key, ratchet_count) = if let Some(current_key) = current_key { - (Some(current_key.ratchet_key.clone()), current_key.ratchet_count) + let ratchet_key = if let Some(current_key) = current_key { + Some(current_key.ratchet_key.clone()) } else { - (None, 0) + None }; // Random ephemeral offer ID @@ -1348,12 +1359,13 @@ fn send_ephemeral_offer( let packet_end = idx; if let Some(header_check_cipher) = header_check_cipher { - send_with_fragmentation(send, &mut packet_buf[..packet_end], mtu, header_check_cipher); + send_with_fragmentation(send, &mut packet_buf[..packet_end], mtu, ratchet_count, header_check_cipher); } else { send_with_fragmentation( send, &mut packet_buf[..packet_end], mtu, + ratchet_count, &Aes::new(kbkdf512(&bob_s_public_blob_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::()), ); } @@ -1418,6 +1430,7 @@ fn send_with_fragmentation( send: &mut SendFunction, packet: &mut [u8], mtu: usize, + ratchet_count: u64, header_check_cipher: &Aes, ) { let packet_len = packet.len(); @@ -1426,7 +1439,7 @@ fn send_with_fragmentation( let mut header: [u8; 16] = packet[..HEADER_SIZE].try_into().unwrap(); loop { let fragment = &mut packet[fragment_start..fragment_end]; - set_header_check_code(fragment, header_check_cipher); + set_header_check_code(fragment, ratchet_count, header_check_cipher); send(fragment); if fragment_end < packet_len { debug_assert!(header[15].wrapping_shr(2) < 63); @@ -1442,19 +1455,28 @@ fn send_with_fragmentation( } /// Set 32-bit header check code, used to make fragmentation mechanism robust. -fn set_header_check_code(packet: &mut [u8], header_check_cipher: &Aes) { +fn set_header_check_code(packet: &mut [u8], ratchet_count: u64, header_check_cipher: &Aes) { debug_assert!(packet.len() >= MIN_PACKET_SIZE); - let mut check_code = 0u128.to_ne_bytes(); - header_check_cipher.encrypt_block(&packet[4..20], &mut check_code); - packet[..4].copy_from_slice(&check_code[..4]); + + let mut header_mac = 0u128.to_le_bytes(); + memory::store_raw((ratchet_count as u16).to_le_bytes(), &mut header_mac[0..2]); + header_mac[2..16].copy_from_slice(&packet[4..18]); + header_check_cipher.encrypt_block_in_place(&mut header_mac); + packet[..4].copy_from_slice(&header_mac[..4]); } /// Verify 32-bit header check code. -fn verify_header_check_code(packet: &[u8], header_check_cipher: &Aes) -> bool { +/// This is not nearly enough entropy to be cryptographically secure, it only is meant for making DOS attacks very hard +fn verify_header_check_code(packet: &[u8], ratchet_count: u64, header_check_cipher: &Aes) -> bool { debug_assert!(packet.len() >= MIN_PACKET_SIZE); - let mut header_mac = 0u128.to_ne_bytes(); - header_check_cipher.encrypt_block(&packet[4..20], &mut header_mac); - memory::load_raw::(&packet[..4]) == memory::load_raw::(&header_mac) + //2 bytes is the ratchet key + //12 bytes is the header we want to verify + //2 bytes is random salt from the encrypted message + let mut header_mac = 0u128.to_le_bytes(); + memory::store_raw((ratchet_count as u16).to_le_bytes(), &mut header_mac[0..2]); + header_mac[2..16].copy_from_slice(&packet[4..18]); + header_check_cipher.encrypt_block_in_place(&mut header_mac); + memory::load_raw::(&packet[..4]) == memory::load_raw::(&header_mac[..4]) } /// Parse KEY_OFFER and KEY_COUNTER_OFFER starting after the unencrypted public key part. @@ -1511,7 +1533,6 @@ impl SessionKey { key: Secret<64>, role: Role, current_time: i64, - ratchet_count: u64, jedi: bool, ) -> Self { let a2b: Secret = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n_clone(); @@ -1529,7 +1550,6 @@ impl SessionKey { send_key, receive_cipher_pool: Mutex::new(Vec::with_capacity(2)), send_cipher_pool: Mutex::new(Vec::with_capacity(2)), - ratchet_count, jedi, } }