fixed multithreading bug

This commit is contained in:
monica 2023-03-10 00:22:53 -05:00
parent 285aab8080
commit eb0425a28f
No known key found for this signature in database
GPG key ID: ADCCDBBE0E3D3B3B

View file

@ -639,46 +639,46 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c); drop(c);
if aead_authentication_ok { if aead_authentication_ok {
if session.update_receive_window(incoming_counter) { // Packet fully authenticated
// Update the current key to point to this key if it's newer, since having received if !session.update_receive_window(incoming_counter) {
// a packet encrypted with it proves that the other side has successfully derived it return Err(Error::OutOfSequence)
// as well. }
if state.current_key == key_index && key.confirmed { // Update the current key to point to this key if it's newer, since having received
drop(state); // a packet encrypted with it proves that the other side has successfully derived it
} else { // as well.
let current_key_created_at_counter = key.created_at_counter; if state.current_key == key_index && key.confirmed {
drop(state);
} else {
let current_key_created_at_counter = key.created_at_counter;
drop(state); drop(state);
let mut state = session.state.write().unwrap(); let mut state = session.state.write().unwrap();
if state.current_key != key_index { if state.current_key != key_index {
if let Some(other_session_key) = state.keys[state.current_key].as_ref() { if let Some(other_session_key) = state.keys[state.current_key].as_ref() {
if other_session_key.created_at_counter < current_key_created_at_counter { if other_session_key.created_at_counter < current_key_created_at_counter {
state.current_key = key_index;
}
} else {
state.current_key = key_index; state.current_key = key_index;
} }
} } else {
state.keys[key_index].as_mut().unwrap().confirmed = true; state.current_key = key_index;
// If we got a valid data packet from Bob, this means we can cancel any offers
// that are still oustanding for initialization.
match &state.current_offer {
Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => {
state.current_offer = Offer::None;
}
_ => {}
} }
} }
state.keys[key_index].as_mut().unwrap().confirmed = true;
if packet_type == PACKET_TYPE_DATA { // If we got a valid data packet from Bob, this means we can cancel any offers
return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len])); // that are still oustanding for initialization.
} else { match &state.current_offer {
return Ok(ReceiveResult::Ok(Some(session))); Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => {
state.current_offer = Offer::None;
}
_ => {}
} }
}
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len]));
} else { } else {
return Err(Error::OutOfSequence); return Ok(ReceiveResult::Ok(Some(session)));
} }
} }
} }
@ -895,15 +895,18 @@ impl<Application: ApplicationLayer> Context<Application> {
let hk = pqc_kyber::decapsulate(&pkt.bob_hk_ciphertext, outgoing_offer.alice_hk_secret.as_bytes()) let hk = pqc_kyber::decapsulate(&pkt.bob_hk_ciphertext, outgoing_offer.alice_hk_secret.as_bytes())
.map_err(|_| Error::FailedAuthentication) .map_err(|_| Error::FailedAuthentication)
.map(|k| Secret(k))?; .map(|k| Secret(k))?;
let noise_se = app.get_local_s_keypair().agree(&bob_noise_e).ok_or(Error::FailedAuthentication)?;
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let noise_es_ee_se_hk_psk = hmac_sha512_secret::<BASE_KEY_SIZE>( let noise_es_ee_se_hk_psk = hmac_sha512_secret::<BASE_KEY_SIZE>(
hmac_sha512_secret::<BASE_KEY_SIZE>( hmac_sha512_secret::<BASE_KEY_SIZE>(
noise_es_ee.as_bytes(), noise_es_ee.as_bytes(),
app.get_local_s_keypair() noise_se.as_bytes(),
.agree(&bob_noise_e) ).as_bytes(),
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
)
.as_bytes(),
hmac_sha512_secret::<BASE_KEY_SIZE>(session.psk.as_bytes(), hk.as_bytes()).as_bytes(), hmac_sha512_secret::<BASE_KEY_SIZE>(session.psk.as_bytes(), hk.as_bytes()).as_bytes(),
); );
@ -1066,7 +1069,7 @@ impl<Application: ApplicationLayer> Context<Application> {
application_data, application_data,
psk, psk,
send_counter: AtomicU64::new(2), // 1 was already used during negotiation send_counter: AtomicU64::new(2), // 1 was already used during negotiation
receive_window: std::array::from_fn(|_| AtomicU64::new(0)), receive_window: std::array::from_fn(|_| AtomicU64::new(incoming_counter)),
header_protection_cipher: Aes::new(&incoming.header_protection_key), header_protection_cipher: Aes::new(&incoming.header_protection_key),
state: RwLock::new(State { state: RwLock::new(State {
remote_session_id: Some(incoming.alice_session_id), remote_session_id: Some(incoming.alice_session_id),
@ -1124,6 +1127,11 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c); drop(c);
if aead_authentication_ok { if aead_authentication_ok {
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) { if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) {
let bob_e_secret = P384KeyPair::generate(); let bob_e_secret = P384KeyPair::generate();
@ -1211,6 +1219,11 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c); drop(c);
if aead_authentication_ok { if aead_authentication_ok {
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) { if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) {
let next_session_key = hmac_sha512_secret( let next_session_key = hmac_sha512_secret(
@ -1230,7 +1243,7 @@ impl<Application: ApplicationLayer> Context<Application> {
next_session_key, next_session_key,
next_ratchet_count, next_ratchet_count,
current_time, current_time,
session.send_counter.load(Ordering::Acquire), session.send_counter.load(Ordering::Relaxed),
true, true,
true, true,
)); ));
@ -1405,13 +1418,13 @@ impl<Application: ApplicationLayer> Session<Application> {
/// Get the next outgoing counter value. /// Get the next outgoing counter value.
#[inline(always)] #[inline(always)]
fn get_next_outgoing_counter(&self) -> Option<NonZeroU64> { fn get_next_outgoing_counter(&self) -> Option<NonZeroU64> {
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::SeqCst)) NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed))
} }
/// Check the receive window without mutating state. /// Check the receive window without mutating state.
#[inline(always)] #[inline(always)]
fn check_receive_window(&self, counter: u64) -> bool { fn check_receive_window(&self, counter: u64) -> bool {
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Acquire); let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Relaxed);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
} }
@ -1419,7 +1432,7 @@ impl<Application: ApplicationLayer> Session<Application> {
/// This should only be called after the packet is authenticated. /// This should only be called after the packet is authenticated.
#[inline(always)] #[inline(always)]
fn update_receive_window(&self, counter: u64) -> bool { fn update_receive_window(&self, counter: u64) -> bool {
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel); let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::Relaxed);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
} }
} }