diff --git a/core-crypto/src/zssp.rs b/core-crypto/src/zssp.rs index 48c2b3ab2..f850c7f16 100644 --- a/core-crypto/src/zssp.rs +++ b/core-crypto/src/zssp.rs @@ -3,7 +3,6 @@ // ZSSP: ZeroTier Secure Session Protocol // FIPS compliant Noise_IK with Jedi powers and built-in attack-resistant large payload (fragmentation) support. -use std::collections::LinkedList; use std::io::{Read, Write}; use std::num::NonZeroU64; use std::ops::Deref; @@ -45,19 +44,24 @@ pub const SERVICE_INTERVAL: u64 = 10000; const JEDI: bool = true; /// Start attempting to rekey after a key has been used to send packets this many times. +/// +/// This is 1/4 the NIST recommended maximum and 1/8 the absolute limit where u32 wraps. const REKEY_AFTER_USES: u64 = 536870912; /// Maximum random jitter to add to rekey-after usage count. const REKEY_AFTER_USES_MAX_JITTER: u32 = 1048576; /// Hard expiration after this many uses. +/// +/// Use of the key beyond this point is prohibited. This is the point where u32 wraps minus +/// a little bit of margin. We should never get here under ordinary circumstances. const EXPIRE_AFTER_USES: u64 = (u32::MAX - 1024) as u64; /// Start attempting to rekey after a key has been in use for this many milliseconds. const REKEY_AFTER_TIME_MS: i64 = 1000 * 60 * 60; // 1 hour /// Maximum random jitter to add to rekey-after time. -const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 5; +const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 10; /// Rate limit for sending new offers to attempt to re-key. const OFFER_RATE_LIMIT_MS: i64 = 2000; @@ -96,7 +100,7 @@ const HMAC_SIZE: usize = 48; const SESSION_ID_SIZE: usize = 6; /// Maximum number of present and future keys to hold at any given time. -const KEY_HISTORY_SIZE_MAX: usize = 3; +const KEY_HISTORY_SIZE: usize = 3; // Key usage labels for sub-key derivation using kbkdf (HMAC). const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; @@ -294,7 +298,8 @@ pub struct Session { struct SessionMutableState { remote_session_id: Option, - keys: LinkedList, + keys: [Option; KEY_HISTORY_SIZE], + key_ptr: usize, offer: Option>, } @@ -358,7 +363,8 @@ impl Session { header_check_cipher, state: RwLock::new(SessionMutableState { remote_session_id: None, - keys: LinkedList::new(), + keys: [None, None, None], + key_ptr: 0, offer: Some(offer), }), remote_s_public_hash, @@ -379,7 +385,7 @@ impl Session { debug_assert!(mtu_buffer.len() >= MIN_MTU); let state = self.state.read(); if let Some(remote_session_id) = state.remote_session_id { - if let Some(key) = state.keys.front() { + if let Some(key) = state.keys[state.key_ptr].as_ref() { let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; let counter = self.send_counter.next(); @@ -433,17 +439,17 @@ impl Session { /// Check whether this session is established. pub fn established(&self) -> bool { let state = self.state.read(); - state.remote_session_id.is_some() && !state.keys.is_empty() + state.remote_session_id.is_some() && state.keys[state.key_ptr].is_some() } /// Get information about this session's security state. /// - /// This returns a tuple of: the time at which the current key was established, the length of its ratchet chain, + /// This returns a tuple of: the key fingerprint, the time it was established, the length of its ratchet chain, /// and whether Kyber1024 was used. None is returned if the session isn't established. - pub fn security_info(&self) -> Option<(i64, u64, bool)> { + pub fn security_info(&self) -> Option<([u8; 16], i64, u64, bool)> { let state = self.state.read(); - if let Some(key) = state.keys.front() { - Some((key.establish_time, key.ratchet_count, key.jedi)) + if let Some(key) = state.keys[state.key_ptr].as_ref() { + Some((key.fingerprint, key.establish_time, key.ratchet_count, key.jedi)) } else { None } @@ -454,9 +460,13 @@ impl Session { /// * `offer_metadata' - Any meta-data to include with initial key offers sent. /// * `mtu` - Physical MTU for sent packets /// * `current_time` - Current monotonic time in milliseconds - pub fn service(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64) { + /// * `force_rekey` - Re-key the session now regardless of key aging (still subject to rate limiting) + pub fn service(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force_rekey: bool) { let state = self.state.upgradable_read(); - if state.keys.front().map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time)) + if (force_rekey + || state.keys[state.key_ptr] + .as_ref() + .map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time))) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS) { if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) { @@ -471,7 +481,7 @@ impl Session { &remote_s_public_p384, &self.remote_s_public_hash, &self.ss, - state.keys.front(), + state.keys[state.key_ptr].as_ref(), if state.remote_session_id.is_some() { &self.header_check_cipher } else { @@ -612,61 +622,56 @@ impl ReceiveContext { if packet_type <= PACKET_TYPE_NOP { if let Some(session) = session { let state = session.state.read(); - let key_count = state.keys.len(); - for (key_index, key) in state.keys.iter().enumerate() { - let tail = fragments.last().unwrap().as_ref(); - if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { - unlikely_branch(); - return Err(Error::InvalidPacket); - } + for p in 0..KEY_HISTORY_SIZE { + let key_ptr = (state.key_ptr + p) % KEY_HISTORY_SIZE; + if let Some(key) = state.keys[key_ptr].as_ref() { + let tail = fragments.last().unwrap().as_ref(); + if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { + unlikely_branch(); + return Err(Error::InvalidPacket); + } - let mut c = key.get_receive_cipher(); - c.init(pseudoheader); + let mut c = key.get_receive_cipher(); + c.init(pseudoheader); - let mut data_len = 0; + let mut data_len = 0; + + for f in fragments[..(fragments.len() - 1)].iter() { + let f = f.as_ref(); + debug_assert!(f.len() >= HEADER_SIZE); + let current_frag_data_start = data_len; + data_len += f.len() - HEADER_SIZE; + if data_len > data_buf.len() { + unlikely_branch(); + key.return_receive_cipher(c); + return Err(Error::DataBufferTooSmall); + } + c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); + } - for f in fragments[..(fragments.len() - 1)].iter() { - let f = f.as_ref(); - debug_assert!(f.len() >= HEADER_SIZE); let current_frag_data_start = data_len; - data_len += f.len() - HEADER_SIZE; + data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); if data_len > data_buf.len() { unlikely_branch(); key.return_receive_cipher(c); return Err(Error::DataBufferTooSmall); } - c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); - } + c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]); - let current_frag_data_start = data_len; - data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); - if data_len > data_buf.len() { - unlikely_branch(); + let ok = c.finish_decrypt(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]); key.return_receive_cipher(c); - return Err(Error::DataBufferTooSmall); - } - c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]); - - let ok = c.finish_decrypt(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]); - key.return_receive_cipher(c); - if ok { - // Drop obsolete keys if we had to iterate past the first key to get here. - if key_index > 0 { - unlikely_branch(); - drop(state); - let mut state = session.state.write(); - if state.keys.len() == key_count { - for _ in 0..key_index { - let _ = state.keys.pop_front(); - } + if ok { + // Select this key as the new default if it's newer than the current key. + if p > 0 && state.keys[state.key_ptr].as_ref().map_or(true, |old| old.establish_counter < key.establish_counter) { + drop(state); + session.state.write().key_ptr = key_ptr; + } + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); + } else { + unlikely_branch(); + return Ok(ReceiveResult::Ok); } - } - - if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); - } else { - unlikely_branch(); - return Ok(ReceiveResult::Ok); } } } @@ -756,9 +761,11 @@ impl ReceiveContext { let mut ratchet_count = 0; let state = session.state.read(); for k in state.keys.iter() { - if SHA384::hash(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_id) { - ratchet_key = Some(k.ratchet_key.clone()); - ratchet_count = k.ratchet_count; + if let Some(k) = k.as_ref() { + if SHA384::hash(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_id) { + ratchet_key = Some(k.ratchet_key.clone()); + ratchet_count = k.ratchet_count; + } } } (ratchet_key, ratchet_count) @@ -815,7 +822,8 @@ impl ReceiveContext { header_check_cipher, state: RwLock::new(SessionMutableState { remote_session_id: Some(alice_session_id), - keys: LinkedList::new(), + keys: [None, None, None], + key_ptr: 0, offer: None, }), remote_s_public_hash: SHA384::hash(&alice_s_public), @@ -914,9 +922,12 @@ impl ReceiveContext { reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac); reply_len += HMAC_SIZE; + let key = SessionKey::new(key, Role::Bob, current_time, reply_counter, ratchet_count + 1, e1e1.is_some()); + let mut state = session.state.write(); let _ = state.remote_session_id.replace(alice_session_id); - add_session_key(&mut state.keys, SessionKey::new(key, Role::Bob, current_time, reply_counter, ratchet_count + 1, e1e1.is_some())); + let next_key_ptr = (state.key_ptr + 1) % KEY_HISTORY_SIZE; + let _ = state.keys[next_key_ptr].replace(key); drop(state); // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. @@ -1015,8 +1026,9 @@ impl ReceiveContext { let mut state = RwLockUpgradableReadGuard::upgrade(state); let _ = state.remote_session_id.replace(bob_session_id); + let next_key_ptr = (state.key_ptr + 1) % KEY_HISTORY_SIZE; + let _ = state.keys[next_key_ptr].replace(key); let _ = state.offer.take(); - add_session_key(&mut state.keys, key); return Ok(ReceiveResult::Ok); } @@ -1281,23 +1293,6 @@ fn dearmor_header(packet: &[u8], header_check_cipher: &Aes) -> Option<(u8, u8, u } } -fn add_session_key(keys: &mut LinkedList, key: SessionKey) { - // Sanity check to make sure duplicates can't get in here. Should be impossible. - for k in keys.iter() { - if k.receive_key.eq(&key.receive_key) { - return; - } - } - - debug_assert!(KEY_HISTORY_SIZE_MAX >= 2); - while keys.len() >= KEY_HISTORY_SIZE_MAX { - let current = keys.pop_front().unwrap(); - let _ = keys.pop_front(); - keys.push_front(current); - } - keys.push_back(key); -} - fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<([u8; 16], SessionId, &[u8], &[u8], &[u8], Option<[u8; 16]>), Error> { let mut p = &incoming_packet[..]; let mut offer_id = [0_u8; 16]; @@ -1328,12 +1323,16 @@ fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Resu if p.len() < (pqc_kyber::KYBER_PUBLICKEYBYTES + 1) { return Err(Error::InvalidPacket); } - &p[1..(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)] + let e1p = &p[1..(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)]; + p = &p[(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)..]; + e1p } else { if p.len() < (pqc_kyber::KYBER_CIPHERTEXTBYTES + 1) { return Err(Error::InvalidPacket); } - &p[1..(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)] + let e1p = &p[1..(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)]; + p = &p[(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)..]; + e1p } } _ => &[], @@ -1387,7 +1386,9 @@ impl KeyLifetime { #[allow(unused)] struct SessionKey { + fingerprint: [u8; 16], establish_time: i64, + establish_counter: u64, lifetime: KeyLifetime, ratchet_key: Secret<64>, receive_key: Secret<32>, @@ -1409,7 +1410,9 @@ impl SessionKey { Role::Bob => (a2b, b2a), }; Self { + fingerprint: SHA384::hash(key.as_bytes())[..16].try_into().unwrap(), establish_time: current_time, + establish_counter: current_counter.0, lifetime: KeyLifetime::new(current_counter, current_time), ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING), receive_key, @@ -1465,6 +1468,7 @@ mod tests { use parking_lot::Mutex; use std::collections::LinkedList; use std::sync::Arc; + use zerotier_utils::hex; #[allow(unused_imports)] use super::*; @@ -1475,9 +1479,10 @@ mod tests { psk: Secret<64>, session: Mutex>>>>, session_id_counter: Mutex, - pub queue: Mutex>>, - pub this_name: &'static str, - pub other_name: &'static str, + queue: Mutex>>, + key_id: Mutex<[u8; 16]>, + this_name: &'static str, + other_name: &'static str, } impl TestHost { @@ -1491,6 +1496,7 @@ mod tests { session: Mutex::new(None), session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1), queue: Mutex::new(LinkedList::new()), + key_id: Mutex::new([0; 16]), this_name, other_name, } @@ -1568,7 +1574,7 @@ mod tests { )); let mut ts = 0; - for _ in 0..3 { + for test_loop in 0..128 { for host in [&alice_host, &bob_host] { let send_to_other = |data: &mut [u8]| { if std::ptr::eq(host, &alice_host) { @@ -1621,11 +1627,28 @@ mod tests { data_buf.fill(0x12); if let Some(session) = host.session.lock().as_ref().cloned() { if session.established() { - for _ in 0..16 { + { + let mut key_id = host.key_id.lock(); + let security_info = session.security_info().unwrap(); + if !security_info.0.eq(key_id.as_ref()) { + *key_id = security_info.0; + println!( + "zssp: new key at {}: fingerprint {} ratchet {} kyber {}", + host.this_name, + hex::to_string(key_id.as_ref()), + security_info.2, + security_info.3 + ); + } + } + for _ in 0..32 { assert!(session .send(send_to_other, &mut mtu_buffer, &data_buf[..((random::xorshift64_random() as usize) % data_buf.len())]) .is_ok()); } + if (test_loop % 8) == 0 && test_loop >= 8 && host.this_name.eq("alice") { + session.service(host, send_to_other, &[], mtu_buffer.len(), test_loop as i64, true); + } } } }