From eb0425a28fd701b3a474bd50fe73533eb179f5c0 Mon Sep 17 00:00:00 2001 From: monica Date: Fri, 10 Mar 2023 00:22:53 -0500 Subject: [PATCH] fixed multithreading bug --- zssp/src/zssp.rs | 97 +++++++++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 42 deletions(-) diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index ab6324f78..856c6bb9a 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -639,46 +639,46 @@ impl Context { drop(c); if aead_authentication_ok { - if session.update_receive_window(incoming_counter) { - // Update the current key to point to this key if it's newer, since having received - // a packet encrypted with it proves that the other side has successfully derived it - // as well. - if state.current_key == key_index && key.confirmed { - drop(state); - } else { - let current_key_created_at_counter = key.created_at_counter; + // Packet fully authenticated + if !session.update_receive_window(incoming_counter) { + return Err(Error::OutOfSequence) + } + // Update the current key to point to this key if it's newer, since having received + // a packet encrypted with it proves that the other side has successfully derived it + // as well. + if state.current_key == key_index && key.confirmed { + drop(state); + } else { + let current_key_created_at_counter = key.created_at_counter; - drop(state); - let mut state = session.state.write().unwrap(); + drop(state); + let mut state = session.state.write().unwrap(); - if state.current_key != key_index { - if let Some(other_session_key) = state.keys[state.current_key].as_ref() { - if other_session_key.created_at_counter < current_key_created_at_counter { - state.current_key = key_index; - } - } else { + if state.current_key != key_index { + if let Some(other_session_key) = state.keys[state.current_key].as_ref() { + if other_session_key.created_at_counter < current_key_created_at_counter { state.current_key = key_index; } - } - state.keys[key_index].as_mut().unwrap().confirmed = true; - - // If we got a valid data packet from Bob, this means we can cancel any offers - // that are still oustanding for initialization. - match &state.current_offer { - Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => { - state.current_offer = Offer::None; - } - _ => {} + } else { + state.current_key = key_index; } } + state.keys[key_index].as_mut().unwrap().confirmed = true; - if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len])); - } else { - return Ok(ReceiveResult::Ok(Some(session))); + // If we got a valid data packet from Bob, this means we can cancel any offers + // that are still oustanding for initialization. + match &state.current_offer { + Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => { + state.current_offer = Offer::None; + } + _ => {} } + } + + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len])); } else { - return Err(Error::OutOfSequence); + return Ok(ReceiveResult::Ok(Some(session))); } } } @@ -895,15 +895,18 @@ impl Context { let hk = pqc_kyber::decapsulate(&pkt.bob_hk_ciphertext, outgoing_offer.alice_hk_secret.as_bytes()) .map_err(|_| Error::FailedAuthentication) .map(|k| Secret(k))?; + let noise_se = app.get_local_s_keypair().agree(&bob_noise_e).ok_or(Error::FailedAuthentication)?; + + // Packet fully authenticated + if !session.update_receive_window(incoming_counter) { + return Err(Error::OutOfSequence) + } + let noise_es_ee_se_hk_psk = hmac_sha512_secret::( hmac_sha512_secret::( noise_es_ee.as_bytes(), - app.get_local_s_keypair() - .agree(&bob_noise_e) - .ok_or(Error::FailedAuthentication)? - .as_bytes(), - ) - .as_bytes(), + noise_se.as_bytes(), + ).as_bytes(), hmac_sha512_secret::(session.psk.as_bytes(), hk.as_bytes()).as_bytes(), ); @@ -1066,7 +1069,7 @@ impl Context { application_data, psk, send_counter: AtomicU64::new(2), // 1 was already used during negotiation - receive_window: std::array::from_fn(|_| AtomicU64::new(0)), + receive_window: std::array::from_fn(|_| AtomicU64::new(incoming_counter)), header_protection_cipher: Aes::new(&incoming.header_protection_key), state: RwLock::new(State { remote_session_id: Some(incoming.alice_session_id), @@ -1124,6 +1127,11 @@ impl Context { drop(c); if aead_authentication_ok { + // Packet fully authenticated + if !session.update_receive_window(incoming_counter) { + return Err(Error::OutOfSequence) + } + let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) { let bob_e_secret = P384KeyPair::generate(); @@ -1211,6 +1219,11 @@ impl Context { drop(c); if aead_authentication_ok { + // Packet fully authenticated + if !session.update_receive_window(incoming_counter) { + return Err(Error::OutOfSequence) + } + let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) { let next_session_key = hmac_sha512_secret( @@ -1230,7 +1243,7 @@ impl Context { next_session_key, next_ratchet_count, current_time, - session.send_counter.load(Ordering::Acquire), + session.send_counter.load(Ordering::Relaxed), true, true, )); @@ -1405,13 +1418,13 @@ impl Session { /// Get the next outgoing counter value. #[inline(always)] fn get_next_outgoing_counter(&self) -> Option { - NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::SeqCst)) + NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed)) } /// Check the receive window without mutating state. #[inline(always)] fn check_receive_window(&self, counter: u64) -> bool { - let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Acquire); + let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Relaxed); prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD } @@ -1419,7 +1432,7 @@ impl Session { /// This should only be called after the packet is authenticated. #[inline(always)] fn update_receive_window(&self, counter: u64) -> bool { - let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel); + let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::Relaxed); prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD } }