got all tests to pass

This commit is contained in:
mamoniot 2022-12-27 21:09:01 -05:00
parent 8d1efcdffa
commit 5d72aabe17
3 changed files with 109 additions and 76 deletions

View file

@ -17,9 +17,10 @@ impl Counter {
// helps randomize packet contents a bit.
Self(AtomicU32::new(1u32))
}
#[inline(always)]
pub fn reset_after_initial_offer(&self) {
self.0.store(2u32, Ordering::SeqCst);
pub fn reset_for_initial_offer(&self) {
self.0.store(1u32, Ordering::SeqCst);
}
/// Get the value most recently used to send a packet.
@ -63,7 +64,8 @@ impl CounterWindow {
pub fn new_invalid() -> Self {
Self(std::array::from_fn(|_| AtomicU32::new(u32::MAX)))
}
pub fn reset_after_initial_offer(&self) {
pub fn reset_for_initial_offer(&self) {
let o = true;
for i in 0..COUNTER_MAX_ALLOWED_OOO {
self.0[i].store(0, Ordering::SeqCst)
}

View file

@ -254,7 +254,7 @@ mod tests {
w.message_received(c);
continue;
} else {
w.reset_after_initial_offer();
w.reset_for_initial_offer();
counter = 1u32;
history = Vec::new();
continue;

View file

@ -112,12 +112,12 @@ pub struct Session<Application: ApplicationLayer> {
/// An arbitrary application defined object associated with each session
pub application_data: Application::Data,
send_counter: Counter, // Outgoing packet counter and nonce state
receive_window: [CounterWindow; 2], // Receive window for anti-replay and deduplication
header_check_cipher: Aes, // Cipher used for header check codes (not Noise related)
receive_windows: [CounterWindow; 2], // Receive window for anti-replay and deduplication
send_counters: [Counter; 2], // Outgoing packet counter and nonce state
state: RwLock<SessionMutableState>, // Mutable parts of state (other than defrag buffers)
psk: Secret<64>, // Arbitrary PSK provided by external code
noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key
header_check_cipher: Aes, // Cipher used for header check codes (not Noise related)
state: RwLock<SessionMutableState>, // Mutable parts of state (other than defrag buffers)
remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_p384_bytes: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
@ -236,6 +236,7 @@ impl<Application: ApplicationLayer> Session<Application> {
&mut send,
send_counter.next(),
false,
false,
local_session_id,
None,
app.get_local_s_public_blob(),
@ -254,11 +255,9 @@ impl<Application: ApplicationLayer> Session<Application> {
return Ok(Self {
id: local_session_id,
application_data,
send_counter,
receive_window: [CounterWindow::new(), CounterWindow::new_invalid()],
psk: psk.clone(),
noise_ss,
header_check_cipher,
receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()],
send_counters: [send_counter, Counter::new()],
state: RwLock::new(SessionMutableState {
remote_session_id: None,
session_keys: [None, None],
@ -266,6 +265,8 @@ impl<Application: ApplicationLayer> Session<Application> {
offer,
last_remote_offer: i64::MIN,
}),
psk: psk.clone(),
noise_ss,
remote_s_public_blob_hash: bob_s_public_blob_hash,
remote_s_public_p384_bytes: bob_s_public.as_bytes().clone(),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
@ -296,7 +297,7 @@ impl<Application: ApplicationLayer> Session<Application> {
let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
// This outgoing packet's nonce counter value.
let counter = self.send_counter.next();
let counter = self.send_counters[state.cur_session_key_id as usize].next();
////////////////////////////////////////////////////////////////
// packet encoding for post-noise transport
@ -403,10 +404,11 @@ impl<Application: ApplicationLayer> Session<Application> {
force_rekey: bool,
) {
let state = self.state.read().unwrap();
let current_key_id = state.cur_session_key_id;
if (force_rekey
|| state.session_keys[state.cur_session_key_id as usize]
.as_ref()
.map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time)))
.map_or(true, |key| key.lifetime.should_rekey(self.send_counters[current_key_id as usize].current(), current_time)))
&& state
.offer
.as_ref()
@ -416,8 +418,9 @@ impl<Application: ApplicationLayer> Session<Application> {
let mut offer = None;
if send_ephemeral_offer(
&mut send,
CounterValue::get_initial_offer_counter(),
!state.cur_session_key_id,
self.send_counters[current_key_id as usize].next(),
current_key_id,
!current_key_id,
self.id,
state.remote_session_id,
app.get_local_s_public_blob(),
@ -425,7 +428,7 @@ impl<Application: ApplicationLayer> Session<Application> {
&remote_s_public,
&self.remote_s_public_blob_hash,
&self.noise_ss,
state.session_keys[state.cur_session_key_id as usize].as_ref(),
state.session_keys[current_key_id as usize].as_ref(),
if state.remote_session_id.is_some() {
Some(&self.header_check_cipher)
} else {
@ -492,12 +495,13 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
{
if let Some(session) = app.lookup_session(local_session_id) {
if verify_header_check_code(incoming_packet, &session.header_check_cipher) {
if session.receive_window[key_id as usize].message_received(counter) {
if session.receive_windows[key_id as usize].message_received(counter) {
let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter);
if fragment_count > 1 {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
let mut defrag = session.defrag.lock().unwrap();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
// by using the counter + the key_id as the key we can prevent packet collisions, this only works if defrag hashes
let fragment_gather_array = defrag.get_or_create_mut(&raw_counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock
return self.receive_complete(
@ -672,7 +676,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
session_key.return_receive_cipher(c);
if aead_authentication_ok {
if session.receive_window[key_id as usize].message_authenticated(counter) {
if session.receive_windows[key_id as usize].message_authenticated(counter) {
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else {
@ -832,7 +836,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
// Perform checks and match ratchet key if there's an existing session, or gate (via host) and
// then create new sessions.
let (new_session, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() {
let (new_session, reply_counter, new_key_id, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() {
// Existing session identity must match the one in this offer.
if !secure_eq(&session.remote_s_public_blob_hash, &SHA384::hash(&alice_s_public_blob)) {
return Err(Error::FailedAuthentication);
@ -843,10 +847,12 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
let mut ratchet_key = None;
let mut last_ratchet_count = 0;
let state = session.state.read().unwrap();
if state.cur_session_key_id == key_id {
return Ok(ReceiveResult::Ignored); // alice is requesting to overwrite the current key, reject it
//key_id here is the key id of the key being rekeyed and replaced
//it must be equal to the current session key, and not the previous session key
if state.cur_session_key_id != key_id {
return Ok(ReceiveResult::Ignored);
}
if let Some(k) = state.session_keys[state.cur_session_key_id as usize].as_ref() {
if let Some(k) = state.session_keys[key_id as usize].as_ref() {
if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) {
ratchet_key = Some(k.ratchet_key.clone());
last_ratchet_count = k.ratchet_count;
@ -856,34 +862,41 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
return Ok(ReceiveResult::Ignored); // old packet?
}
(None, ratchet_key, last_ratchet_count)
(None, session.send_counters[key_id as usize].next(), !key_id, ratchet_key, last_ratchet_count)
} else {
if key_id != false {
return Ok(ReceiveResult::Ignored);
}
if let Some((new_session_id, psk, associated_object)) =
app.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata)
{
let header_check_cipher = Aes::new(
kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(),
);
let send_counter = Counter::new();
let reply_counter = send_counter.next();
(
Some(Session::<Application> {
id: new_session_id,
application_data: associated_object,
send_counter: Counter::new(),
receive_window: [CounterWindow::new_invalid(), CounterWindow::new_invalid()],
psk,
noise_ss,
header_check_cipher,
receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()],
send_counters: [send_counter, Counter::new()],
state: RwLock::new(SessionMutableState {
remote_session_id: Some(alice_session_id),
session_keys: [None, None],
cur_session_key_id: key_id,
session_keys: [None, None],//this is the only value which will be writen later
cur_session_key_id: false,
offer: None,
last_remote_offer: current_time,
}),
psk,
noise_ss,
remote_s_public_blob_hash: SHA384::hash(&alice_s_public_blob),
remote_s_public_p384_bytes: alice_s_public.as_bytes().clone(),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
}),
reply_counter,
false,
None,
0,
)
@ -946,7 +959,6 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
////////////////////////////////////////////////////////////////
let mut reply_buf = [0_u8; KEX_BUF_LEN];
let reply_counter = CounterValue::get_initial_offer_counter();
let mut idx = HEADER_SIZE;
idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?;
@ -1016,7 +1028,8 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
);
idx = safe_write_all(&mut reply_buf, idx, &hmac)?;
let packet_end = idx;
if session.receive_windows[key_id as usize].message_authenticated(counter) {
//initial key offers should only check this if this is a rekey
let session_key = SessionKey::new(
session_key,
Role::Bob,
@ -1025,13 +1038,15 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
hybrid_kk.is_some(),
);
//TODO: check for correct orderings
let mut state = session.state.write().unwrap();
let _ = state.session_keys[new_key_id as usize].replace(session_key);
if existing_session.is_some() {
let _ = state.remote_session_id.replace(alice_session_id);
let _ = state.session_keys[key_id as usize].replace(session_key);
session.send_counter.reset_after_initial_offer();
state.cur_session_key_id = key_id;
session.receive_window[key_id as usize].reset_after_initial_offer();
session.receive_window[key_id as usize].message_authenticated(counter);
state.cur_session_key_id = new_key_id;
session.send_counters[new_key_id as usize].reset_for_initial_offer();
session.receive_windows[new_key_id as usize].reset_for_initial_offer();
}
drop(state);
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
@ -1043,6 +1058,10 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
} else {
return Ok(ReceiveResult::Ok);
}
} else {
return Ok(ReceiveResult::Ignored);
}
}
PACKET_TYPE_KEY_COUNTER_OFFER => {
@ -1096,8 +1115,9 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?;
// Check that this is a counter offer to the original offer we sent.
if !(offer.id.eq(offer_id) & (offer.key_id == key_id)) {
if !(offer.id.eq(offer_id)) {
return Ok(ReceiveResult::Ignored);
}
// Kyber1024 key agreement if enabled.
@ -1138,7 +1158,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
) {
return Err(Error::FailedAuthentication);
}
if session.receive_windows[key_id as usize].message_authenticated(counter) {
// Alice has now completed and validated the full hybrid exchange.
let session_key = SessionKey::new(
@ -1149,16 +1169,25 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
hybrid_kk.is_some(),
);
let new_key_id = offer.key_id;
let is_new_session = offer.ratchet_count == 0;
drop(state);
//TODO: check for correct orderings
let mut state = session.state.write().unwrap();
let _ = state.remote_session_id.replace(bob_session_id);
let _ = state.session_keys[key_id as usize].replace(session_key);
session.send_counter.reset_after_initial_offer();
state.cur_session_key_id = key_id;
session.receive_window[key_id as usize].message_authenticated(counter);
let _ = state.session_keys[new_key_id as usize].replace(session_key);
if !is_new_session {
//when an brand new key offer is sent, it is sent using the new_key_id==false counter, we cannot reset it in that case.
state.cur_session_key_id = new_key_id;
session.send_counters[new_key_id as usize].reset_for_initial_offer();
session.receive_windows[new_key_id as usize].reset_for_initial_offer();
}
let _ = state.offer.take();
return Ok(ReceiveResult::Ok);
} else {
return Ok(ReceiveResult::Ignored);
}
}
}
@ -1174,10 +1203,12 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
}
/// Create an send an ephemeral offer, populating ret_ephemeral_offer on success.
/// If there is no current session key set `current_key_id == new_key_id == false`
fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
send: &mut SendFunction,
counter: CounterValue,
key_id: bool,
current_key_id: bool,
new_key_id: bool,
alice_session_id: SessionId,
bob_session_id: Option<SessionId>,
alice_s_public_blob: &[u8],
@ -1263,7 +1294,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
PACKET_TYPE_INITIAL_KEY_OFFER,
bob_session_id,
counter,
key_id,
current_key_id,
)?;
let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32());
@ -1317,7 +1348,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
*ret_ephemeral_offer = Some(EphemeralOffer {
id,
key_id,
key_id: new_key_id,
creation_time: current_time,
ratchet_count,
ratchet_key,