completed audit for threadsafety

This commit is contained in:
mamoniot 2022-12-27 22:00:49 -05:00
parent bf3591f593
commit d3d7cc1a3c
2 changed files with 21 additions and 17 deletions

View file

@ -3,10 +3,8 @@ use std::sync::atomic::{Ordering, AtomicU32};
use crate::constants::COUNTER_MAX_ALLOWED_OOO; use crate::constants::COUNTER_MAX_ALLOWED_OOO;
/// Outgoing packet counter with strictly ordered atomic semantics. /// Outgoing packet counter with strictly ordered atomic semantics.
/// Count sequence always starts at 1u32, it must never be allowed to overflow
/// ///
/// The counter used in packets is actually 32 bits, but using a 64-bit integer internally
/// lets us more safely implement key lifetime limits without confusing logic to handle 32-bit
/// wrap-around.
#[repr(transparent)] #[repr(transparent)]
pub(crate) struct Counter(AtomicU32); pub(crate) struct Counter(AtomicU32);

View file

@ -81,7 +81,8 @@ pub enum ReceiveResult<'a, H: ApplicationLayer> {
/// The session will have already been gated by the accept_new_session() method in the Host trait. /// The session will have already been gated by the accept_new_session() method in the Host trait.
OkNewSession(Session<H>), OkNewSession(Session<H>),
/// Packet appears valid but was ignored e.g. as a duplicate. /// Packet superficially appears valid but was ignored e.g. as a duplicate.
/// IMPORTANT: Authentication was not completed on this packet, so for the most part treat this the same as an Error::FailedAuthentication
Ignored, Ignored,
} }
@ -114,7 +115,6 @@ pub struct Session<Application: ApplicationLayer> {
header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) header_check_cipher: Aes, // Cipher used for header check codes (not Noise related)
receive_windows: [CounterWindow; 2], // Receive window for anti-replay and deduplication receive_windows: [CounterWindow; 2], // Receive window for anti-replay and deduplication
send_counters: [Counter; 2], // Outgoing packet counter and nonce state
state: RwLock<SessionMutableState>, // Mutable parts of state (other than defrag buffers) state: RwLock<SessionMutableState>, // Mutable parts of state (other than defrag buffers)
psk: Secret<64>, // Arbitrary PSK provided by external code psk: Secret<64>, // Arbitrary PSK provided by external code
noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key
@ -126,6 +126,7 @@ pub struct Session<Application: ApplicationLayer> {
struct SessionMutableState { struct SessionMutableState {
remote_session_id: Option<SessionId>, // The other side's 48-bit session ID remote_session_id: Option<SessionId>, // The other side's 48-bit session ID
send_counters: [Counter; 2], // Outgoing packet counter and nonce state
session_keys: [Option<SessionKey>; 2], // Buffers to store current, next, and last active key session_keys: [Option<SessionKey>; 2], // Buffers to store current, next, and last active key
cur_session_key_id: bool, // Pointer used for keys[] circular buffer cur_session_key_id: bool, // Pointer used for keys[] circular buffer
offer: Option<EphemeralOffer>, // Most recent ephemeral offer sent to remote offer: Option<EphemeralOffer>, // Most recent ephemeral offer sent to remote
@ -257,8 +258,8 @@ impl<Application: ApplicationLayer> Session<Application> {
application_data, application_data,
header_check_cipher, header_check_cipher,
receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()],
send_counters: [send_counter, Counter::new()],
state: RwLock::new(SessionMutableState { state: RwLock::new(SessionMutableState {
send_counters: [send_counter, Counter::new()],
remote_session_id: None, remote_session_id: None,
session_keys: [None, None], session_keys: [None, None],
cur_session_key_id: false, cur_session_key_id: false,
@ -297,7 +298,7 @@ impl<Application: ApplicationLayer> Session<Application> {
let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
// This outgoing packet's nonce counter value. // This outgoing packet's nonce counter value.
let counter = self.send_counters[state.cur_session_key_id as usize].next(); let counter = state.send_counters[state.cur_session_key_id as usize].next();
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
// packet encoding for post-noise transport // packet encoding for post-noise transport
@ -408,7 +409,7 @@ impl<Application: ApplicationLayer> Session<Application> {
if (force_rekey if (force_rekey
|| state.session_keys[state.cur_session_key_id as usize] || state.session_keys[state.cur_session_key_id as usize]
.as_ref() .as_ref()
.map_or(true, |key| key.lifetime.should_rekey(self.send_counters[current_key_id as usize].current(), current_time))) .map_or(true, |key| key.lifetime.should_rekey(state.send_counters[current_key_id as usize].current(), current_time)))
&& state && state
.offer .offer
.as_ref() .as_ref()
@ -418,7 +419,7 @@ impl<Application: ApplicationLayer> Session<Application> {
let mut offer = None; let mut offer = None;
if send_ephemeral_offer( if send_ephemeral_offer(
&mut send, &mut send,
self.send_counters[current_key_id as usize].next(), state.send_counters[current_key_id as usize].next(),
current_key_id, current_key_id,
!current_key_id, !current_key_id,
self.id, self.id,
@ -494,8 +495,8 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64) if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64)
{ {
if let Some(session) = app.lookup_session(local_session_id) { if let Some(session) = app.lookup_session(local_session_id) {
if verify_header_check_code(incoming_packet, &session.header_check_cipher) { if session.receive_windows[key_id as usize].message_received(counter) {
if session.receive_windows[key_id as usize].message_received(counter) { if verify_header_check_code(incoming_packet, &session.header_check_cipher) {
let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter);
if fragment_count > 1 { if fragment_count > 1 {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
@ -541,11 +542,11 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
} }
} else { } else {
unlikely_branch(); unlikely_branch();
return Ok(ReceiveResult::Ignored); return Err(Error::FailedAuthentication);
} }
} else { } else {
unlikely_branch(); unlikely_branch();
return Err(Error::FailedAuthentication); return Ok(ReceiveResult::Ignored);
} }
} else { } else {
unlikely_branch(); unlikely_branch();
@ -862,7 +863,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
return Ok(ReceiveResult::Ignored); // old packet? return Ok(ReceiveResult::Ignored); // old packet?
} }
(None, session.send_counters[key_id as usize].next(), !key_id, ratchet_key, last_ratchet_count) (None, state.send_counters[state.cur_session_key_id as usize].next(), !key_id, ratchet_key, last_ratchet_count)
} else { } else {
if key_id != false { if key_id != false {
return Ok(ReceiveResult::Ignored); return Ok(ReceiveResult::Ignored);
@ -881,8 +882,8 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
application_data: associated_object, application_data: associated_object,
header_check_cipher, header_check_cipher,
receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()],
send_counters: [send_counter, Counter::new()],
state: RwLock::new(SessionMutableState { state: RwLock::new(SessionMutableState {
send_counters: [send_counter, Counter::new()],
remote_session_id: Some(alice_session_id), remote_session_id: Some(alice_session_id),
session_keys: [None, None],//this is the only value which will be writen later session_keys: [None, None],//this is the only value which will be writen later
cur_session_key_id: false, cur_session_key_id: false,
@ -1043,8 +1044,12 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
let _ = state.session_keys[new_key_id as usize].replace(session_key); let _ = state.session_keys[new_key_id as usize].replace(session_key);
if existing_session.is_some() { if existing_session.is_some() {
let _ = state.remote_session_id.replace(alice_session_id); let _ = state.remote_session_id.replace(alice_session_id);
//if this wasn't done inside a lock, a theoretical race condition exists where a thread uses the new key id before the counter is reset, or worse: a thread has held onto the previous key_id == new_key_id, and attempts to use the reset counter
//for this reason do not access send_counters without holding the read lock
state.cur_session_key_id = new_key_id; state.cur_session_key_id = new_key_id;
session.send_counters[new_key_id as usize].reset_for_initial_offer(); state.send_counters[new_key_id as usize].reset_for_initial_offer();
//receive_windows only has race conditions with the send counter of the remote party. It is theoretically possible that the local host receives counters under new_key_id while the receive_window is still in the process of resetting, but this is both very unlikely and can only ever cause dropped packets under the current implementation of receive_window.
//if receive_window is ever reimplemented, double check it maintains the above property.
session.receive_windows[new_key_id as usize].reset_for_initial_offer(); session.receive_windows[new_key_id as usize].reset_for_initial_offer();
} }
drop(state); drop(state);
@ -1178,8 +1183,9 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
let _ = state.session_keys[new_key_id as usize].replace(session_key); let _ = state.session_keys[new_key_id as usize].replace(session_key);
if !is_new_session { if !is_new_session {
//when an brand new key offer is sent, it is sent using the new_key_id==false counter, we cannot reset it in that case. //when an brand new key offer is sent, it is sent using the new_key_id==false counter, we cannot reset it in that case.
//NOTE: the following code should be properly threadsafe, see the large comment above at the end of KEY_OFFER decoding for more info
state.cur_session_key_id = new_key_id; state.cur_session_key_id = new_key_id;
session.send_counters[new_key_id as usize].reset_for_initial_offer(); state.send_counters[new_key_id as usize].reset_for_initial_offer();
session.receive_windows[new_key_id as usize].reset_for_initial_offer(); session.receive_windows[new_key_id as usize].reset_for_initial_offer();
} }
let _ = state.offer.take(); let _ = state.offer.take();