diff --git a/core-crypto/src/zssp.rs b/core-crypto/src/zssp.rs index 82861904c..09641583a 100644 --- a/core-crypto/src/zssp.rs +++ b/core-crypto/src/zssp.rs @@ -1,11 +1,12 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. +use std::collections::LinkedList; use std::io::{Read, Write}; use std::num::NonZeroU64; use std::ops::Deref; use std::sync::atomic::{AtomicU64, Ordering}; -use crate::aes::AesGcm; +use crate::aes::{Aes, AesGcm}; use crate::hash::{hmac_sha384, hmac_sha512, SHA384}; use crate::p384::{P384KeyPair, P384PublicKey, P384_PUBLIC_KEY_SIZE}; use crate::random; @@ -53,12 +54,18 @@ const E1_TYPE_KYBER1024: u8 = 1; const MAX_FRAGMENTS: usize = 48; // protocol max: 63 const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc. -const HEADER_SIZE: usize = 12; +const HEADER_SIZE: usize = 16; +const HEADER_CHECK_SIZE: usize = 4; const AES_GCM_TAG_SIZE: usize = 16; +const AES_GCM_NONCE_SIZE: usize = 12; +const AES_GCM_NONCE_START: usize = 4; +const AES_GCM_NONCE_END: usize = 16; const HMAC_SIZE: usize = 48; const SESSION_ID_SIZE: usize = 6; +const KEY_HISTORY_SIZE_MAX: usize = 3; const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; +const KBKDF_KEY_USAGE_LABEL_HEADER_MAC: u8 = b'H'; const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; @@ -243,6 +250,7 @@ pub struct Session { send_counter: Counter, psk: Secret<64>, // Arbitrary PSK provided by external code ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer + header_check_cipher: Aes, // Cipher used for fast 32-bit header MAC state: RwLock, // Mutable parts of state (other than defrag buffers) remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key @@ -251,13 +259,14 @@ pub struct Session { struct MutableState { remote_session_id: Option, - keys: [Option; 2], // current, next (promoted to current on successful decrypt) + keys: LinkedList, offer: Option, } /// State information to associate with receiving contexts such as sockets or remote paths/endpoints. pub struct ReceiveContext { initial_offer_defrag: Mutex, 1024, 128>>, + incoming_init_header_check_cipher: Aes, } impl Session { @@ -276,11 +285,13 @@ impl Session { ) -> Result { if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) { if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) { - let counter = Counter::new(); + let send_counter = Counter::new(); + let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); let remote_s_public_hash = SHA384::hash(remote_s_public); + let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); if let Ok(offer) = EphemeralOffer::create_alice_offer( &mut send, - counter.next(), + send_counter.next(), local_session_id, None, host.get_local_s_public(), @@ -288,6 +299,7 @@ impl Session { &remote_s_public_p384, &remote_s_public_hash, &ss, + &outgoing_init_header_check_cipher, mtu, current_time, jedi, @@ -295,12 +307,13 @@ impl Session { return Ok(Self { id: local_session_id, associated_object, - send_counter: counter, + send_counter, psk: psk.clone(), ss, + header_check_cipher, state: RwLock::new(MutableState { remote_session_id: None, - keys: [None, None], + keys: LinkedList::new(), offer: Some(offer), }), remote_s_public_hash, @@ -313,10 +326,59 @@ impl Session { return Err(Error::InvalidParameter); } + #[inline] + pub fn send(&self, mut send: SendFunction, mtu_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> { + debug_assert!(mtu_buffer.len() >= MIN_MTU); + let state = self.state.read(); + if let Some(remote_session_id) = state.remote_session_id { + if let Some(key) = state.keys.front() { + let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; + let counter = self.send_counter.next(); + + send_with_fragmentation_init_header(mtu_buffer, packet_len, mtu_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), counter); + let mut c = key.get_send_cipher(counter)?; + c.init(mtu_buffer); + + if packet_len > mtu_buffer.len() { + let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap(); + let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE; + let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); + loop { + debug_assert!(data.len() > last_fragment_data_mtu); + c.crypt(&data[HEADER_CHECK_SIZE..fragment_data_mtu], &mut mtu_buffer[HEADER_SIZE..]); + let hc = header_check(mtu_buffer, &self.header_check_cipher); + mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes()); + send(mtu_buffer); + data = &data[fragment_data_mtu..]; + debug_assert!(header[7].wrapping_shr(2) < 63); + header[7] += 0x04; // increment fragment number + mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].copy_from_slice(&header); + if data.len() <= last_fragment_data_mtu { + break; + } + } + packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; + } + + let gcm_tag_idx = data.len() + HEADER_SIZE; + c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]); + let hc = header_check(mtu_buffer, &self.header_check_cipher); + mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes()); + mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish()); + send(&mut mtu_buffer[..packet_len]); + + key.return_send_cipher(c); + + return Ok(()); + } + } + return Err(Error::SessionNotEstablished); + } + #[inline] pub fn rekey_check(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) { let state = self.state.upgradable_read(); - if let Some(key) = state.keys[0].as_ref() { + if let Some(key) = state.keys.front() { if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) { if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) { if let Ok(offer) = EphemeralOffer::create_alice_offer( @@ -329,6 +391,7 @@ impl Session { &remote_s_public_p384, &self.remote_s_public_hash, &self.ss, + &self.header_check_cipher, mtu, current_time, jedi, @@ -343,9 +406,10 @@ impl Session { impl ReceiveContext { #[inline] - pub fn new() -> Self { + pub fn new(host: &H) -> Self { Self { initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), + incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()), } } @@ -366,17 +430,22 @@ impl ReceiveContext { return Err(Error::InvalidPacket); } - let header_0_8 = memory::u64_from_le_bytes(incoming_packet); // type, frag info, session ID - let counter = memory::u32_from_le_bytes(&incoming_packet[8..]); - let local_session_id = SessionId::new_from_u64(header_0_8.wrapping_shr(16)); - let packet_type = (header_0_8 as u8) & 15; - let fragment_count = ((header_0_8.wrapping_shr(4) as u8) & 63).wrapping_add(1); - let fragment_no = (header_0_8.wrapping_shr(10) as u8) & 63; + let header_0_8 = memory::u64_from_le_bytes(&incoming_packet[HEADER_CHECK_SIZE..12]); // session ID, type, frag info + let counter = memory::u32_from_le_bytes(&incoming_packet[12..16]); + let local_session_id = SessionId::new_from_u64(header_0_8 & SessionId::MAX_BIT_MASK); + let packet_type = (header_0_8.wrapping_shr(48) as u8) & 15; + let fragment_count = ((header_0_8.wrapping_shr(52) as u8) & 63).wrapping_add(1); + let fragment_no = (header_0_8.wrapping_shr(58) as u8) & 63; - if fragment_count > 1 { - if let Some(local_session_id) = local_session_id { - if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { - if let Some(session) = host.session_lookup(local_session_id) { + if let Some(local_session_id) = local_session_id { + if let Some(session) = host.session_lookup(local_session_id) { + if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &session.header_check_cipher) { + unlikely_branch(); + return Err(Error::FailedAuthentication); + } + + if fragment_count > 1 { + if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { let mut defrag = session.defrag.lock(); let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { @@ -385,38 +454,30 @@ impl ReceiveContext { } } else { unlikely_branch(); - return Err(Error::UnknownLocalSessionId(local_session_id)); + return Err(Error::InvalidPacket); } } else { - unlikely_branch(); - return Err(Error::InvalidPacket); + return self.receive_complete(host, &mut send, data_buf, &[incoming_packet_buf], packet_type, Some(session), mtu, jedi, current_time); } } else { - if fragment_count <= (KEY_EXCHANGE_MAX_FRAGMENTS as u8) && fragment_no < fragment_count { - let mut defrag = self.initial_offer_defrag.lock(); - let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { - drop(defrag); // release lock - return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); - } - } else { - unlikely_branch(); - return Err(Error::InvalidPacket); - } + unlikely_branch(); + return Err(Error::UnknownLocalSessionId(local_session_id)); } } else { - return self.receive_complete( - host, - &mut send, - data_buf, - &[incoming_packet_buf], - packet_type, - local_session_id.and_then(|lsid| host.session_lookup(lsid)), - mtu, - jedi, - current_time, - ); - } + unlikely_branch(); + + if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &self.incoming_init_header_check_cipher) { + unlikely_branch(); + return Err(Error::FailedAuthentication); + } + + let mut defrag = self.initial_offer_defrag.lock(); + let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { + drop(defrag); // release lock + return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); + } + }; return Ok(ReceiveResult::Ok); } @@ -440,60 +501,63 @@ impl ReceiveContext { if packet_type <= PACKET_TYPE_NOP { if let Some(session) = session { let state = session.state.read(); - for ki in 0..2 { - if let Some(key) = state.keys[ki].as_ref() { - let tail = fragments.last().unwrap().as_ref(); - if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { - unlikely_branch(); - return Err(Error::InvalidPacket); - } + let key_count = state.keys.len(); + for (key_index, key) in state.keys.iter().enumerate() { + let tail = fragments.last().unwrap().as_ref(); + if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { + unlikely_branch(); + return Err(Error::InvalidPacket); + } - let mut c = key.get_receive_cipher(); - c.init(&get_aes_gcm_nonce(fragments.first().unwrap().as_ref())); + let mut c = key.get_receive_cipher(); + c.init(&fragments.first().unwrap().as_ref()[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); - let mut data_len = 0; - - for f in fragments[..(fragments.len() - 1)].iter() { - let f = f.as_ref(); - debug_assert!(f.len() >= HEADER_SIZE); - let current_frag_data_start = data_len; - data_len += f.len() - HEADER_SIZE; - if data_len > data_buf.len() { - unlikely_branch(); - key.return_receive_cipher(c); - return Err(Error::DataBufferTooSmall); - } - c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); - } + let mut data_len = 0; + for f in fragments[..(fragments.len() - 1)].iter() { + let f = f.as_ref(); + debug_assert!(f.len() >= HEADER_SIZE); let current_frag_data_start = data_len; - data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); + data_len += f.len() - HEADER_SIZE; if data_len > data_buf.len() { unlikely_branch(); key.return_receive_cipher(c); return Err(Error::DataBufferTooSmall); } - c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]); + c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); + } - let tag = c.finish(); + let current_frag_data_start = data_len; + data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); + if data_len > data_buf.len() { + unlikely_branch(); key.return_receive_cipher(c); + return Err(Error::DataBufferTooSmall); + } + c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]); - if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) { - // If this succeeded with the "next" key, promote it to current. - if ki == 1 { - unlikely_branch(); - drop(state); - let mut state = session.state.write(); - state.keys[0] = state.keys[1].take(); - } + let tag = c.finish(); + key.return_receive_cipher(c); - if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); - } else { - unlikely_branch(); - return Ok(ReceiveResult::Ok); + if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) { + // Drop obsolete keys if we had to iterate past the first key to get here. + if key_index > 0 { + unlikely_branch(); + drop(state); + let mut state = session.state.write(); + if state.keys.len() == key_count { + for _ in 0..key_index { + let _ = state.keys.pop_front(); + } } } + + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); + } else { + unlikely_branch(); + return Ok(ReceiveResult::Ok); + } } } return Err(Error::FailedAuthentication); @@ -541,7 +605,7 @@ impl ReceiveContext { let hmac1_end = incoming_packet_len - HMAC_SIZE; // Check that the sender knows this host's identity before doing anything else. - if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[..hmac1_end]).eq(&incoming_packet[hmac1_end..]) { + if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[HEADER_CHECK_SIZE..hmac1_end]).eq(&incoming_packet[hmac1_end..]) { return Err(Error::FailedAuthentication); } @@ -561,7 +625,7 @@ impl ReceiveContext { let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes())); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false); - c.init(&get_aes_gcm_nonce(incoming_packet)); + c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); let c = c.finish(); if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { @@ -582,7 +646,12 @@ impl ReceiveContext { let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); - if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]).eq(&incoming_packet[aes_gcm_tag_end..hmac1_end]) { + if !hmac_sha384( + kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), + &original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end], + ) + .eq(&incoming_packet[aes_gcm_tag_end..hmac1_end]) + { return Err(Error::FailedAuthentication); } @@ -596,15 +665,17 @@ impl ReceiveContext { None } else { if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) { + let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); Some(Session:: { id: new_session_id, associated_object, send_counter: Counter::new(), psk, ss, + header_check_cipher, state: RwLock::new(MutableState { remote_session_id: Some(alice_session_id), - keys: [None, None], + keys: LinkedList::new(), offer: None, }), remote_s_public_hash: SHA384::hash(&alice_s_public), @@ -656,13 +727,14 @@ impl ReceiveContext { } else { rp.write_all(&[E1_TYPE_NONE])?; } + (MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len() }; - let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter); - reply_buf[..HEADER_SIZE].copy_from_slice(&header); + + send_with_fragmentation_init_header(&mut reply_buf, reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true); - c.init(&get_aes_gcm_nonce(&header)); + c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]); let c = c.finish(); reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c); @@ -674,18 +746,19 @@ impl ReceiveContext { // key derivation step. let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes())); - let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[..reply_len]); + let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[HEADER_CHECK_SIZE..reply_len]); reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac); reply_len += HMAC_SIZE; let mut state = session.state.write(); let _ = state.remote_session_id.replace(alice_session_id); - state.keys[1].replace(SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi)); + add_key(&mut state.keys, SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi)); drop(state); // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. - send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &mut header); + send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &session.header_check_cipher); + if new_session.is_some() { return Ok(ReceiveResult::OkNewSession(new_session.unwrap())); } else { @@ -716,7 +789,7 @@ impl ReceiveContext { )); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false); - c.init(&get_aes_gcm_nonce(incoming_packet)); + c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); let c = c.finish(); if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { @@ -739,8 +812,11 @@ impl ReceiveContext { let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes())); - if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]) - .eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()]) + if !hmac_sha384( + kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), + &original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end], + ) + .eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()]) { return Err(Error::FailedAuthentication); } @@ -751,20 +827,22 @@ impl ReceiveContext { let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi); let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; - let header = send_with_fragmentation_init_header(HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter); - reply_buf[..HEADER_SIZE].copy_from_slice(&header); + send_with_fragmentation_init_header(&mut reply_buf, HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter); let mut c = key.get_send_cipher(counter)?; - c.init(&get_aes_gcm_nonce(&reply_buf)); + c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish()); key.return_send_cipher(c); + let hc = header_check(&reply_buf, &session.header_check_cipher); + reply_buf[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes()); + send(&mut reply_buf); let mut state = RwLockUpgradableReadGuard::upgrade(state); let _ = state.remote_session_id.replace(bob_session_id); let _ = state.offer.take(); - let _ = state.keys[0].insert(key); + add_key(&mut state.keys, key); return Ok(ReceiveResult::Ok); } @@ -861,6 +939,7 @@ impl EphemeralOffer { bob_s_public_p384: &P384PublicKey, bob_s_public_hash: &[u8], ss: &Secret<48>, + header_check_cipher: &Aes, mtu: usize, current_time: i64, jedi: bool, @@ -900,8 +979,7 @@ impl EphemeralOffer { PACKET_BUF_SIZE - p.len() }; - let mut header = send_with_fragmentation_init_header(packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter); - packet_buf[..HEADER_SIZE].copy_from_slice(&header); + send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter); let key = Secret(hmac_sha512( &hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()), @@ -910,7 +988,7 @@ impl EphemeralOffer { let gcm_tag = { let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true); - c.init(&get_aes_gcm_nonce(&packet_buf)); + c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]); c.finish() }; @@ -919,15 +997,15 @@ impl EphemeralOffer { let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); - let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[..packet_len]); + let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]); packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); packet_len += HMAC_SIZE; - let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[..packet_len]); + let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]); packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); packet_len += HMAC_SIZE; - send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, &mut header); + send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher); Ok(EphemeralOffer { creation_time: current_time, @@ -939,7 +1017,7 @@ impl EphemeralOffer { } #[inline(always)] -fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] { +fn send_with_fragmentation_init_header(header: &mut [u8], packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) { let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize; debug_assert!(mtu >= MIN_MTU); debug_assert!(packet_len >= HEADER_SIZE); @@ -947,27 +1025,26 @@ fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_typ debug_assert!(fragment_count > 0); debug_assert!(packet_type <= 0x0f); // packet type is 4 bits debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits - - // Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter - ((((((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u64) | recipient_session_id.wrapping_shl(16)) as u128) | (counter.to_u32() as u128).wrapping_shl(64)).to_le_bytes() - [..HEADER_SIZE] - .try_into() - .unwrap() + header[HEADER_CHECK_SIZE..12].copy_from_slice(&(recipient_session_id | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)).to_le_bytes()); + header[12..HEADER_SIZE].copy_from_slice(&counter.to_u32().to_le_bytes()); } -#[inline(always)] -fn send_with_fragmentation(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header: &mut [u8; HEADER_SIZE]) { +/// Send a packet in fragments (used for everything but DATA which has a hand-rolled version for performance). +fn send_with_fragmentation(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header_check_cipher: &Aes) { let packet_len = packet.len(); let mut fragment_start = 0; let mut fragment_end = packet_len.min(mtu); + let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = packet[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap(); loop { + let hc = header_check(&packet[fragment_start..], header_check_cipher); + packet[fragment_start..(fragment_start + HEADER_CHECK_SIZE)].copy_from_slice(&hc.to_ne_bytes()); send(&mut packet[fragment_start..fragment_end]); if fragment_end < packet_len { + debug_assert!(header[7].wrapping_shr(2) < 63); + header[7] += 0x04; // increment fragment number fragment_start = fragment_end - HEADER_SIZE; fragment_end = (fragment_start + mtu).min(packet_len); - debug_assert!(header[1].wrapping_shr(2) < 63); - header[1] += 0x04; // increment fragment number in least significant 6 bits of byte 1 - packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header); + packet[(fragment_start + HEADER_CHECK_SIZE)..(fragment_start + HEADER_SIZE)].copy_from_slice(&header); } else { debug_assert_eq!(fragment_end, packet_len); break; @@ -975,6 +1052,31 @@ fn send_with_fragmentation(send: &mut SendFuncti } } +/// Compute the 32-bit header check code for a packet on receipt or right before send. +fn header_check(packet: &[u8], header_check_cipher: &Aes) -> u32 { + debug_assert!(packet.len() >= HEADER_SIZE); + let mut header_check = 0u128.to_ne_bytes(); + if packet.len() >= (16 + HEADER_CHECK_SIZE) { + header_check_cipher.encrypt_block(&packet[HEADER_CHECK_SIZE..(16 + HEADER_CHECK_SIZE)], &mut header_check); + } else { + unlikely_branch(); + header_check[..(packet.len() - HEADER_CHECK_SIZE)].copy_from_slice(&packet[HEADER_CHECK_SIZE..]); + header_check_cipher.encrypt_block_in_place(&mut header_check); + } + memory::u32_from_ne_bytes(&header_check) +} + +/// Add a new session key to the key list, retiring older non-active keys if necessary. +fn add_key(keys: &mut LinkedList, key: SessionKey) { + debug_assert!(KEY_HISTORY_SIZE_MAX >= 2); + while keys.len() >= KEY_HISTORY_SIZE_MAX { + let current = keys.pop_front().unwrap(); + let _ = keys.pop_front(); + keys.push_front(current); + } + keys.push_back(key); +} + fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<(SessionId, &[u8], &[u8], &[u8]), Error> { let mut p = &incoming_packet[..]; let alice_session_id = SessionId::new_from_reader(&mut p)?; @@ -1083,13 +1185,6 @@ fn kbkdf512(key: &[u8], label: u8) -> Secret<64> { Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00])) } -#[inline(always)] -fn get_aes_gcm_nonce(packet: &[u8]) -> [u8; 16] { - let mut tmp = 0u128.to_ne_bytes(); - tmp[..HEADER_SIZE].copy_from_slice(&packet[..HEADER_SIZE]); - tmp -} - #[cfg(test)] mod tests { use parking_lot::Mutex; @@ -1175,7 +1270,8 @@ mod tests { random::fill_bytes_secure(&mut psk.0); let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob")); let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice")); - let rc: Box>> = Box::new(ReceiveContext::new()); + let alice_rc: Box>> = Box::new(ReceiveContext::new(&alice_host)); + let bob_rc: Box>> = Box::new(ReceiveContext::new(&bob_host)); let mut data_buf = [0_u8; 4096]; //println!("zssp: size of session (bytes): {}", std::mem::size_of::>>()); @@ -1207,6 +1303,12 @@ mod tests { } }; + let rc = if std::ptr::eq(host, &alice_host) { + &alice_rc + } else { + &bob_rc + }; + loop { if let Some(qi) = host.queue.lock().pop_back() { let qi_len = qi.len(); @@ -1232,7 +1334,8 @@ mod tests { } } } else { - println!("zssp: {} => {}: error: {}", host.other_name, host.this_name, r.err().unwrap().to_string()); + println!("zssp: {} => {} ({}): error: {}", host.other_name, host.this_name, qi_len, r.err().unwrap().to_string()); + panic!(); } } else { break; diff --git a/utils/src/memory.rs b/utils/src/memory.rs index 2404334f6..5f8e45cd1 100644 --- a/utils/src/memory.rs +++ b/utils/src/memory.rs @@ -39,6 +39,42 @@ mod fast_int_memory_access { unsafe { *b.as_ptr().cast() } } + #[inline(always)] + pub fn u64_from_ne_bytes(b: &[u8]) -> u64 { + assert!(b.len() >= 8); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn u32_from_ne_bytes(b: &[u8]) -> u32 { + assert!(b.len() >= 4); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn u16_from_ne_bytes(b: &[u8]) -> u16 { + assert!(b.len() >= 2); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn i64_from_ne_bytes(b: &[u8]) -> i64 { + assert!(b.len() >= 8); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn i32_from_ne_bytes(b: &[u8]) -> i32 { + assert!(b.len() >= 4); + unsafe { *b.as_ptr().cast() } + } + + #[inline(always)] + pub fn i16_from_ne_bytes(b: &[u8]) -> i16 { + assert!(b.len() >= 2); + unsafe { *b.as_ptr().cast() } + } + #[inline(always)] pub fn u64_from_be_bytes(b: &[u8]) -> u64 { assert!(b.len() >= 8); @@ -109,6 +145,36 @@ mod fast_int_memory_access { i16::from_le_bytes(b[..2].try_into().unwrap()) } + #[inline(always)] + pub fn u64_from_ne_bytes(b: &[u8]) -> u64 { + u64::from_ne_bytes(b[..8].try_into().unwrap()) + } + + #[inline(always)] + pub fn u32_from_ne_bytes(b: &[u8]) -> u32 { + u32::from_ne_bytes(b[..4].try_into().unwrap()) + } + + #[inline(always)] + pub fn u16_from_ne_bytes(b: &[u8]) -> u16 { + u16::from_ne_bytes(b[..2].try_into().unwrap()) + } + + #[inline(always)] + pub fn i64_from_ne_bytes(b: &[u8]) -> i64 { + i64::from_ne_bytes(b[..8].try_into().unwrap()) + } + + #[inline(always)] + pub fn i32_from_ne_bytes(b: &[u8]) -> i32 { + i32::from_ne_bytes(b[..4].try_into().unwrap()) + } + + #[inline(always)] + pub fn i16_from_ne_bytes(b: &[u8]) -> i16 { + i16::from_ne_bytes(b[..2].try_into().unwrap()) + } + #[inline(always)] pub fn u64_from_be_bytes(b: &[u8]) -> u64 { u64::from_be_bytes(b[..8].try_into().unwrap()) diff --git a/utils/src/ringbuffermap.rs b/utils/src/ringbuffermap.rs index 989bbca47..9da5c3a31 100644 --- a/utils/src/ringbuffermap.rs +++ b/utils/src/ringbuffermap.rs @@ -4,6 +4,8 @@ use std::hash::{Hash, Hasher}; use std::mem::MaybeUninit; +const EMPTY: u16 = 0xffff; + #[inline(always)] fn xorshift64(mut x: u64) -> u64 { x ^= x.wrapping_shl(13); @@ -78,9 +80,9 @@ impl Hasher for XorShiftHasher { struct Entry { key: MaybeUninit, value: MaybeUninit, - bucket: i32, // which bucket is this in? -1 for none - next: i32, // next item in bucket's linked list, -1 for none - prev: i32, // previous entry to permit deletion of old entries from bucket lists + bucket: u16, // which bucket is this in? EMPTY for none + next: u16, // next item in bucket's linked list, EMPTY for none + prev: u16, // previous entry to permit deletion of old entries from bucket lists } /// A hybrid between a circular buffer and a map. @@ -90,35 +92,35 @@ struct Entry { /// with a HashMap but that would be less efficient. This requires no memory allocations unless /// the K or V types allocate memory and occupies a fixed amount of memory. /// -/// This is pretty basic and doesn't have a remove function. Old entries just roll off. This -/// only contains what is needed elsewhere in the project. +/// There is no explicit remove since that would require more complex logic to maintain FIFO +/// ordering for replacement of entries. Old entries just roll off the end. +/// +/// This is used for things like defragmenting incoming packets to support multiple fragmented +/// packets in flight. Having no allocations is good to reduce the potential for memory +/// exhaustion attacks. /// /// The C template parameter is the total capacity while the B parameter is the number of -/// buckets in the hash table. +/// buckets in the hash table. The maximum for both these parameters is 65535. This could be +/// increased by making the index variables larger (e.g. u32 instead of u16). pub struct RingBufferMap { - entries: [Entry; C], - buckets: [i32; B], - entry_ptr: u32, salt: u32, + entries: [Entry; C], + buckets: [u16; B], + entry_ptr: u16, } impl RingBufferMap { + /// Create a new map with the supplied random salt to perturb the hashing function. #[inline] pub fn new(salt: u32) -> Self { - Self { - entries: { - let mut entries: [Entry; C] = unsafe { MaybeUninit::uninit().assume_init() }; - for e in entries.iter_mut() { - e.bucket = -1; - e.next = -1; - e.prev = -1; - } - entries - }, - buckets: [-1; B], - entry_ptr: 0, - salt, - } + debug_assert!(C <= EMPTY as usize); + debug_assert!(B <= EMPTY as usize); + let mut tmp: Self = unsafe { MaybeUninit::uninit().assume_init() }; + // EMPTY is the maximum value of the indices, which is all 0xff, so this sets all indices to EMPTY. + unsafe { std::ptr::write_bytes(&mut tmp, 0xff, 1) }; + tmp.salt = salt; + tmp.entry_ptr = 0; + tmp } #[inline] @@ -126,9 +128,9 @@ impl RingBu let mut h = XorShiftHasher::new(self.salt); key.hash(&mut h); let mut e = self.buckets[(h.finish() as usize) % B]; - while e >= 0 { + while e != EMPTY { let ee = &self.entries[e as usize]; - debug_assert!(ee.bucket >= 0); + debug_assert!(ee.bucket != EMPTY); if unsafe { ee.key.assume_init_ref().eq(key) } { return Some(unsafe { &ee.value.assume_init_ref() }); } @@ -137,7 +139,7 @@ impl RingBu return None; } - /// Get an entry, creating if not present. + /// Get an entry, creating if not present, and return a mutable reference to it. #[inline] pub fn get_or_create_mut V>(&mut self, key: &K, create: CF) -> &mut V { let mut h = XorShiftHasher::new(self.salt); @@ -145,10 +147,10 @@ impl RingBu let bucket = (h.finish() as usize) % B; let mut e = self.buckets[bucket]; - while e >= 0 { + while e != EMPTY { unsafe { let e_ptr = &mut *self.entries.as_mut_ptr().add(e as usize); - debug_assert!(e_ptr.bucket >= 0); + debug_assert!(e_ptr.bucket != EMPTY); if e_ptr.key.assume_init_ref().eq(key) { return e_ptr.value.assume_init_mut(); } @@ -167,9 +169,9 @@ impl RingBu let bucket = (h.finish() as usize) % B; let mut e = self.buckets[bucket]; - while e >= 0 { + while e != EMPTY { let e_ptr = &mut self.entries[e as usize]; - debug_assert!(e_ptr.bucket >= 0); + debug_assert!(e_ptr.bucket != EMPTY); if unsafe { e_ptr.key.assume_init_ref().eq(&key) } { unsafe { *e_ptr.value.assume_init_mut() = value }; return; @@ -177,7 +179,7 @@ impl RingBu e = e_ptr.next; } - self.internal_add(bucket, key, value); + let _ = self.internal_add(bucket, key, value); } #[inline] @@ -186,8 +188,8 @@ impl RingBu self.entry_ptr = self.entry_ptr.wrapping_add(1); let e_ptr = unsafe { &mut *self.entries.as_mut_ptr().add(e) }; - if e_ptr.bucket >= 0 { - if e_ptr.prev >= 0 { + if e_ptr.bucket != EMPTY { + if e_ptr.prev != EMPTY { self.entries[e_ptr.prev as usize].next = e_ptr.next; } else { self.buckets[e_ptr.bucket as usize] = e_ptr.next; @@ -200,13 +202,13 @@ impl RingBu e_ptr.key.write(key); e_ptr.value.write(value); - e_ptr.bucket = bucket as i32; + e_ptr.bucket = bucket as u16; e_ptr.next = self.buckets[bucket]; - if e_ptr.next >= 0 { - self.entries[e_ptr.next as usize].prev = e as i32; + if e_ptr.next != EMPTY { + self.entries[e_ptr.next as usize].prev = e as u16; } - self.buckets[bucket] = e as i32; - e_ptr.prev = -1; + self.buckets[bucket] = e as u16; + e_ptr.prev = EMPTY; unsafe { e_ptr.value.assume_init_mut() } } } @@ -215,7 +217,7 @@ impl Drop f #[inline] fn drop(&mut self) { for e in self.entries.iter_mut() { - if e.bucket >= 0 { + if e.bucket != EMPTY { unsafe { e.key.assume_init_drop(); e.value.assume_init_drop();