From 40945cf6c92bc0cb07aee25d716b3c4255481a6c Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Thu, 2 Mar 2023 19:09:31 -0500 Subject: [PATCH] Rework defragmentation, and it now tolerates very poor link quality pretty well. --- utils/src/gatherarray.rs | 6 +- zssp/src/applicationlayer.rs | 8 + zssp/src/fragged.rs | 103 ++++++++ zssp/src/lib.rs | 1 + zssp/src/main.rs | 164 +++++++----- zssp/src/proto.rs | 25 +- zssp/src/zssp.rs | 500 ++++++++++++++++++++--------------- 7 files changed, 514 insertions(+), 293 deletions(-) create mode 100644 zssp/src/fragged.rs diff --git a/utils/src/gatherarray.rs b/utils/src/gatherarray.rs index 091a77354..b0b0a3539 100644 --- a/utils/src/gatherarray.rs +++ b/utils/src/gatherarray.rs @@ -38,7 +38,7 @@ impl GatherArray { /// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here. #[inline(always)] - pub fn add(&mut self, index: u8, value: T) -> Option> { + pub fn add_return_when_satisfied(&mut self, index: u8, value: T) -> Option> { if index < self.goal { let mut have = self.have_bits; let got = 1u64.wrapping_shl(index as u32); @@ -91,9 +91,9 @@ mod tests { for goal in 2u8..64u8 { let mut m = GatherArray::::new(goal); for x in 0..(goal - 1) { - assert!(m.add(x, x).is_none()); + assert!(m.add_return_when_satisfied(x, x).is_none()); } - let r = m.add(goal - 1, goal - 1).unwrap(); + let r = m.add_return_when_satisfied(goal - 1, goal - 1).unwrap(); for x in 0..goal { assert_eq!(r.as_ref()[x as usize], x); } diff --git a/zssp/src/applicationlayer.rs b/zssp/src/applicationlayer.rs index b9a1f1723..1c89ff72f 100644 --- a/zssp/src/applicationlayer.rs +++ b/zssp/src/applicationlayer.rs @@ -6,6 +6,8 @@ * https://www.zerotier.com/ */ +use std::hash::Hash; + use zerotier_crypto::p384::P384KeyPair; /// Trait to implement to integrate the session into an application. @@ -65,6 +67,12 @@ pub trait ApplicationLayer: Sized { /// for a short period of time when assembling fragmented packets on the receive path. type IncomingPacketBuffer: AsRef<[u8]> + AsMut<[u8]>; + /// Opaque type for whatever constitutes a physical path to the application. + /// + /// A physical path could be an IP address or IP plus device in the case of UDP, a socket in the + /// case of TCP, etc. + type PhysicalPath: PartialEq + Eq + Hash + Clone; + /// Get a reference to this host's static public key blob. /// /// This must contain a NIST P-384 public key but can contain other information. In ZeroTier this diff --git a/zssp/src/fragged.rs b/zssp/src/fragged.rs new file mode 100644 index 000000000..e1b0629e6 --- /dev/null +++ b/zssp/src/fragged.rs @@ -0,0 +1,103 @@ +use std::mem::{needs_drop, size_of, zeroed, MaybeUninit}; +use std::ptr::slice_from_raw_parts; + +/// Fast packet defragmenter +pub struct Fragged { + have: u64, + counter: u64, + frags: [MaybeUninit; MAX_FRAGMENTS], +} + +pub struct Assembled([MaybeUninit; MAX_FRAGMENTS], usize); + +impl AsRef<[Fragment]> for Assembled { + #[inline(always)] + fn as_ref(&self) -> &[Fragment] { + unsafe { &*slice_from_raw_parts(self.0.as_ptr().cast::(), self.1) } + } +} + +impl Drop for Assembled { + #[inline(always)] + fn drop(&mut self) { + for i in 0..self.1 { + unsafe { + self.0.get_unchecked_mut(i).assume_init_drop(); + } + } + } +} + +impl Fragged { + pub fn new() -> Self { + debug_assert!(MAX_FRAGMENTS <= 64); + debug_assert_eq!(size_of::>(), size_of::()); + debug_assert_eq!( + size_of::<[MaybeUninit; MAX_FRAGMENTS]>(), + size_of::<[Fragment; MAX_FRAGMENTS]>() + ); + unsafe { zeroed() } + } + + pub fn assemble( + &mut self, + counter: u64, + fragment: Fragment, + fragment_no: u8, + fragment_count: u8, + ) -> Option> { + if fragment_no < fragment_count && (fragment_count as usize) <= MAX_FRAGMENTS { + debug_assert!((fragment_count as usize) <= MAX_FRAGMENTS); + debug_assert!((fragment_no as usize) < MAX_FRAGMENTS); + + let mut have = self.have; + if counter != self.counter { + self.counter = counter; + if needs_drop::() { + let mut i = 0; + while have != 0 { + if (have & 1) != 0 { + debug_assert!(i < MAX_FRAGMENTS); + unsafe { self.frags.get_unchecked_mut(i).assume_init_drop() }; + } + have = have.wrapping_shr(1); + i += 1; + } + } else { + have = 0; + } + } + + unsafe { + self.frags.get_unchecked_mut(fragment_no as usize).write(fragment); + } + + let want = 0xffffffffffffffffu64.wrapping_shr((64 - fragment_count) as u32); + have |= 1u64.wrapping_shl(fragment_no as u32); + if (have & want) == want { + self.have = 0; + return Some(Assembled(unsafe { std::mem::transmute_copy(&self.frags) }, fragment_count as usize)); + } else { + self.have = have; + } + } + return None; + } +} + +impl Drop for Fragged { + fn drop(&mut self) { + if needs_drop::() { + let mut have = self.have; + let mut i = 0; + while have != 0 { + if (have & 1) != 0 { + debug_assert!(i < MAX_FRAGMENTS); + unsafe { self.frags.get_unchecked_mut(i).assume_init_drop() }; + } + have = have.wrapping_shr(1); + i += 1; + } + } + } +} diff --git a/zssp/src/lib.rs b/zssp/src/lib.rs index 718e5adb4..d4b41ca62 100644 --- a/zssp/src/lib.rs +++ b/zssp/src/lib.rs @@ -8,6 +8,7 @@ mod applicationlayer; mod error; +mod fragged; mod proto; mod sessionid; mod zssp; diff --git a/zssp/src/main.rs b/zssp/src/main.rs index 0567541db..02b86b6ae 100644 --- a/zssp/src/main.rs +++ b/zssp/src/main.rs @@ -1,9 +1,12 @@ +use std::iter::ExactSizeIterator; +use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc; use std::thread; use std::time::Duration; use zerotier_crypto::p384::{P384KeyPair, P384PublicKey}; +use zerotier_crypto::random; use zerotier_crypto::secret::Secret; use zerotier_utils::hex; use zerotier_utils::ms_monotonic; @@ -23,8 +26,8 @@ impl zssp::ApplicationLayer for TestApplication { const RETRY_INTERVAL: i64 = 500; type Data = (); - type IncomingPacketBuffer = Vec; + type PhysicalPath = usize; fn get_local_s_public_blob(&self) -> &[u8] { self.identity_key.public_key_bytes() @@ -37,6 +40,7 @@ impl zssp::ApplicationLayer for TestApplication { fn alice_main( run: &AtomicBool, + packet_success_rate: u32, alice_app: &TestApplication, bob_app: &TestApplication, alice_out: mpsc::SyncSender>, @@ -46,7 +50,7 @@ fn alice_main( let mut data_buf = [0u8; 65536]; let mut next_service = ms_monotonic() + 500; let mut last_ratchet_count = 0; - let test_data = [1u8; 10000]; + let test_data = [1u8; TEST_MTU * 10]; let mut up = false; let alice_session = context @@ -71,31 +75,35 @@ fn alice_main( loop { let pkt = alice_in.try_recv(); if let Ok(pkt) = pkt { - //println!("bob >> alice {}", pkt.len()); - match context.receive( - alice_app, - || true, - |s_public, _| Some((P384PublicKey::from_bytes(s_public).unwrap(), Secret::default(), ())), - |_, b| { - let _ = alice_out.send(b.to_vec()); - }, - &mut data_buf, - pkt, - TEST_MTU, - current_time, - ) { - Ok(zssp::ReceiveResult::Ok) => { - //println!("[alice] ok"); - } - Ok(zssp::ReceiveResult::OkData(_, _)) => { - //println!("[alice] received {}", data.len()); - } - Ok(zssp::ReceiveResult::OkNewSession(s)) => { - println!("[alice] new session {}", s.id.to_string()); - } - Ok(zssp::ReceiveResult::Rejected) => {} - Err(e) => { - println!("[alice] ERROR {}", e.to_string()); + if (random::xorshift64_random() as u32) <= packet_success_rate { + match context.receive( + alice_app, + || true, + |s_public, _| Some((P384PublicKey::from_bytes(s_public).unwrap(), Secret::default(), ())), + |_, b| { + let _ = alice_out.send(b.to_vec()); + }, + &0, + &mut data_buf, + pkt, + TEST_MTU, + current_time, + ) { + Ok(zssp::ReceiveResult::Ok) => { + //println!("[alice] ok"); + } + Ok(zssp::ReceiveResult::OkData(_, _)) => { + //println!("[alice] received {}", data.len()); + } + Ok(zssp::ReceiveResult::OkNewSession(s)) => { + println!("[alice] new session {}", s.id.to_string()); + } + Ok(zssp::ReceiveResult::Rejected) => {} + Err(e) => { + println!("[alice] ERROR {}", e.to_string()); + //run.store(false, Ordering::SeqCst); + //break; + } } } } else { @@ -116,12 +124,14 @@ fn alice_main( let _ = alice_out.send(b.to_vec()); }, &mut data_buf[..TEST_MTU], - &test_data[..1400 + ((zerotier_crypto::random::xorshift64_random() as usize) % (test_data.len() - 1400))], + &test_data[..1400 + ((random::xorshift64_random() as usize) % (test_data.len() - 1400))], ) .is_ok()); } else { if alice_session.established() { up = true; + } else { + thread::sleep(Duration::from_millis(10)); } } @@ -140,6 +150,7 @@ fn alice_main( fn bob_main( run: &AtomicBool, + packet_success_rate: u32, _alice_app: &TestApplication, bob_app: &TestApplication, bob_out: mpsc::SyncSender>, @@ -160,42 +171,46 @@ fn bob_main( let current_time = ms_monotonic(); if let Ok(pkt) = pkt { - //println!("alice >> bob {}", pkt.len()); - match context.receive( - bob_app, - || true, - |s_public, _| Some((P384PublicKey::from_bytes(s_public).unwrap(), Secret::default(), ())), - |_, b| { - let _ = bob_out.send(b.to_vec()); - }, - &mut data_buf, - pkt, - TEST_MTU, - current_time, - ) { - Ok(zssp::ReceiveResult::Ok) => { - //println!("[bob] ok"); - } - Ok(zssp::ReceiveResult::OkData(s, data)) => { - //println!("[bob] received {}", data.len()); - assert!(s - .send( - |b| { - let _ = bob_out.send(b.to_vec()); - }, - &mut data_buf_2, - data.as_mut(), - ) - .is_ok()); - transferred += data.len() as u64 * 2; // *2 because we are also sending this many bytes back - } - Ok(zssp::ReceiveResult::OkNewSession(s)) => { - println!("[bob] new session {}", s.id.to_string()); - let _ = bob_session.replace(s); - } - Ok(zssp::ReceiveResult::Rejected) => {} - Err(e) => { - println!("[bob] ERROR {}", e.to_string()); + if (random::xorshift64_random() as u32) <= packet_success_rate { + match context.receive( + bob_app, + || true, + |s_public, _| Some((P384PublicKey::from_bytes(s_public).unwrap(), Secret::default(), ())), + |_, b| { + let _ = bob_out.send(b.to_vec()); + }, + &0, + &mut data_buf, + pkt, + TEST_MTU, + current_time, + ) { + Ok(zssp::ReceiveResult::Ok) => { + //println!("[bob] ok"); + } + Ok(zssp::ReceiveResult::OkData(s, data)) => { + //println!("[bob] received {}", data.len()); + assert!(s + .send( + |b| { + let _ = bob_out.send(b.to_vec()); + }, + &mut data_buf_2, + data.as_mut(), + ) + .is_ok()); + transferred += data.len() as u64 * 2; // *2 because we are also sending this many bytes back + } + Ok(zssp::ReceiveResult::OkNewSession(s)) => { + println!("[bob] new session {}", s.id.to_string()); + let _ = bob_session.replace(s); + } + Ok(zssp::ReceiveResult::Rejected) => {} + Err(e) => { + println!("[bob] ERROR {}", e.to_string()); + //run.store(false, Ordering::SeqCst); + //break; + } } } } @@ -211,10 +226,12 @@ fn bob_main( let speed_metric_elapsed = current_time - last_speed_metric; if speed_metric_elapsed >= 1000 { last_speed_metric = current_time; - println!( - "[bob] throughput: {} MiB/sec (combined input and output)", - ((transferred as f64) / 1048576.0) / ((speed_metric_elapsed as f64) / 1000.0) - ); + if transferred > 0 { + println!( + "[bob] throughput: {} MiB/sec (combined input and output)", + ((transferred as f64) / 1048576.0) / ((speed_metric_elapsed as f64) / 1000.0) + ); + } transferred = 0; } @@ -240,9 +257,16 @@ fn main() { let (alice_out, bob_in) = mpsc::sync_channel::>(1024); let (bob_out, alice_in) = mpsc::sync_channel::>(1024); + let args = std::env::args(); + let packet_success_rate = if args.len() <= 1 { + u32::MAX + } else { + ((u32::MAX as f64) * f64::from_str(args.last().unwrap().as_str()).unwrap()) as u32 + }; + thread::scope(|ts| { - let alice_thread = ts.spawn(|| alice_main(&run, &alice_app, &bob_app, alice_out, alice_in)); - let bob_thread = ts.spawn(|| bob_main(&run, &alice_app, &bob_app, bob_out, bob_in)); + let alice_thread = ts.spawn(|| alice_main(&run, packet_success_rate, &alice_app, &bob_app, alice_out, alice_in)); + let bob_thread = ts.spawn(|| bob_main(&run, packet_success_rate, &alice_app, &bob_app, bob_out, bob_in)); thread::sleep(Duration::from_secs(60 * 10)); diff --git a/zssp/src/proto.rs b/zssp/src/proto.rs index 32510bd44..ec8ce9991 100644 --- a/zssp/src/proto.rs +++ b/zssp/src/proto.rs @@ -24,24 +24,29 @@ pub const MIN_TRANSPORT_MTU: usize = 128; /// Maximum combined size of static public blob and metadata. pub const MAX_INIT_PAYLOAD_SIZE: usize = MAX_NOISE_HANDSHAKE_SIZE - ALICE_NOISE_XK_ACK_MIN_SIZE; +/// Version 0: Noise_XK with NIST P-384 plus Kyber1024 hybrid exchange on session init. pub(crate) const SESSION_PROTOCOL_VERSION: u8 = 0x00; -pub(crate) const COUNTER_WINDOW_MAX_OOO: usize = 16; +/// Maximum window over which packets may be reordered. +pub(crate) const COUNTER_WINDOW_MAX_OOO: usize = 32; + +/// Maximum number of counter steps that the counter is allowed to skip ahead. pub(crate) const COUNTER_WINDOW_MAX_SKIP_AHEAD: u64 = 16777216; -pub(crate) const PACKET_TYPE_DATA: u8 = 0; -pub(crate) const PACKET_TYPE_ALICE_NOISE_XK_INIT: u8 = 1; -pub(crate) const PACKET_TYPE_BOB_NOISE_XK_ACK: u8 = 2; -pub(crate) const PACKET_TYPE_ALICE_NOISE_XK_ACK: u8 = 3; -pub(crate) const PACKET_TYPE_REKEY_INIT: u8 = 4; -pub(crate) const PACKET_TYPE_REKEY_ACK: u8 = 5; +pub(crate) const PACKET_TYPE_NOP: u8 = 0; +pub(crate) const PACKET_TYPE_DATA: u8 = 1; +pub(crate) const PACKET_TYPE_ALICE_NOISE_XK_INIT: u8 = 2; +pub(crate) const PACKET_TYPE_BOB_NOISE_XK_ACK: u8 = 3; +pub(crate) const PACKET_TYPE_ALICE_NOISE_XK_ACK: u8 = 4; +pub(crate) const PACKET_TYPE_REKEY_INIT: u8 = 5; +pub(crate) const PACKET_TYPE_REKEY_ACK: u8 = 6; pub(crate) const HEADER_SIZE: usize = 16; pub(crate) const HEADER_PROTECT_ENCRYPT_START: usize = 6; pub(crate) const HEADER_PROTECT_ENCRYPT_END: usize = 22; -pub(crate) const KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION: u8 = b'X'; // intermediate keys used in key exchanges -pub(crate) const KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION: u8 = b'x'; // intermediate keys used in key exchanges +pub(crate) const KBKDF_KEY_USAGE_LABEL_INIT_ENCRYPTION: u8 = b'x'; // AES-CTR encryption during initial setup +pub(crate) const KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION: u8 = b'X'; // HMAC-SHA384 during initial setup pub(crate) const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; // AES-GCM in A->B direction pub(crate) const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; // AES-GCM in B->A direction pub(crate) const KBKDF_KEY_USAGE_LABEL_RATCHET: u8 = b'R'; // Key used in derivatin of next session key @@ -50,6 +55,7 @@ pub(crate) const MAX_FRAGMENTS: usize = 48; // hard protocol max: 63 pub(crate) const MAX_NOISE_HANDSHAKE_FRAGMENTS: usize = 16; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc. pub(crate) const MAX_NOISE_HANDSHAKE_SIZE: usize = MAX_NOISE_HANDSHAKE_FRAGMENTS * MIN_TRANSPORT_MTU; +/// Size of keys used during derivation, mixing, etc. process. pub(crate) const BASE_KEY_SIZE: usize = 64; pub(crate) const AES_256_KEY_SIZE: usize = 32; @@ -162,7 +168,6 @@ impl RekeyAck { pub(crate) trait ProtocolFlatBuffer {} impl ProtocolFlatBuffer for AliceNoiseXKInit {} impl ProtocolFlatBuffer for BobNoiseXKAck {} -//impl ProtocolFlatBuffer for NoiseXKAliceStaticAck {} impl ProtocolFlatBuffer for RekeyInit {} impl ProtocolFlatBuffer for RekeyAck {} diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 423d48828..b510cdba0 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -15,19 +15,18 @@ use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; use std::sync::{Arc, Mutex, RwLock, Weak}; use zerotier_crypto::aes::{Aes, AesCtr, AesGcm}; -use zerotier_crypto::hash::{hmac_sha512, HMACSHA384, HMAC_SHA384_SIZE, SHA384, SHA384_HASH_SIZE}; +use zerotier_crypto::hash::{hmac_sha512, HMACSHA384, HMAC_SHA384_SIZE, SHA384}; use zerotier_crypto::p384::{P384KeyPair, P384PublicKey, P384_ECDH_SHARED_SECRET_SIZE}; use zerotier_crypto::secret::Secret; use zerotier_crypto::{random, secure_eq}; use zerotier_utils::arrayvec::ArrayVec; -use zerotier_utils::gatherarray::GatherArray; -use zerotier_utils::ringbuffermap::RingBufferMap; -use pqc_kyber::{KYBER_CIPHERTEXTBYTES, KYBER_SECRETKEYBYTES, KYBER_SSBYTES}; +use pqc_kyber::{KYBER_SECRETKEYBYTES, KYBER_SSBYTES}; use crate::applicationlayer::ApplicationLayer; use crate::error::Error; +use crate::fragged::Fragged; use crate::proto::*; use crate::sessionid::SessionId; @@ -37,7 +36,12 @@ use crate::sessionid::SessionId; /// defragment incoming packets that are not yet associated with a session. pub struct Context { max_incomplete_session_queue_size: usize, - defrag: Mutex, 256, 256>>, + defrag: Mutex< + HashMap< + (Application::PhysicalPath, u64), + Arc, i64)>>, + >, + >, sessions: RwLock>, } @@ -80,7 +84,7 @@ pub struct Session { receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO], header_protection_cipher: Aes, state: RwLock, - defrag: Mutex, 16, 16>>, + defrag: [Mutex>; COUNTER_WINDOW_MAX_OOO], } /// Most of the mutable parts of a session state. @@ -91,20 +95,16 @@ struct State { current_offer: Offer, } -/// State related to an incoming session not yet fully established. struct IncomingIncompleteSession { timestamp: i64, - request_hash: [u8; SHA384_HASH_SIZE], alice_session_id: SessionId, bob_session_id: SessionId, noise_es_ee: Secret, - bob_hk_ciphertext: [u8; KYBER_CIPHERTEXTBYTES], hk: Secret, header_protection_key: Secret, bob_noise_e_secret: P384KeyPair, } -/// State related to an outgoing session attempt. struct OutgoingSessionInit { last_retry_time: AtomicI64, alice_noise_e_secret: P384KeyPair, @@ -114,14 +114,19 @@ struct OutgoingSessionInit { init_packet: [u8; AliceNoiseXKInit::SIZE], } -/// Latest outgoing offer, either an outgoing attempt or a rekey attempt. +struct OutgoingSessionAck { + last_retry_time: AtomicI64, + ack: [u8; MAX_NOISE_HANDSHAKE_SIZE], + ack_size: usize, +} + enum Offer { None, NoiseXKInit(Box), - RekeyInit(P384KeyPair, [u8; RekeyInit::SIZE], AtomicI64), + NoiseXKAck(Box), + RekeyInit(P384KeyPair, i64), } -/// An ephemeral session key with expiration info. struct SessionKey { ratchet_key: Secret, // Key used in derivation of the next session key receive_key: Secret, // Receive side AES-GCM key @@ -134,6 +139,7 @@ struct SessionKey { expire_at_counter: u64, // Hard error when this counter value is reached or exceeded ratchet_count: u64, // Number of rekey events bob: bool, // Was this side "Bob" in this exchange? + confirmed: bool, // Is this key confirmed by the other side? } impl Context { @@ -141,7 +147,7 @@ impl Context { pub fn new(max_incomplete_session_queue_size: usize) -> Self { Self { max_incomplete_session_queue_size, - defrag: Mutex::new(RingBufferMap::new(random::next_u32_secure())), + defrag: Mutex::new(HashMap::new()), sessions: RwLock::new(SessionsById { active: HashMap::with_capacity(64), incoming: HashMap::with_capacity(64), @@ -173,19 +179,14 @@ impl Context { for (id, s) in sessions.active.iter() { if let Some(session) = s.upgrade() { let state = session.state.read().unwrap(); - match &state.current_offer { - Offer::None => { - if let Some(key) = state.keys[state.current_key].as_ref() { - if key.bob - && (current_time >= key.rekey_at_time - || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) - { - drop(state); - session.initiate_rekey(|b| send(&session, b), current_time); - } - } - } + if match &state.current_offer { + Offer::None => true, Offer::NoiseXKInit(offer) => { + // If there's an outstanding attempt to open a session, retransmit this periodically + // in case the initial packet doesn't make it. Note that we currently don't have + // retransmission for the intermediate steps, so a new session may still fail if the + // packet loss rate is huge. The application layer has its own logic to keep trying + // under those conditions. if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { offer.last_retry_time.store(current_time, Ordering::Relaxed); let _ = send_with_fragmentation( @@ -199,11 +200,37 @@ impl Context { None, ); } + false } - Offer::RekeyInit(_, rekey_packet, last_retry_time) => { - if last_retry_time.load(Ordering::Relaxed) < retry_cutoff { - last_retry_time.store(current_time, Ordering::Relaxed); - send(&session, &mut (rekey_packet.clone())); + Offer::NoiseXKAck(ack) => { + // We also keep retransmitting the final ACK until we get a valid DATA or NOP packet + // from Bob, otherwise we could get a half open session. + if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { + ack.last_retry_time.store(current_time, Ordering::Relaxed); + let _ = send_with_fragmentation( + |b| send(&session, b), + &mut (ack.ack.clone())[..ack.ack_size], + mtu, + PACKET_TYPE_ALICE_NOISE_XK_ACK, + state.remote_session_id, + 0, + 2, + Some(&session.header_protection_cipher), + ); + } + false + } + Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff, + } { + // Check whether we need to rekey if there is no pending offer or if the last rekey + // offer was before retry_cutoff (checked in the 'match' above). + if let Some(key) = state.keys[state.current_key].as_ref() { + if key.bob + && (current_time >= key.rekey_at_time + || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) + { + drop(state); + session.initiate_rekey(|b| send(&session, b), current_time); } } } @@ -297,7 +324,7 @@ impl Context { init_packet: [0u8; AliceNoiseXKInit::SIZE], })), }), - defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), + defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), }); sessions.active.insert(local_session_id, Arc::downgrade(&session)); @@ -310,7 +337,7 @@ impl Context { let init_packet = if let Offer::NoiseXKInit(offer) = &mut state.current_offer { &mut offer.init_packet } else { - panic!(); + panic!(); // should be impossible }; let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(init_packet).unwrap(); @@ -321,12 +348,12 @@ impl Context { init.header_protection_key = header_protection_key.0; aes_ctr_crypt_one_time_use_key( - kbkdf::(noise_es.as_bytes()).as_bytes(), + kbkdf::(noise_es.as_bytes()).as_bytes(), &mut init_packet[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START], ); let hmac = hmac_sha384_2( - kbkdf::(noise_es.as_bytes()).as_bytes(), + kbkdf::(noise_es.as_bytes()).as_bytes(), &create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_INIT, 1), &init_packet[HEADER_SIZE..AliceNoiseXKInit::AUTH_START], ); @@ -385,6 +412,7 @@ impl Context { mut check_allow_incoming_session: CheckAllowIncomingSession, mut check_accept_session: CheckAcceptSession, mut send: SendFunction, + source: &Application::PhysicalPath, data_buf: &'b mut [u8], mut incoming_packet_buf: Application::IncomingPacketBuffer, mtu: usize, @@ -414,31 +442,27 @@ impl Context { if session.check_receive_window(incoming_counter) { if fragment_count > 1 { - if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { - let mut defrag = session.defrag.lock().unwrap(); - let fragment_gather_array = defrag.get_or_create_mut(&incoming_counter, || GatherArray::new(fragment_count)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { - drop(defrag); // release lock - return self.process_complete_incoming_packet( - app, - &mut send, - &mut check_allow_incoming_session, - &mut check_accept_session, - data_buf, - incoming_counter, - assembled_packet.as_ref(), - packet_type, - Some(session), - None, - key_index, - mtu, - current_time, - ); - } else { - return Ok(ReceiveResult::Ok); - } + let mut fragged = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO].lock().unwrap(); + if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) + { + drop(fragged); + return self.process_complete_incoming_packet( + app, + &mut send, + &mut check_allow_incoming_session, + &mut check_accept_session, + data_buf, + incoming_counter, + assembled_packet.as_ref(), + packet_type, + Some(session), + None, + key_index, + mtu, + current_time, + ); } else { - return Err(Error::InvalidPacket); + return Ok(ReceiveResult::Ok); } } else { return self.process_complete_incoming_packet( @@ -476,10 +500,19 @@ impl Context { let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); if fragment_count > 1 { - let mut defrag = self.defrag.lock().unwrap(); - let fragment_gather_array = defrag.get_or_create_mut(&incoming_counter, || GatherArray::new(fragment_count)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { - drop(defrag); // release lock + let fragged_m = { + let mut defrag = self.defrag.lock().unwrap(); + defrag + .entry((source.clone(), incoming_counter)) + .or_insert_with(|| Arc::new(Mutex::new((Fragged::new(), current_time)))) + .clone() + }; + let mut fragged = fragged_m.lock().unwrap(); + if let Some(assembled_packet) = fragged + .0 + .assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) + { + self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter)); return self.process_complete_incoming_packet( app, &mut send, @@ -543,7 +576,7 @@ impl Context { // Generate incoming message nonce for decryption and authentication. let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter); - if packet_type == PACKET_TYPE_DATA { + if packet_type <= PACKET_TYPE_DATA { if let Some(session) = session { let state = session.state.read().unwrap(); if let Some(key) = state.keys[key_index].as_ref() { @@ -590,22 +623,41 @@ impl Context { // Update the current key to point to this key if it's newer, since having received // a packet encrypted with it proves that the other side has successfully derived it // as well. - if state.current_key == key_index { + if state.current_key == key_index && key.confirmed { drop(state); } else { - let key_created_at_counter = key.created_at_counter; + let current_key_created_at_counter = key.created_at_counter; + drop(state); let mut state = session.state.write().unwrap(); - if let Some(other_session_key) = state.keys[state.current_key].as_ref() { - if other_session_key.created_at_counter < key_created_at_counter { + + if state.current_key != key_index { + if let Some(other_session_key) = state.keys[state.current_key].as_ref() { + if other_session_key.created_at_counter < current_key_created_at_counter { + state.current_key = key_index; + } + } else { state.current_key = key_index; } } else { - state.current_key = key_index; + state.keys[key_index].as_mut().unwrap().confirmed = true; + } + + // If we got a valid data packet from Bob, this means we can cancel any offers + // that are still oustanding for initialization. + match &state.current_offer { + Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => { + state.current_offer = Offer::None; + } + _ => {} } } - return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len])); + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len])); + } else { + println!("nop"); + } } else { return Err(Error::OutOfSequence); } @@ -650,120 +702,104 @@ impl Context { if incoming_counter != 1 || session.is_some() { return Err(Error::OutOfSequence); } - - // Hash the init packet so we can check to see if it's just being retransmitted. Alice may - // attempt to retransmit this packet until she receives a response. - let request_hash = SHA384::hash(&pkt_assembled); - - let (alice_session_id, mut bob_session_id, noise_es_ee, bob_hk_ciphertext, header_protection_key, bob_noise_e); - if let Some(incoming) = incoming { - // If we've already seen this exact packet before, just recall the same state so we send the - // same response. - if secure_eq(&request_hash, &incoming.request_hash) { - alice_session_id = incoming.alice_session_id; - bob_session_id = incoming.bob_session_id; - noise_es_ee = incoming.noise_es_ee.clone(); - bob_hk_ciphertext = incoming.bob_hk_ciphertext; - header_protection_key = incoming.header_protection_key.clone(); - bob_noise_e = *incoming.bob_noise_e_secret.public_key_bytes(); - } else { - return Err(Error::FailedAuthentication); - } - } else { - // Otherwise parse the packet, authenticate, generate keys, etc. and record state in an - // incoming state object until this phase of the negotiation is done. - let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?; - let alice_noise_e = P384PublicKey::from_bytes(&pkt.alice_noise_e).ok_or(Error::FailedAuthentication)?; - let noise_es = app.get_local_s_keypair().agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?; - - // Authenticate packet and also prove that Alice knows our static public key. - if !secure_eq( - &pkt.hmac_es, - &hmac_sha384_2( - kbkdf::(noise_es.as_bytes()).as_bytes(), - &incoming_message_nonce, - &pkt_assembled[HEADER_SIZE..AliceNoiseXKInit::AUTH_START], - ), - ) { - return Err(Error::FailedAuthentication); - } - - // Let application filter incoming connection attempt by whatever criteria it wants. - if !check_allow_incoming_session() { - return Ok(ReceiveResult::Rejected); - } - - // Decrypt encrypted part of payload. - aes_ctr_crypt_one_time_use_key( - kbkdf::(noise_es.as_bytes()).as_bytes(), - &mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START], - ); - - let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?; - alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?; - header_protection_key = Secret(pkt.header_protection_key); - - // Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create - // a Kyber ciphertext to send back to Alice. - let bob_noise_e_secret = P384KeyPair::generate(); - bob_noise_e = bob_noise_e_secret.public_key_bytes().clone(); - noise_es_ee = Secret(hmac_sha512( - noise_es.as_bytes(), - bob_noise_e_secret - .agree(&alice_noise_e) - .ok_or(Error::FailedAuthentication)? - .as_bytes(), - )); - let (hk_ct, hk) = pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default()) - .map_err(|_| Error::FailedAuthentication) - .map(|(ct, hk)| (ct, Secret(hk)))?; - bob_hk_ciphertext = hk_ct; - - let mut sessions = self.sessions.write().unwrap(); - - loop { - bob_session_id = SessionId::random(); - if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) { - break; - } - } - - if sessions.incoming.len() >= self.max_incomplete_session_queue_size { - // If this queue is too big, we remove the latest entry and replace it. The latest - // is used because under flood conditions this is most likely to be another bogus - // entry. If we find one that is actually timed out, that one is replaced instead. - let mut newest = i64::MIN; - let mut replace_id = None; - let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; - for (id, s) in sessions.incoming.iter() { - if s.timestamp <= cutoff_time { - replace_id = Some(*id); - break; - } else if s.timestamp >= newest { - newest = s.timestamp; - replace_id = Some(*id); - } - } - let _ = sessions.incoming.remove(replace_id.as_ref().unwrap()); - } - - // Reserve session ID on this side and record incomplete session state. - sessions.incoming.insert( - bob_session_id, - Arc::new(IncomingIncompleteSession { - timestamp: current_time, - request_hash, - alice_session_id, - bob_session_id, - noise_es_ee: noise_es_ee.clone(), - bob_hk_ciphertext, - hk, - bob_noise_e_secret, - header_protection_key: Secret(pkt.header_protection_key), - }), - ); + if pkt_assembled.len() != AliceNoiseXKInit::SIZE { + return Err(Error::InvalidPacket); } + // Otherwise parse the packet, authenticate, generate keys, etc. and record state in an + // incoming state object until this phase of the negotiation is done. + let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?; + let alice_noise_e = P384PublicKey::from_bytes(&pkt.alice_noise_e).ok_or(Error::FailedAuthentication)?; + let noise_es = app.get_local_s_keypair().agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?; + + // Authenticate packet and also prove that Alice knows our static public key. + if !secure_eq( + &pkt.hmac_es, + &hmac_sha384_2( + kbkdf::(noise_es.as_bytes()).as_bytes(), + &incoming_message_nonce, + &pkt_assembled[HEADER_SIZE..AliceNoiseXKInit::AUTH_START], + ), + ) { + return Err(Error::FailedAuthentication); + } + + // Let application filter incoming connection attempt by whatever criteria it wants. + if !check_allow_incoming_session() { + return Ok(ReceiveResult::Rejected); + } + + // Decrypt encrypted part of payload. + aes_ctr_crypt_one_time_use_key( + kbkdf::(noise_es.as_bytes()).as_bytes(), + &mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START], + ); + + let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?; + let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?; + let header_protection_key = Secret(pkt.header_protection_key); + + // Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create + // a Kyber ciphertext to send back to Alice. + let bob_noise_e_secret = P384KeyPair::generate(); + let bob_noise_e = bob_noise_e_secret.public_key_bytes().clone(); + let noise_es_ee = Secret(hmac_sha512( + noise_es.as_bytes(), + bob_noise_e_secret + .agree(&alice_noise_e) + .ok_or(Error::FailedAuthentication)? + .as_bytes(), + )); + let (bob_hk_ciphertext, hk) = pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default()) + .map_err(|_| Error::FailedAuthentication) + .map(|(ct, hk)| (ct, Secret(hk)))?; + + let mut sessions = self.sessions.write().unwrap(); + + let mut bob_session_id; + loop { + bob_session_id = SessionId::random(); + if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) { + break; + } + } + + if sessions.incoming.len() >= self.max_incomplete_session_queue_size { + // If this queue is too big, we remove the latest entry and replace it. The latest + // is used because under flood conditions this is most likely to be another bogus + // entry. If we find one that is actually timed out, that one is replaced instead. + let mut newest = i64::MIN; + let mut replace_id = None; + let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; + for (id, s) in sessions.incoming.iter() { + if s.timestamp <= cutoff_time { + replace_id = Some(*id); + break; + } else if s.timestamp >= newest { + newest = s.timestamp; + replace_id = Some(*id); + } + } + let _ = sessions.incoming.remove(replace_id.as_ref().unwrap()); + } + + // Reserve session ID on this side and record incomplete session state. + sessions.incoming.insert( + bob_session_id, + Arc::new(IncomingIncompleteSession { + timestamp: current_time, + alice_session_id, + bob_session_id, + noise_es_ee: noise_es_ee.clone(), + hk, + bob_noise_e_secret, + header_protection_key: Secret(pkt.header_protection_key), + }), + ); + debug_assert!(!sessions.active.contains_key(&bob_session_id)); + + drop(sessions); + // Create Bob's ephemeral counter-offer reply. let mut ack_packet = [0u8; BobNoiseXKAck::SIZE]; let ack: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut ack_packet)?; @@ -774,13 +810,13 @@ impl Context { // Encrypt main section of reply. aes_ctr_crypt_one_time_use_key( - kbkdf::(noise_es_ee.as_bytes()).as_bytes(), + kbkdf::(noise_es_ee.as_bytes()).as_bytes(), &mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START], ); // Add HMAC-SHA384 to reply packet. let reply_hmac = hmac_sha384_2( - kbkdf::(noise_es_ee.as_bytes()).as_bytes(), + kbkdf::(noise_es_ee.as_bytes()).as_bytes(), &create_message_nonce(PACKET_TYPE_BOB_NOISE_XK_ACK, 1), &ack_packet[HEADER_SIZE..BobNoiseXKAck::AUTH_START], ); @@ -812,9 +848,18 @@ impl Context { if incoming_counter != 1 || incoming.is_some() { return Err(Error::OutOfSequence); } + if pkt_assembled.len() != BobNoiseXKAck::SIZE { + return Err(Error::InvalidPacket); + } if let Some(session) = session { let state = session.state.read().unwrap(); + + // This doesn't make sense if the session is up. + if state.keys[state.current_key].is_some() { + return Err(Error::OutOfSequence); + } + if let Offer::NoiseXKInit(outgoing_offer) = &state.current_offer { let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?; @@ -830,7 +875,7 @@ impl Context { )); let noise_es_ee_kex_hmac_key = - kbkdf::(noise_es_ee.as_bytes()); + kbkdf::(noise_es_ee.as_bytes()); // Authenticate Bob's reply and the validity of bob_noise_e. if !secure_eq( @@ -846,7 +891,7 @@ impl Context { // Decrypt encrypted portion of message. aes_ctr_crypt_one_time_use_key( - kbkdf::(noise_es_ee.as_bytes()).as_bytes(), + kbkdf::(noise_es_ee.as_bytes()).as_bytes(), &mut pkt_assembled[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START], ); let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?; @@ -915,7 +960,7 @@ impl Context { // key exchange. Bob won't be able to do this until he decrypts and parses Alice's // identity, so the first HMAC is to let him authenticate that first. let hmac_es_ee_se_hk_psk = hmac_sha384_2( - kbkdf::(noise_es_ee_se_hk_psk.as_bytes()) + kbkdf::(noise_es_ee_se_hk_psk.as_bytes()) .as_bytes(), &reply_message_nonce, &reply_buffer[HEADER_SIZE..reply_len], @@ -933,9 +978,15 @@ impl Context { current_time, 2, false, + false, )); + debug_assert!(state.keys[1].is_none()); state.current_key = 0; - state.current_offer = Offer::None; + state.current_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck { + last_retry_time: AtomicI64::new(current_time), + ack: reply_buffer, + ack_size: reply_len, + })); } send_with_fragmentation( @@ -980,18 +1031,13 @@ impl Context { } if let Some(incoming) = incoming { - // Check timeout, negotiations aren't allowed to take longer than this. - if (current_time - incoming.timestamp) > Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS { - return Err(Error::UnknownLocalSessionId); - } - // Check the first HMAC to verify against the currently known noise_es_ee key, which verifies // that this reply is part of this session. let auth_start = pkt_assembled.len() - ALICE_NOISE_XK_ACK_AUTH_SIZE; if !secure_eq( &pkt_assembled[auth_start..pkt_assembled.len() - HMAC_SHA384_SIZE], &hmac_sha384_2( - kbkdf::(incoming.noise_es_ee.as_bytes()) + kbkdf::(incoming.noise_es_ee.as_bytes()) .as_bytes(), &incoming_message_nonce, &pkt_assembled[HEADER_SIZE..auth_start], @@ -1065,7 +1111,7 @@ impl Context { if !secure_eq( &pkt_assembly_buffer_copy[auth_start + HMAC_SHA384_SIZE..pkt_assembled.len()], &hmac_sha384_2( - kbkdf::(noise_es_ee_se_hk_psk.as_bytes()) + kbkdf::(noise_es_ee_se_hk_psk.as_bytes()) .as_bytes(), &incoming_message_nonce, &pkt_assembly_buffer_copy[HEADER_SIZE..auth_start + HMAC_SHA384_SIZE], @@ -1084,22 +1130,30 @@ impl Context { state: RwLock::new(State { remote_session_id: Some(incoming.alice_session_id), keys: [ - Some(SessionKey::new::(noise_es_ee_se_hk_psk, 1, current_time, 2, true)), + Some(SessionKey::new::( + noise_es_ee_se_hk_psk, + 1, + current_time, + 2, + true, + true, + )), None, ], current_key: 0, current_offer: Offer::None, }), - defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), + defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), }); - // Promote this from an incomplete session to an established session. { let mut sessions = self.sessions.write().unwrap(); sessions.incoming.remove(&incoming.bob_session_id); sessions.active.insert(incoming.bob_session_id, Arc::downgrade(&session)); } + let _ = session.send_nop(|b| send(Some(&session), b)); + return Ok(ReceiveResult::OkNewSession(session)); } else { return Err(Error::UnknownLocalSessionId); @@ -1177,6 +1231,7 @@ impl Context { current_time, counter, false, + false, )); return Ok(ReceiveResult::Ok); @@ -1202,7 +1257,7 @@ impl Context { if let Some(session) = session { let state = session.state.read().unwrap(); - if let Offer::RekeyInit(alice_e_secret, _, _) = &state.current_offer { + if let Offer::RekeyInit(alice_e_secret, _) = &state.current_offer { if let Some(key) = state.keys[key_index].as_ref() { // Only the current "Bob" initiates rekeys and expects this ACK. if key.bob { @@ -1234,6 +1289,7 @@ impl Context { current_time, session.send_counter.load(Ordering::Acquire), true, + true, )); state.current_key = next_key_index; // this is an ACK so it's confirmed state.current_offer = Offer::None; @@ -1325,6 +1381,32 @@ impl Session { return Err(Error::SessionNotEstablished); } + /// Send a NOP to the other side (e.g. for keep alive). + pub fn send_nop(&self, mut send: SendFunction) -> Result<(), Error> { + let state = self.state.read().unwrap(); + if let Some(remote_session_id) = state.remote_session_id { + if let Some(session_key) = state.keys[state.current_key].as_ref() { + let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; + let mut c = session_key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); + nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); + session_key.return_send_cipher(c); + set_packet_header( + &mut nop, + 1, + 0, + PACKET_TYPE_NOP, + u64::from(remote_session_id), + state.current_key, + counter, + ); + send(&mut nop); + } + } + return Err(Error::SessionNotEstablished); + } + /// Check whether this session is established. pub fn established(&self) -> bool { let state = self.state.read().unwrap(); @@ -1381,7 +1463,7 @@ impl Session { .encrypt_block_in_place(&mut rekey_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(&mut rekey_buf); - self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, rekey_buf, AtomicI64::new(current_time)); + self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, current_time); } } } @@ -1448,15 +1530,11 @@ fn set_packet_header( #[inline(always)] fn parse_packet_header(incoming_packet: &[u8]) -> (usize, u8, u8, u8, u64) { let raw_header_a = u16::from_le_bytes(incoming_packet[6..8].try_into().unwrap()); - let key_index = (raw_header_a & 1) as usize; - let packet_type = (raw_header_a.wrapping_shr(1) & 7) as u8; - let fragment_count = ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8; - let fragment_no = raw_header_a.wrapping_shr(10) as u8; ( - key_index, - packet_type, - fragment_count, - fragment_no, + (raw_header_a & 1) as usize, + (raw_header_a.wrapping_shr(1) & 7) as u8, + ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8, + raw_header_a.wrapping_shr(10) as u8, u64::from_le_bytes(incoming_packet[8..16].try_into().unwrap()), ) } @@ -1532,11 +1610,12 @@ impl SessionKey { ratchet_count: u64, current_time: i64, current_counter: u64, - role_is_bob: bool, + bob: bool, + confirmed: bool, ) -> Self { let a2b = kbkdf::(key.as_bytes()); let b2a = kbkdf::(key.as_bytes()); - let (receive_key, send_key) = if role_is_bob { + let (receive_key, send_key) = if bob { (a2b, b2a) } else { (b2a, a2b) @@ -1557,7 +1636,8 @@ impl SessionKey { rekey_at_counter: current_counter.checked_add(Application::REKEY_AFTER_USES).unwrap(), expire_at_counter: current_counter.checked_add(Application::EXPIRE_AFTER_USES).unwrap(), ratchet_count, - bob: role_is_bob, + bob, + confirmed, } }