fixed multithreading bug

This commit is contained in:
monica 2023-03-10 00:22:53 -05:00
parent 285aab8080
commit eb0425a28f
No known key found for this signature in database
GPG key ID: ADCCDBBE0E3D3B3B

View file

@ -639,46 +639,46 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c);
if aead_authentication_ok {
if session.update_receive_window(incoming_counter) {
// Update the current key to point to this key if it's newer, since having received
// a packet encrypted with it proves that the other side has successfully derived it
// as well.
if state.current_key == key_index && key.confirmed {
drop(state);
} else {
let current_key_created_at_counter = key.created_at_counter;
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
// Update the current key to point to this key if it's newer, since having received
// a packet encrypted with it proves that the other side has successfully derived it
// as well.
if state.current_key == key_index && key.confirmed {
drop(state);
} else {
let current_key_created_at_counter = key.created_at_counter;
drop(state);
let mut state = session.state.write().unwrap();
drop(state);
let mut state = session.state.write().unwrap();
if state.current_key != key_index {
if let Some(other_session_key) = state.keys[state.current_key].as_ref() {
if other_session_key.created_at_counter < current_key_created_at_counter {
state.current_key = key_index;
}
} else {
if state.current_key != key_index {
if let Some(other_session_key) = state.keys[state.current_key].as_ref() {
if other_session_key.created_at_counter < current_key_created_at_counter {
state.current_key = key_index;
}
}
state.keys[key_index].as_mut().unwrap().confirmed = true;
// If we got a valid data packet from Bob, this means we can cancel any offers
// that are still oustanding for initialization.
match &state.current_offer {
Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => {
state.current_offer = Offer::None;
}
_ => {}
} else {
state.current_key = key_index;
}
}
state.keys[key_index].as_mut().unwrap().confirmed = true;
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len]));
} else {
return Ok(ReceiveResult::Ok(Some(session)));
// If we got a valid data packet from Bob, this means we can cancel any offers
// that are still oustanding for initialization.
match &state.current_offer {
Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => {
state.current_offer = Offer::None;
}
_ => {}
}
}
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len]));
} else {
return Err(Error::OutOfSequence);
return Ok(ReceiveResult::Ok(Some(session)));
}
}
}
@ -895,15 +895,18 @@ impl<Application: ApplicationLayer> Context<Application> {
let hk = pqc_kyber::decapsulate(&pkt.bob_hk_ciphertext, outgoing_offer.alice_hk_secret.as_bytes())
.map_err(|_| Error::FailedAuthentication)
.map(|k| Secret(k))?;
let noise_se = app.get_local_s_keypair().agree(&bob_noise_e).ok_or(Error::FailedAuthentication)?;
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let noise_es_ee_se_hk_psk = hmac_sha512_secret::<BASE_KEY_SIZE>(
hmac_sha512_secret::<BASE_KEY_SIZE>(
noise_es_ee.as_bytes(),
app.get_local_s_keypair()
.agree(&bob_noise_e)
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
)
.as_bytes(),
noise_se.as_bytes(),
).as_bytes(),
hmac_sha512_secret::<BASE_KEY_SIZE>(session.psk.as_bytes(), hk.as_bytes()).as_bytes(),
);
@ -1066,7 +1069,7 @@ impl<Application: ApplicationLayer> Context<Application> {
application_data,
psk,
send_counter: AtomicU64::new(2), // 1 was already used during negotiation
receive_window: std::array::from_fn(|_| AtomicU64::new(0)),
receive_window: std::array::from_fn(|_| AtomicU64::new(incoming_counter)),
header_protection_cipher: Aes::new(&incoming.header_protection_key),
state: RwLock::new(State {
remote_session_id: Some(incoming.alice_session_id),
@ -1124,6 +1127,11 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c);
if aead_authentication_ok {
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) {
let bob_e_secret = P384KeyPair::generate();
@ -1211,6 +1219,11 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c);
if aead_authentication_ok {
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) {
let next_session_key = hmac_sha512_secret(
@ -1230,7 +1243,7 @@ impl<Application: ApplicationLayer> Context<Application> {
next_session_key,
next_ratchet_count,
current_time,
session.send_counter.load(Ordering::Acquire),
session.send_counter.load(Ordering::Relaxed),
true,
true,
));
@ -1405,13 +1418,13 @@ impl<Application: ApplicationLayer> Session<Application> {
/// Get the next outgoing counter value.
#[inline(always)]
fn get_next_outgoing_counter(&self) -> Option<NonZeroU64> {
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::SeqCst))
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed))
}
/// Check the receive window without mutating state.
#[inline(always)]
fn check_receive_window(&self, counter: u64) -> bool {
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Acquire);
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Relaxed);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
@ -1419,7 +1432,7 @@ impl<Application: ApplicationLayer> Session<Application> {
/// This should only be called after the packet is authenticated.
#[inline(always)]
fn update_receive_window(&self, counter: u64) -> bool {
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel);
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::Relaxed);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
}