fixed multithreading bug

This commit is contained in:
monica 2023-03-10 00:22:53 -05:00
parent 285aab8080
commit eb0425a28f
No known key found for this signature in database
GPG key ID: ADCCDBBE0E3D3B3B

View file

@ -639,7 +639,10 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c);
if aead_authentication_ok {
if session.update_receive_window(incoming_counter) {
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
// Update the current key to point to this key if it's newer, since having received
// a packet encrypted with it proves that the other side has successfully derived it
// as well.
@ -677,9 +680,6 @@ impl<Application: ApplicationLayer> Context<Application> {
} else {
return Ok(ReceiveResult::Ok(Some(session)));
}
} else {
return Err(Error::OutOfSequence);
}
}
}
@ -895,15 +895,18 @@ impl<Application: ApplicationLayer> Context<Application> {
let hk = pqc_kyber::decapsulate(&pkt.bob_hk_ciphertext, outgoing_offer.alice_hk_secret.as_bytes())
.map_err(|_| Error::FailedAuthentication)
.map(|k| Secret(k))?;
let noise_se = app.get_local_s_keypair().agree(&bob_noise_e).ok_or(Error::FailedAuthentication)?;
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let noise_es_ee_se_hk_psk = hmac_sha512_secret::<BASE_KEY_SIZE>(
hmac_sha512_secret::<BASE_KEY_SIZE>(
noise_es_ee.as_bytes(),
app.get_local_s_keypair()
.agree(&bob_noise_e)
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
)
.as_bytes(),
noise_se.as_bytes(),
).as_bytes(),
hmac_sha512_secret::<BASE_KEY_SIZE>(session.psk.as_bytes(), hk.as_bytes()).as_bytes(),
);
@ -1066,7 +1069,7 @@ impl<Application: ApplicationLayer> Context<Application> {
application_data,
psk,
send_counter: AtomicU64::new(2), // 1 was already used during negotiation
receive_window: std::array::from_fn(|_| AtomicU64::new(0)),
receive_window: std::array::from_fn(|_| AtomicU64::new(incoming_counter)),
header_protection_cipher: Aes::new(&incoming.header_protection_key),
state: RwLock::new(State {
remote_session_id: Some(incoming.alice_session_id),
@ -1124,6 +1127,11 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c);
if aead_authentication_ok {
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) {
let bob_e_secret = P384KeyPair::generate();
@ -1211,6 +1219,11 @@ impl<Application: ApplicationLayer> Context<Application> {
drop(c);
if aead_authentication_ok {
// Packet fully authenticated
if !session.update_receive_window(incoming_counter) {
return Err(Error::OutOfSequence)
}
let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) {
let next_session_key = hmac_sha512_secret(
@ -1230,7 +1243,7 @@ impl<Application: ApplicationLayer> Context<Application> {
next_session_key,
next_ratchet_count,
current_time,
session.send_counter.load(Ordering::Acquire),
session.send_counter.load(Ordering::Relaxed),
true,
true,
));
@ -1405,13 +1418,13 @@ impl<Application: ApplicationLayer> Session<Application> {
/// Get the next outgoing counter value.
#[inline(always)]
fn get_next_outgoing_counter(&self) -> Option<NonZeroU64> {
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::SeqCst))
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed))
}
/// Check the receive window without mutating state.
#[inline(always)]
fn check_receive_window(&self, counter: u64) -> bool {
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Acquire);
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Relaxed);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
@ -1419,7 +1432,7 @@ impl<Application: ApplicationLayer> Session<Application> {
/// This should only be called after the packet is authenticated.
#[inline(always)]
fn update_receive_window(&self, counter: u64) -> bool {
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel);
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::Relaxed);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
}