From d5943f246ab5b2335366961bb09d48ee903b04f6 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Tue, 6 Sep 2022 18:31:19 -0400 Subject: [PATCH] More cleanup in session. --- core-crypto/src/zssp.rs | 97 +++++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 47 deletions(-) diff --git a/core-crypto/src/zssp.rs b/core-crypto/src/zssp.rs index 2253d065a..82861904c 100644 --- a/core-crypto/src/zssp.rs +++ b/core-crypto/src/zssp.rs @@ -261,9 +261,10 @@ pub struct ReceiveContext { } impl Session { + #[inline] pub fn new( host: &H, - send: SendFunction, + mut send: SendFunction, local_session_id: SessionId, remote_s_public: &[u8], offer_metadata: &[u8], @@ -278,7 +279,7 @@ impl Session { let counter = Counter::new(); let remote_s_public_hash = SHA384::hash(remote_s_public); if let Ok(offer) = EphemeralOffer::create_alice_offer( - send, + &mut send, counter.next(), local_session_id, None, @@ -312,13 +313,14 @@ impl Session { return Err(Error::InvalidParameter); } - pub fn rekey_check(&self, host: &H, send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) { + #[inline] + pub fn rekey_check(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) { let state = self.state.upgradable_read(); if let Some(key) = state.keys[0].as_ref() { if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) { if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) { if let Ok(offer) = EphemeralOffer::create_alice_offer( - send, + &mut send, self.send_counter.next(), self.id, state.remote_session_id, @@ -340,16 +342,18 @@ impl Session { } impl ReceiveContext { + #[inline] pub fn new() -> Self { Self { initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), } } + #[inline] pub fn receive<'a, SendFunction: FnMut(&mut [u8])>( &self, host: &H, - send: SendFunction, + mut send: SendFunction, data_buf: &'a mut [u8], incoming_packet_buf: H::IncomingPacketBuffer, mtu: usize, @@ -377,7 +381,7 @@ impl ReceiveContext { let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock - return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time); + return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time); } } else { unlikely_branch(); @@ -393,7 +397,7 @@ impl ReceiveContext { let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock - return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); + return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); } } else { unlikely_branch(); @@ -403,7 +407,7 @@ impl ReceiveContext { } else { return self.receive_complete( host, - send, + &mut send, data_buf, &[incoming_packet_buf], packet_type, @@ -420,7 +424,7 @@ impl ReceiveContext { fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>( &self, host: &H, - mut send: SendFunction, + send: &mut SendFunction, data_buf: &'a mut [u8], fragments: &[H::IncomingPacketBuffer], packet_type: u8, @@ -475,8 +479,8 @@ impl ReceiveContext { key.return_receive_cipher(c); if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) { + // If this succeeded with the "next" key, promote it to current. if ki == 1 { - // Promote next key to current key on success. unlikely_branch(); drop(state); let mut state = session.state.write(); @@ -486,6 +490,7 @@ impl ReceiveContext { if packet_type == PACKET_TYPE_DATA { return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); } else { + unlikely_branch(); return Ok(ReceiveResult::Ok); } } @@ -517,7 +522,7 @@ impl ReceiveContext { let original_ciphertext = incoming_packet_buf.clone(); let incoming_packet = &mut incoming_packet_buf[..incoming_packet_len]; - if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) { + if incoming_packet_len <= HEADER_SIZE { return Err(Error::InvalidPacket); } if incoming_packet[HEADER_SIZE] != SESSION_PROTOCOL_VERSION { @@ -528,6 +533,9 @@ impl ReceiveContext { PACKET_TYPE_KEY_OFFER => { // alice (remote) -> bob (local) + if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE) { + return Err(Error::InvalidPacket); + } let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE); let aes_gcm_tag_end = incoming_packet_len - (HMAC_SIZE + HMAC_SIZE); let hmac1_end = incoming_packet_len - HMAC_SIZE; @@ -688,6 +696,9 @@ impl ReceiveContext { PACKET_TYPE_KEY_COUNTER_OFFER => { // bob (remote) -> alice (local) + if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) { + return Err(Error::InvalidPacket); + } let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE); let aes_gcm_tag_end = incoming_packet_len - HMAC_SIZE; @@ -712,7 +723,7 @@ impl ReceiveContext { return Err(Error::FailedAuthentication); } - // Alice has now completed Noise_IK for P-384 and verified with GCM auth, now for the hybrid add-on. + // Alice has now completed Noise_IK with NIST P-384 and verified with GCM auth, but now for hybrid... let (bob_session_id, _, _, bob_e1_public) = parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?; @@ -734,36 +745,26 @@ impl ReceiveContext { return Err(Error::FailedAuthentication); } - // Alice has now completed and validated the full hybrid exchange. If this is the first exchange send - // a NOP back to Bob to acknowledge that the session is open and can now be used. Otherwise just queue - // this up as the next key to be promoted to current when Bob uses it. + // Alice has now completed and validated the full hybrid exchange. + + let counter = session.send_counter.next(); + let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi); + + let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; + let header = send_with_fragmentation_init_header(HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter); + reply_buf[..HEADER_SIZE].copy_from_slice(&header); + + let mut c = key.get_send_cipher(counter)?; + c.init(&get_aes_gcm_nonce(&reply_buf)); + reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish()); + key.return_send_cipher(c); + + send(&mut reply_buf); let mut state = RwLockUpgradableReadGuard::upgrade(state); - let _ = state.offer.take(); let _ = state.remote_session_id.replace(bob_session_id); - if state.keys[0].is_some() { - let _ = state.keys[1].replace(SessionKey::new(key, Role::Alice, current_time, session.send_counter.current(), jedi)); - } else { - let counter = session.send_counter.next(); - let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi); - - let mut reply_buf = [0_u8; MIN_MTU]; - let dummy_data_len = (random::next_u32_secure() % (mtu - (HEADER_SIZE + AES_GCM_TAG_SIZE)) as u32) as usize; - let reply_len = dummy_data_len + HEADER_SIZE + AES_GCM_TAG_SIZE; - let header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter); - reply_buf[..HEADER_SIZE].copy_from_slice(&header); - - let mut c = key.get_send_cipher(counter)?; - c.init(&get_aes_gcm_nonce(&reply_buf)); - c.crypt_in_place(&mut reply_buf[HEADER_SIZE..(HEADER_SIZE + dummy_data_len)]); - reply_buf[(HEADER_SIZE + dummy_data_len)..reply_len].copy_from_slice(&c.finish()); - key.return_send_cipher(c); - - send(&mut reply_buf[..reply_len]); - - let _ = state.keys[0].replace(key); - let _ = state.keys[1].take(); - } + let _ = state.offer.take(); + let _ = state.keys[0].insert(key); return Ok(ReceiveResult::Ok); } @@ -783,6 +784,7 @@ impl ReceiveContext { struct Counter(AtomicU64); impl Counter { + #[inline(always)] fn new() -> Self { Self(AtomicU64::new(random::next_u32_secure() as u64)) } @@ -850,7 +852,7 @@ struct EphemeralOffer { impl EphemeralOffer { fn create_alice_offer( - send: SendFunction, + send: &mut SendFunction, counter: CounterValue, alice_session_id: SessionId, bob_session_id: Option, @@ -947,14 +949,14 @@ fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_typ debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits // Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter - let mut header = ((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u128; - header |= recipient_session_id.wrapping_shl(16) as u128; - header |= (counter.to_u32() as u128).wrapping_shl(64); - header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap() + ((((((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u64) | recipient_session_id.wrapping_shl(16)) as u128) | (counter.to_u32() as u128).wrapping_shl(64)).to_le_bytes() + [..HEADER_SIZE] + .try_into() + .unwrap() } #[inline(always)] -fn send_with_fragmentation(mut send: SendFunction, packet: &mut [u8], mtu: usize, header: &mut [u8; HEADER_SIZE]) { +fn send_with_fragmentation(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header: &mut [u8; HEADER_SIZE]) { let packet_len = packet.len(); let mut fragment_start = 0; let mut fragment_end = packet_len.min(mtu); @@ -1168,6 +1170,7 @@ mod tests { #[allow(unused_variables)] #[test] fn establish_session() { + let jedi = true; let mut psk: Secret<64> = Secret::default(); random::fill_bytes_secure(&mut psk.0); let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob")); @@ -1188,7 +1191,7 @@ mod tests { 1, 1280, 1, - true, + jedi, ) .unwrap(), )); @@ -1208,7 +1211,7 @@ mod tests { if let Some(qi) = host.queue.lock().pop_back() { let qi_len = qi.len(); ts += 1; - let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, true, ts); + let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, jedi, ts); if r.is_ok() { let r = r.unwrap(); match r {