From 5d72aabe170b02d154d283e6805f8b37f4f797a8 Mon Sep 17 00:00:00 2001 From: mamoniot Date: Tue, 27 Dec 2022 21:09:01 -0500 Subject: [PATCH] got all tests to pass --- zssp/src/counter.rs | 8 +- zssp/src/tests.rs | 2 +- zssp/src/zssp.rs | 175 ++++++++++++++++++++++++++------------------ 3 files changed, 109 insertions(+), 76 deletions(-) diff --git a/zssp/src/counter.rs b/zssp/src/counter.rs index 2f7fa0749..8b5c9f789 100644 --- a/zssp/src/counter.rs +++ b/zssp/src/counter.rs @@ -17,9 +17,10 @@ impl Counter { // helps randomize packet contents a bit. Self(AtomicU32::new(1u32)) } + #[inline(always)] - pub fn reset_after_initial_offer(&self) { - self.0.store(2u32, Ordering::SeqCst); + pub fn reset_for_initial_offer(&self) { + self.0.store(1u32, Ordering::SeqCst); } /// Get the value most recently used to send a packet. @@ -63,7 +64,8 @@ impl CounterWindow { pub fn new_invalid() -> Self { Self(std::array::from_fn(|_| AtomicU32::new(u32::MAX))) } - pub fn reset_after_initial_offer(&self) { + pub fn reset_for_initial_offer(&self) { + let o = true; for i in 0..COUNTER_MAX_ALLOWED_OOO { self.0[i].store(0, Ordering::SeqCst) } diff --git a/zssp/src/tests.rs b/zssp/src/tests.rs index b15a5e3f4..747b86dfb 100644 --- a/zssp/src/tests.rs +++ b/zssp/src/tests.rs @@ -254,7 +254,7 @@ mod tests { w.message_received(c); continue; } else { - w.reset_after_initial_offer(); + w.reset_for_initial_offer(); counter = 1u32; history = Vec::new(); continue; diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 9fd678801..4112b9eb4 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -112,12 +112,12 @@ pub struct Session { /// An arbitrary application defined object associated with each session pub application_data: Application::Data, - send_counter: Counter, // Outgoing packet counter and nonce state - receive_window: [CounterWindow; 2], // Receive window for anti-replay and deduplication + header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) + receive_windows: [CounterWindow; 2], // Receive window for anti-replay and deduplication + send_counters: [Counter; 2], // Outgoing packet counter and nonce state + state: RwLock, // Mutable parts of state (other than defrag buffers) psk: Secret<64>, // Arbitrary PSK provided by external code noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key - header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) - state: RwLock, // Mutable parts of state (other than defrag buffers) remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_p384_bytes: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key @@ -127,7 +127,7 @@ pub struct Session { struct SessionMutableState { remote_session_id: Option, // The other side's 48-bit session ID session_keys: [Option; 2], // Buffers to store current, next, and last active key - cur_session_key_id: bool, // Pointer used for keys[] circular buffer + cur_session_key_id: bool, // Pointer used for keys[] circular buffer offer: Option, // Most recent ephemeral offer sent to remote last_remote_offer: i64, // Time of most recent ephemeral offer (ms) } @@ -236,6 +236,7 @@ impl Session { &mut send, send_counter.next(), false, + false, local_session_id, None, app.get_local_s_public_blob(), @@ -254,11 +255,9 @@ impl Session { return Ok(Self { id: local_session_id, application_data, - send_counter, - receive_window: [CounterWindow::new(), CounterWindow::new_invalid()], - psk: psk.clone(), - noise_ss, header_check_cipher, + receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], + send_counters: [send_counter, Counter::new()], state: RwLock::new(SessionMutableState { remote_session_id: None, session_keys: [None, None], @@ -266,6 +265,8 @@ impl Session { offer, last_remote_offer: i64::MIN, }), + psk: psk.clone(), + noise_ss, remote_s_public_blob_hash: bob_s_public_blob_hash, remote_s_public_p384_bytes: bob_s_public.as_bytes().clone(), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), @@ -296,7 +297,7 @@ impl Session { let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; // This outgoing packet's nonce counter value. - let counter = self.send_counter.next(); + let counter = self.send_counters[state.cur_session_key_id as usize].next(); //////////////////////////////////////////////////////////////// // packet encoding for post-noise transport @@ -403,10 +404,11 @@ impl Session { force_rekey: bool, ) { let state = self.state.read().unwrap(); + let current_key_id = state.cur_session_key_id; if (force_rekey || state.session_keys[state.cur_session_key_id as usize] .as_ref() - .map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time))) + .map_or(true, |key| key.lifetime.should_rekey(self.send_counters[current_key_id as usize].current(), current_time))) && state .offer .as_ref() @@ -416,8 +418,9 @@ impl Session { let mut offer = None; if send_ephemeral_offer( &mut send, - CounterValue::get_initial_offer_counter(), - !state.cur_session_key_id, + self.send_counters[current_key_id as usize].next(), + current_key_id, + !current_key_id, self.id, state.remote_session_id, app.get_local_s_public_blob(), @@ -425,7 +428,7 @@ impl Session { &remote_s_public, &self.remote_s_public_blob_hash, &self.noise_ss, - state.session_keys[state.cur_session_key_id as usize].as_ref(), + state.session_keys[current_key_id as usize].as_ref(), if state.remote_session_id.is_some() { Some(&self.header_check_cipher) } else { @@ -492,12 +495,13 @@ impl ReceiveContext { { if let Some(session) = app.lookup_session(local_session_id) { if verify_header_check_code(incoming_packet, &session.header_check_cipher) { - if session.receive_window[key_id as usize].message_received(counter) { + if session.receive_windows[key_id as usize].message_received(counter) { let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); if fragment_count > 1 { if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { let mut defrag = session.defrag.lock().unwrap(); - let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); + // by using the counter + the key_id as the key we can prevent packet collisions, this only works if defrag hashes + let fragment_gather_array = defrag.get_or_create_mut(&raw_counter, || GatherArray::new(fragment_count)); if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock return self.receive_complete( @@ -672,7 +676,7 @@ impl ReceiveContext { session_key.return_receive_cipher(c); if aead_authentication_ok { - if session.receive_window[key_id as usize].message_authenticated(counter) { + if session.receive_windows[key_id as usize].message_authenticated(counter) { if packet_type == PACKET_TYPE_DATA { return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); } else { @@ -832,7 +836,7 @@ impl ReceiveContext { // Perform checks and match ratchet key if there's an existing session, or gate (via host) and // then create new sessions. - let (new_session, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() { + let (new_session, reply_counter, new_key_id, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() { // Existing session identity must match the one in this offer. if !secure_eq(&session.remote_s_public_blob_hash, &SHA384::hash(&alice_s_public_blob)) { return Err(Error::FailedAuthentication); @@ -843,10 +847,12 @@ impl ReceiveContext { let mut ratchet_key = None; let mut last_ratchet_count = 0; let state = session.state.read().unwrap(); - if state.cur_session_key_id == key_id { - return Ok(ReceiveResult::Ignored); // alice is requesting to overwrite the current key, reject it + //key_id here is the key id of the key being rekeyed and replaced + //it must be equal to the current session key, and not the previous session key + if state.cur_session_key_id != key_id { + return Ok(ReceiveResult::Ignored); } - if let Some(k) = state.session_keys[state.cur_session_key_id as usize].as_ref() { + if let Some(k) = state.session_keys[key_id as usize].as_ref() { if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { ratchet_key = Some(k.ratchet_key.clone()); last_ratchet_count = k.ratchet_count; @@ -856,34 +862,41 @@ impl ReceiveContext { return Ok(ReceiveResult::Ignored); // old packet? } - (None, ratchet_key, last_ratchet_count) + (None, session.send_counters[key_id as usize].next(), !key_id, ratchet_key, last_ratchet_count) } else { + if key_id != false { + return Ok(ReceiveResult::Ignored); + } if let Some((new_session_id, psk, associated_object)) = app.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata) { let header_check_cipher = Aes::new( kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::(), ); + let send_counter = Counter::new(); + let reply_counter = send_counter.next(); ( Some(Session:: { id: new_session_id, application_data: associated_object, - send_counter: Counter::new(), - receive_window: [CounterWindow::new_invalid(), CounterWindow::new_invalid()], - psk, - noise_ss, header_check_cipher, + receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], + send_counters: [send_counter, Counter::new()], state: RwLock::new(SessionMutableState { remote_session_id: Some(alice_session_id), - session_keys: [None, None], - cur_session_key_id: key_id, + session_keys: [None, None],//this is the only value which will be writen later + cur_session_key_id: false, offer: None, last_remote_offer: current_time, }), + psk, + noise_ss, remote_s_public_blob_hash: SHA384::hash(&alice_s_public_blob), remote_s_public_p384_bytes: alice_s_public.as_bytes().clone(), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), }), + reply_counter, + false, None, 0, ) @@ -946,7 +959,6 @@ impl ReceiveContext { //////////////////////////////////////////////////////////////// let mut reply_buf = [0_u8; KEX_BUF_LEN]; - let reply_counter = CounterValue::get_initial_offer_counter(); let mut idx = HEADER_SIZE; idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?; @@ -1016,33 +1028,40 @@ impl ReceiveContext { ); idx = safe_write_all(&mut reply_buf, idx, &hmac)?; let packet_end = idx; + if session.receive_windows[key_id as usize].message_authenticated(counter) { + //initial key offers should only check this if this is a rekey + let session_key = SessionKey::new( + session_key, + Role::Bob, + current_time, + last_ratchet_count + 1, + hybrid_kk.is_some(), + ); - let session_key = SessionKey::new( - session_key, - Role::Bob, - current_time, - last_ratchet_count + 1, - hybrid_kk.is_some(), - ); + //TODO: check for correct orderings + let mut state = session.state.write().unwrap(); + let _ = state.session_keys[new_key_id as usize].replace(session_key); + if existing_session.is_some() { + let _ = state.remote_session_id.replace(alice_session_id); + state.cur_session_key_id = new_key_id; + session.send_counters[new_key_id as usize].reset_for_initial_offer(); + session.receive_windows[new_key_id as usize].reset_for_initial_offer(); + } + drop(state); - let mut state = session.state.write().unwrap(); - let _ = state.remote_session_id.replace(alice_session_id); - let _ = state.session_keys[key_id as usize].replace(session_key); - session.send_counter.reset_after_initial_offer(); - state.cur_session_key_id = key_id; - session.receive_window[key_id as usize].reset_after_initial_offer(); - session.receive_window[key_id as usize].message_authenticated(counter); - drop(state); + // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. - // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. + send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher); - send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher); - - if let Some(new_session) = new_session { - return Ok(ReceiveResult::OkNewSession(new_session)); + if let Some(new_session) = new_session { + return Ok(ReceiveResult::OkNewSession(new_session)); + } else { + return Ok(ReceiveResult::Ok); + } } else { - return Ok(ReceiveResult::Ok); + return Ok(ReceiveResult::Ignored); } + } PACKET_TYPE_KEY_COUNTER_OFFER => { @@ -1096,8 +1115,9 @@ impl ReceiveContext { parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?; // Check that this is a counter offer to the original offer we sent. - if !(offer.id.eq(offer_id) & (offer.key_id == key_id)) { + if !(offer.id.eq(offer_id)) { return Ok(ReceiveResult::Ignored); + } // Kyber1024 key agreement if enabled. @@ -1138,27 +1158,36 @@ impl ReceiveContext { ) { return Err(Error::FailedAuthentication); } + if session.receive_windows[key_id as usize].message_authenticated(counter) { + // Alice has now completed and validated the full hybrid exchange. - // Alice has now completed and validated the full hybrid exchange. + let session_key = SessionKey::new( + session_key, + Role::Alice, + current_time, + last_ratchet_count + 1, + hybrid_kk.is_some(), + ); - let session_key = SessionKey::new( - session_key, - Role::Alice, - current_time, - last_ratchet_count + 1, - hybrid_kk.is_some(), - ); + let new_key_id = offer.key_id; + let is_new_session = offer.ratchet_count == 0; + drop(state); + //TODO: check for correct orderings + let mut state = session.state.write().unwrap(); + let _ = state.remote_session_id.replace(bob_session_id); + let _ = state.session_keys[new_key_id as usize].replace(session_key); + if !is_new_session { + //when an brand new key offer is sent, it is sent using the new_key_id==false counter, we cannot reset it in that case. + state.cur_session_key_id = new_key_id; + session.send_counters[new_key_id as usize].reset_for_initial_offer(); + session.receive_windows[new_key_id as usize].reset_for_initial_offer(); + } + let _ = state.offer.take(); - drop(state); - let mut state = session.state.write().unwrap(); - let _ = state.remote_session_id.replace(bob_session_id); - let _ = state.session_keys[key_id as usize].replace(session_key); - session.send_counter.reset_after_initial_offer(); - state.cur_session_key_id = key_id; - session.receive_window[key_id as usize].message_authenticated(counter); - let _ = state.offer.take(); - - return Ok(ReceiveResult::Ok); + return Ok(ReceiveResult::Ok); + } else { + return Ok(ReceiveResult::Ignored); + } } } @@ -1174,10 +1203,12 @@ impl ReceiveContext { } /// Create an send an ephemeral offer, populating ret_ephemeral_offer on success. +/// If there is no current session key set `current_key_id == new_key_id == false` fn send_ephemeral_offer( send: &mut SendFunction, counter: CounterValue, - key_id: bool, + current_key_id: bool, + new_key_id: bool, alice_session_id: SessionId, bob_session_id: Option, alice_s_public_blob: &[u8], @@ -1263,7 +1294,7 @@ fn send_ephemeral_offer( PACKET_TYPE_INITIAL_KEY_OFFER, bob_session_id, counter, - key_id, + current_key_id, )?; let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32()); @@ -1317,7 +1348,7 @@ fn send_ephemeral_offer( *ret_ephemeral_offer = Some(EphemeralOffer { id, - key_id, + key_id: new_key_id, creation_time: current_time, ratchet_count, ratchet_key,