From f83bf41427287f494ffefc72a0b363d2f4217694 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Fri, 10 Mar 2023 16:58:38 -0500 Subject: [PATCH] A bunch of ZSSP cleanup and optimization. Runs a bit faster now. --- zssp/src/sessionid.rs | 27 ++- zssp/src/zssp.rs | 547 +++++++++++++++++++++--------------------- 2 files changed, 292 insertions(+), 282 deletions(-) diff --git a/zssp/src/sessionid.rs b/zssp/src/sessionid.rs index be272034c..e1fb99b64 100644 --- a/zssp/src/sessionid.rs +++ b/zssp/src/sessionid.rs @@ -10,7 +10,6 @@ use std::fmt::Display; use std::num::NonZeroU64; use zerotier_crypto::random; -use zerotier_utils::memory::{array_range, as_byte_array}; /// 48-bit session ID (most significant 16 bits of u64 are unused) #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -25,6 +24,7 @@ impl SessionId { pub const MAX: u64 = 0xffffffffffff; /// Create a new session ID, panicing if 'i' is zero or exceeds MAX. + #[inline(always)] pub fn new(i: u64) -> SessionId { assert!(i <= Self::MAX); Self(NonZeroU64::new(i.to_le()).unwrap()) @@ -35,22 +35,23 @@ impl SessionId { Self(NonZeroU64::new(((random::xorshift64_random() % (Self::MAX - 1)) + 1).to_le()).unwrap()) } - pub(crate) fn new_from_bytes(b: &[u8; Self::SIZE]) -> Option { - let mut tmp = [0u8; 8]; + #[inline(always)] + pub fn to_bytes(&self) -> [u8; Self::SIZE] { + self.0.get().to_ne_bytes()[..Self::SIZE].try_into().unwrap() + } + + #[inline(always)] + pub fn new_from_bytes(b: &[u8]) -> Option { + let mut tmp = 0u64.to_ne_bytes(); tmp[..SESSION_ID_SIZE_BYTES].copy_from_slice(b); - Self::new_from_u64_le(u64::from_ne_bytes(tmp)) + NonZeroU64::new(u64::from_ne_bytes(tmp)).map(|i| Self(i)) } - /// Create from a u64 that is already in little-endian byte order. #[inline(always)] - pub(crate) fn new_from_u64_le(i: u64) -> Option { - NonZeroU64::new(i & Self::MAX.to_le()).map(|i| Self(i)) - } - - /// Get this session ID as a little-endian byte array. - #[inline(always)] - pub(crate) fn as_bytes(&self) -> &[u8; Self::SIZE] { - array_range::(as_byte_array(&self.0)) + pub fn new_from_array(b: &[u8; Self::SIZE]) -> Option { + let mut tmp = 0u64.to_ne_bytes(); + tmp[..SESSION_ID_SIZE_BYTES].copy_from_slice(b); + NonZeroU64::new(u64::from_ne_bytes(tmp)).map(|i| Self(i)) } } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 0d091a64b..ee7842e35 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -56,7 +56,7 @@ struct SessionsById { active: HashMap>>, // Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout. - incoming: HashMap>, + incoming: HashMap>>, } /// Result generated by the context packet receive function, with possible payloads. @@ -97,10 +97,10 @@ struct State { remote_session_id: Option, keys: [Option; 2], current_key: usize, - current_offer: Offer, + outgoing_offer: Offer, } -struct IncomingIncompleteSession { +struct IncomingIncompleteSession { timestamp: i64, alice_session_id: SessionId, bob_session_id: SessionId, @@ -109,6 +109,7 @@ struct IncomingIncompleteSession { hk: Secret, header_protection_key: Secret, bob_noise_e_secret: P384KeyPair, + defrag: [Mutex>; MAX_NOISE_HANDSHAKE_FRAGMENTS], } struct OutgoingSessionOffer { @@ -184,7 +185,7 @@ impl Context { for (id, s) in sessions.active.iter() { if let Some(session) = s.upgrade() { let state = session.state.read().unwrap(); - if match &state.current_offer { + if match &state.outgoing_offer { Offer::None => true, Offer::NoiseXKInit(offer) => { // If there's an outstanding attempt to open a session, retransmit this periodically @@ -324,7 +325,7 @@ impl Context { remote_session_id: None, keys: [None, None], current_key: 0, - current_offer: Offer::NoiseXKInit(Box::new(OutgoingSessionOffer { + outgoing_offer: Offer::NoiseXKInit(Box::new(OutgoingSessionOffer { last_retry_time: AtomicI64::new(current_time), psk, noise_h: mix_hash(&mix_hash(&INITIAL_H, remote_s_public_blob), &alice_noise_e), @@ -345,7 +346,7 @@ impl Context { { let mut state = session.state.write().unwrap(); - let offer = if let Offer::NoiseXKInit(offer) = &mut state.current_offer { + let offer = if let Offer::NoiseXKInit(offer) = &mut state.outgoing_offer { offer } else { panic!(); // should be impossible as this is what we initialized with @@ -357,7 +358,7 @@ impl Context { let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(init_packet).unwrap(); init.session_protocol_version = SESSION_PROTOCOL_VERSION; init.alice_noise_e = alice_noise_e; - init.alice_session_id = *local_session_id.as_bytes(); + init.alice_session_id = local_session_id.to_bytes(); init.alice_hk_public = alice_hk_secret.public; init.header_protection_key = header_protection_key.0; } @@ -417,6 +418,7 @@ impl Context { /// * `data_buf` - Buffer to receive decrypted and authenticated object data (an error is returned if too small) /// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership) /// * `current_time` - Current monotonic time in milliseconds + #[inline] pub fn receive< 'b, SendFunction: FnMut(Option<&Arc>>, &mut [u8]), @@ -430,112 +432,83 @@ impl Context { mut send: SendFunction, source: &Application::PhysicalPath, data_buf: &'b mut [u8], - mut incoming_packet_buf: Application::IncomingPacketBuffer, + mut incoming_physical_packet_buf: Application::IncomingPacketBuffer, current_time: i64, ) -> Result, Error> { - let incoming_packet: &mut [u8] = incoming_packet_buf.as_mut(); - if incoming_packet.len() < MIN_PACKET_SIZE { + let incoming_physical_packet: &mut [u8] = incoming_physical_packet_buf.as_mut(); + if incoming_physical_packet.len() < MIN_PACKET_SIZE { return Err(Error::InvalidPacket); } - let mut incoming = None; - if let Some(local_session_id) = SessionId::new_from_u64_le(u64::from_le_bytes(incoming_packet[0..8].try_into().unwrap())) { - if let Some(session) = self.sessions.read().unwrap().active.get(&local_session_id).and_then(|s| s.upgrade()) { + if let Some(local_session_id) = SessionId::new_from_bytes(&incoming_physical_packet[0..SessionId::SIZE]) { + let sessions = self.sessions.read().unwrap(); + if let Some(session) = sessions.active.get(&local_session_id).and_then(|s| s.upgrade()) { + drop(sessions); debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id)); session .header_protection_cipher - .decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); + .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); if session.check_receive_window(incoming_counter) { - if fragment_count > 1 { - let mut fragged = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO].lock().unwrap(); - if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) { - drop(fragged); - return self.process_complete_incoming_packet( - app, - &mut send, - &mut check_allow_incoming_session, - &mut check_accept_session, - data_buf, - incoming_counter, - assembled_packet.as_ref(), - packet_type, - Some(session), - None, - key_index, - current_time, - ); + let (assembled_packet, incoming_packet_buf_arr); + let incoming_packet = if fragment_count > 1 { + assembled_packet = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO] + .lock() + .unwrap() + .assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if let Some(assembled_packet) = assembled_packet.as_ref() { + assembled_packet.as_ref() } else { - drop(fragged); return Ok(ReceiveResult::Ok(Some(session))); } } else { - return self.process_complete_incoming_packet( - app, - &mut send, - &mut check_allow_incoming_session, - &mut check_accept_session, - data_buf, - incoming_counter, - &[incoming_packet_buf], - packet_type, - Some(session), - None, - key_index, - current_time, - ); - } + incoming_packet_buf_arr = [incoming_physical_packet_buf]; + &incoming_packet_buf_arr + }; + + return self.process_complete_incoming_packet( + app, + &mut send, + &mut check_allow_incoming_session, + &mut check_accept_session, + data_buf, + incoming_counter, + incoming_packet, + packet_type, + Some(session), + None, + key_index, + current_time, + ); } else { return Err(Error::OutOfSequence); } - } else { - if let Some(i) = self.sessions.read().unwrap().incoming.get(&local_session_id).cloned() { - Aes::new(&i.header_protection_key) - .decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - incoming = Some(i); - } else { - return Err(Error::UnknownLocalSessionId); - } - } - } + } else if let Some(incoming) = sessions.incoming.get(&local_session_id).cloned() { + drop(sessions); + debug_assert!(!self.sessions.read().unwrap().active.contains_key(&local_session_id)); - // If we make it here the packet is not associated with a session or is associated with an - // incoming session (Noise_XK mid-negotiation). + Aes::new(&incoming.header_protection_key) + .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); - let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); - if fragment_count > 1 { - let f = { - let mut defrag = self.defrag.lock().unwrap(); - let f = defrag - .entry((source.clone(), incoming_counter)) - .or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time))) - .clone(); - - // Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions. - if defrag.len() >= self.max_incomplete_session_queue_size { - // First, drop all entries that are timed out or whose physical source duplicates another entry. - let mut sources = HashSet::with_capacity(defrag.len()); - let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; - defrag.retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f)); - - // Then, if we are still at or over the limit, drop 10% of remaining entries at random. - if defrag.len() >= self.max_incomplete_session_queue_size { - let mut rn = random::next_u32_secure(); - defrag.retain(|_, fragged| { - rn = prng32(rn); - rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f) - }); + let (assembled_packet, incoming_packet_buf_arr); + let incoming_packet = if fragment_count > 1 { + assembled_packet = incoming.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO] + .lock() + .unwrap() + .assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if let Some(assembled_packet) = assembled_packet.as_ref() { + assembled_packet.as_ref() + } else { + return Ok(ReceiveResult::Ok(None)); } - } + } else { + incoming_packet_buf_arr = [incoming_physical_packet_buf]; + &incoming_packet_buf_arr + }; - f - }; - let mut fragged = f.0.lock().unwrap(); - - if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) { - self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter)); return self.process_complete_incoming_packet( app, &mut send, @@ -543,15 +516,63 @@ impl Context { &mut check_accept_session, data_buf, incoming_counter, - assembled_packet.as_ref(), + incoming_packet, packet_type, None, - incoming, + Some(incoming), key_index, current_time, ); + } else { + return Err(Error::UnknownLocalSessionId); } } else { + let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); + + let (assembled_packet, incoming_packet_buf_arr); + let incoming_packet = if fragment_count > 1 { + assembled_packet = { + let mut defrag = self.defrag.lock().unwrap(); + let f = defrag + .entry((source.clone(), incoming_counter)) + .or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time))) + .clone(); + + // Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions. + if defrag.len() >= self.max_incomplete_session_queue_size { + // First, drop all entries that are timed out or whose physical source duplicates another entry. + let mut sources = HashSet::with_capacity(defrag.len()); + let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; + defrag + .retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f)); + + // Then, if we are still at or over the limit, drop 10% of remaining entries at random. + if defrag.len() >= self.max_incomplete_session_queue_size { + let mut rn = random::next_u32_secure(); + defrag.retain(|_, fragged| { + rn = prng32(rn); + rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f) + }); + } + } + + f + } + .0 + .lock() + .unwrap() + .assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if let Some(assembled_packet) = assembled_packet.as_ref() { + self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter)); + assembled_packet.as_ref() + } else { + return Ok(ReceiveResult::Ok(None)); + } + } else { + incoming_packet_buf_arr = [incoming_physical_packet_buf]; + &incoming_packet_buf_arr + }; + return self.process_complete_incoming_packet( app, &mut send, @@ -559,16 +580,14 @@ impl Context { &mut check_accept_session, data_buf, incoming_counter, - &[incoming_packet_buf], + incoming_packet, packet_type, None, - incoming, + None, key_index, current_time, ); } - - return Ok(ReceiveResult::Ok(None)); } fn process_complete_incoming_packet< @@ -587,7 +606,7 @@ impl Context { fragments: &[Application::IncomingPacketBuffer], packet_type: u8, session: Option>>, - incoming: Option>, + incoming: Option>>, key_index: usize, current_time: i64, ) -> Result, Error> { @@ -651,9 +670,9 @@ impl Context { // If we got a valid data packet from Bob, this means we can cancel any offers // that are still oustanding for initialization. - match &state.current_offer { + match &state.outgoing_offer { Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => { - state.current_offer = Offer::None; + state.outgoing_offer = Offer::None; } _ => {} } @@ -730,7 +749,7 @@ impl Context { } let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?; - let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?; + let alice_session_id = SessionId::new_from_array(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?; let header_protection_key = Secret(pkt.header_protection_key); // Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create @@ -761,7 +780,7 @@ impl Context { let ack: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut ack_packet)?; ack.session_protocol_version = SESSION_PROTOCOL_VERSION; ack.bob_noise_e = bob_noise_e; - ack.bob_session_id = *bob_session_id.as_bytes(); + ack.bob_session_id = bob_session_id.to_bytes(); ack.bob_hk_ciphertext = bob_hk_ciphertext; // Encrypt main section of reply and attach tag. @@ -802,6 +821,7 @@ impl Context { hk, bob_noise_e_secret, header_protection_key: Secret(pkt.header_protection_key), + defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), }), ); debug_assert!(!sessions.active.contains_key(&bob_session_id)); @@ -847,7 +867,7 @@ impl Context { return Err(Error::OutOfSequence); } - if let Offer::NoiseXKInit(outgoing_offer) = &state.current_offer { + if let Offer::NoiseXKInit(outgoing_offer) = &state.outgoing_offer { let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?; // Derive noise_es_ee from Bob's ephemeral public key. @@ -875,7 +895,7 @@ impl Context { let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?; - if let Some(bob_session_id) = SessionId::new_from_bytes(&pkt.bob_session_id) { + if let Some(bob_session_id) = SessionId::new_from_array(&pkt.bob_session_id) { // Complete Noise_XKpsk3 by mixing in noise_se followed by the PSK. The PSK as far as // the Noise pattern is concerned is the result of mixing the externally supplied PSK // with the Kyber1024 shared secret (hk). Kyber is treated as part of the PSK because @@ -948,7 +968,7 @@ impl Context { )); debug_assert!(state.keys[1].is_none()); state.current_key = 0; - state.current_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck { + state.outgoing_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck { last_retry_time: AtomicI64::new(current_time), ack, ack_len, @@ -1071,7 +1091,7 @@ impl Context { None, ], current_key: 0, - current_offer: Offer::None, + outgoing_offer: Offer::None, }), defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), }); @@ -1108,81 +1128,74 @@ impl Context { if let Some(session) = session { let state = session.state.read().unwrap(); - if let Some(remote_session_id) = state.remote_session_id { - if let Some(key) = state.keys[key_index].as_ref() { - // Only the current "Alice" accepts rekeys initiated by the current "Bob." These roles - // flip with each rekey event. - if !key.my_turn_to_rekey { - let mut c = key.get_receive_cipher(incoming_counter); - c.reset_init_gcm(&incoming_message_nonce); - c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]); - let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]); - drop(c); + if let (Some(remote_session_id), Some(key)) = (state.remote_session_id, state.keys[key_index].as_ref()) { + if !key.my_turn_to_rekey && { + let mut c = key.get_receive_cipher(incoming_counter); + c.reset_init_gcm(&incoming_message_nonce); + c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]); + c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]) + } { + let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); + if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) { + let bob_e_secret = P384KeyPair::generate(); + let next_session_key = hmac_sha512_secret( + key.ratchet_key.as_bytes(), + bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?.as_bytes(), + ); - if aead_authentication_ok { - let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); - if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) { - let bob_e_secret = P384KeyPair::generate(); - let next_session_key = hmac_sha512_secret( - key.ratchet_key.as_bytes(), - bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?.as_bytes(), - ); + // Packet fully authenticated + if session.update_receive_window(incoming_counter) { + let mut reply_buf = [0u8; RekeyAck::SIZE]; + let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap(); + reply.session_protocol_version = SESSION_PROTOCOL_VERSION; + reply.bob_e = *bob_e_secret.public_key_bytes(); + reply.next_key_fingerprint = SHA384::hash(next_session_key.as_bytes()); - // Packet fully authenticated - if session.update_receive_window(incoming_counter) { - let mut reply_buf = [0u8; RekeyAck::SIZE]; - let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap(); - reply.session_protocol_version = SESSION_PROTOCOL_VERSION; - reply.bob_e = *bob_e_secret.public_key_bytes(); - reply.next_key_fingerprint = SHA384::hash(next_session_key.as_bytes()); + let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + set_packet_header( + &mut reply_buf, + 1, + 0, + PACKET_TYPE_REKEY_ACK, + u64::from(remote_session_id), + state.current_key, + counter, + ); - let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); - set_packet_header( - &mut reply_buf, - 1, - 0, - PACKET_TYPE_REKEY_ACK, - u64::from(remote_session_id), - state.current_key, - counter, - ); + let mut c = key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter)); + c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]); + reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt()); + drop(c); - let mut c = key.get_send_cipher(counter)?; - c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter)); - c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]); - reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt()); - drop(c); + session + .header_protection_cipher + .encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(Some(&session), &mut reply_buf); - session - .header_protection_cipher - .encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - send(Some(&session), &mut reply_buf); + // The new "Bob" doesn't know yet if Alice has received the new key, so the + // new key is recorded as the "alt" (key_index ^ 1) but the current key is + // not advanced yet. This happens automatically the first time we receive a + // valid packet with the new key. + let next_ratchet_count = key.ratchet_count + 1; + drop(state); + let mut state = session.state.write().unwrap(); + let _ = state.keys[key_index ^ 1].replace(SessionKey::new::( + next_session_key, + next_ratchet_count, + current_time, + counter, + false, + false, + )); - // The new "Bob" doesn't know yet if Alice has received the new key, so the - // new key is recorded as the "alt" (key_index ^ 1) but the current key is - // not advanced yet. This happens automatically the first time we receive a - // valid packet with the new key. - let next_ratchet_count = key.ratchet_count + 1; - drop(state); - let mut state = session.state.write().unwrap(); - let _ = state.keys[key_index ^ 1].replace(SessionKey::new::( - next_session_key, - next_ratchet_count, - current_time, - counter, - false, - false, - )); - - drop(state); - return Ok(ReceiveResult::Ok(Some(session))); - } else { - return Err(Error::OutOfSequence); - } - } + drop(state); + return Ok(ReceiveResult::Ok(Some(session))); + } else { + return Err(Error::OutOfSequence); } - return Err(Error::FailedAuthentication); } + return Err(Error::FailedAuthentication); } } return Err(Error::OutOfSequence); @@ -1201,59 +1214,52 @@ impl Context { if let Some(session) = session { let state = session.state.read().unwrap(); - if let Offer::RekeyInit(alice_e_secret, _) = &state.current_offer { - if let Some(key) = state.keys[key_index].as_ref() { - // Only the current "Bob" initiates rekeys and expects this ACK. - if key.my_turn_to_rekey { - let mut c = key.get_receive_cipher(incoming_counter); - c.reset_init_gcm(&incoming_message_nonce); - c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]); - let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]); - drop(c); + if let (Offer::RekeyInit(alice_e_secret, _), Some(key)) = (&state.outgoing_offer, state.keys[key_index].as_ref()) { + if key.my_turn_to_rekey && { + let mut c = key.get_receive_cipher(incoming_counter); + c.reset_init_gcm(&incoming_message_nonce); + c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]); + c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]) + } { + let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); + if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) { + let next_session_key = hmac_sha512_secret( + key.ratchet_key.as_bytes(), + alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?.as_bytes(), + ); - if aead_authentication_ok { - // Packet fully authenticated + if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(next_session_key.as_bytes())) { + if session.update_receive_window(incoming_counter) { + // The new "Alice" knows Bob has the key since this is an ACK, so she can go + // ahead and set current_key to the new key. Then when she sends something + // to Bob the other side will automatically advance to the new key as well. + let next_ratchet_count = key.ratchet_count + 1; + drop(state); + let next_key_index = key_index ^ 1; + let mut state = session.state.write().unwrap(); + let _ = state.keys[next_key_index].replace(SessionKey::new::( + next_session_key, + next_ratchet_count, + current_time, + session.send_counter.load(Ordering::Relaxed), + true, + true, + )); + state.current_key = next_key_index; // this is an ACK so it's confirmed + state.outgoing_offer = Offer::None; - let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); - if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) { - let next_session_key = hmac_sha512_secret( - key.ratchet_key.as_bytes(), - alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?.as_bytes(), - ); - - if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(next_session_key.as_bytes())) { - if session.update_receive_window(incoming_counter) { - // The new "Alice" knows Bob has the key since this is an ACK, so she can go - // ahead and set current_key to the new key. Then when she sends something - // to Bob the other side will automatically advance to the new key as well. - let next_ratchet_count = key.ratchet_count + 1; - drop(state); - let next_key_index = key_index ^ 1; - let mut state = session.state.write().unwrap(); - let _ = state.keys[next_key_index].replace(SessionKey::new::( - next_session_key, - next_ratchet_count, - current_time, - session.send_counter.load(Ordering::Relaxed), - true, - true, - )); - state.current_key = next_key_index; // this is an ACK so it's confirmed - state.current_offer = Offer::None; - - drop(state); - return Ok(ReceiveResult::Ok(Some(session))); - } else { - return Err(Error::OutOfSequence); - } - } + drop(state); + return Ok(ReceiveResult::Ok(Some(session))); + } else { + return Err(Error::OutOfSequence); } } - return Err(Error::FailedAuthentication); } } + return Err(Error::FailedAuthentication); + } else { + return Err(Error::OutOfSequence); } - return Err(Error::OutOfSequence); } else { return Err(Error::UnknownLocalSessionId); } @@ -1277,51 +1283,49 @@ impl Session { pub fn send(&self, mut send: SendFunction, mtu_sized_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> { debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); let state = self.state.read().unwrap(); - if let Some(remote_session_id) = state.remote_session_id { - if let Some(session_key) = state.keys[state.current_key].as_ref() { - let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) { + let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); - let mut c = session_key.get_send_cipher(counter)?; - c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); + let mut c = session_key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); - let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize; - let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE; - let last_fragment_no = fragment_count - 1; + let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize; + let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE; + let last_fragment_no = fragment_count - 1; - for fragment_no in 0..fragment_count { - let chunk_size = fragment_max_chunk_size.min(data.len()); - let mut fragment_size = chunk_size + HEADER_SIZE; + for fragment_no in 0..fragment_count { + let chunk_size = fragment_max_chunk_size.min(data.len()); + let mut fragment_size = chunk_size + HEADER_SIZE; - set_packet_header( - mtu_sized_buffer, - fragment_count, - fragment_no, - PACKET_TYPE_DATA, - u64::from(remote_session_id), - state.current_key, - counter, - ); + set_packet_header( + mtu_sized_buffer, + fragment_count, + fragment_no, + PACKET_TYPE_DATA, + u64::from(remote_session_id), + state.current_key, + counter, + ); - c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); - data = &data[chunk_size..]; + c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); + data = &data[chunk_size..]; - if fragment_no == last_fragment_no { - debug_assert!(data.is_empty()); - let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE; - mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt()); - fragment_size = tagged_fragment_size; - } - - self.header_protection_cipher - .encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - send(&mut mtu_sized_buffer[..fragment_size]); + if fragment_no == last_fragment_no { + debug_assert!(data.is_empty()); + let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE; + mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt()); + fragment_size = tagged_fragment_size; } - debug_assert!(data.is_empty()); - drop(c); - - return Ok(()); + self.header_protection_cipher + .encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(&mut mtu_sized_buffer[..fragment_size]); } + debug_assert!(data.is_empty()); + + drop(c); + + return Ok(()); } return Err(Error::SessionNotEstablished); } @@ -1329,23 +1333,26 @@ impl Session { /// Send a NOP to the other side (e.g. for keep alive). pub fn send_nop(&self, mut send: SendFunction) -> Result<(), Error> { let state = self.state.read().unwrap(); - if let Some(remote_session_id) = state.remote_session_id { - if let Some(session_key) = state.keys[state.current_key].as_ref() { - let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); - let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; - let mut c = session_key.get_send_cipher(counter)?; - c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); - nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); - drop(c); - set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter); - self.header_protection_cipher - .encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - send(&mut nop); - } + if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) { + let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; + let mut c = session_key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); + nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); + drop(c); + set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter); + self.header_protection_cipher + .encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(&mut nop); } return Err(Error::SessionNotEstablished); } + /// Set the current physical MTU that this session should use to send packets. + pub fn set_physical_mtu(&self, mtu: usize) { + self.state.write().unwrap().physical_mtu = mtu; + } + /// Check whether this session is established. pub fn established(&self) -> bool { let state = self.state.read().unwrap(); @@ -1403,7 +1410,7 @@ impl Session { send(&mut rekey_buf); drop(state); - self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, current_time); + self.state.write().unwrap().outgoing_offer = Offer::RekeyInit(rekey_e, current_time); } } } @@ -1578,6 +1585,7 @@ impl SessionKey { } } + #[inline(always)] fn get_send_cipher<'a>(&'a self, counter: u64) -> Result>, Error> { if counter < self.expire_at_counter { Ok(self.send_cipher_pool[(counter as usize) % GCM_CIPHER_POOL_SIZE].lock().unwrap()) @@ -1586,6 +1594,7 @@ impl SessionKey { } } + #[inline(always)] fn get_receive_cipher<'a>(&'a self, counter: u64) -> MutexGuard<'a, AesGcm> { self.receive_cipher_pool[(counter as usize) % GCM_CIPHER_POOL_SIZE].lock().unwrap() }