diff --git a/zssp/src/fragged.rs b/zssp/src/fragged.rs index 47370158b..c734ba0dc 100644 --- a/zssp/src/fragged.rs +++ b/zssp/src/fragged.rs @@ -45,12 +45,6 @@ impl Fragged { unsafe { zeroed() } } - /// Returns the counter value associated with the packet currently being assembled. - /// If no packet is currently being assembled it returns 0. - #[inline(always)] - pub fn counter(&self) -> u64 { - self.counter - } /// Add a fragment and return an assembled packet container if all fragments have been received. /// /// When a fully assembled packet is returned the internal state is reset and this object can diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 544368e6c..389dad465 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -10,7 +10,6 @@ // FIPS compliant Noise_XK with Jedi powers (Kyber1024) and built-in attack-resistant large payload (fragmentation) support. use std::collections::hash_map::RandomState; -//use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher}; use std::num::NonZeroU64; @@ -38,33 +37,26 @@ const GCM_CIPHER_POOL_SIZE: usize = 4; /// /// Each application using ZSSP must create an instance of this to own sessions and /// defragment incoming packets that are not yet associated with a session. -pub struct Context { +pub struct Context<'a, Application: ApplicationLayer> { default_physical_mtu: AtomicUsize, - defrag_salt: RandomState, - defrag_has_pending: AtomicBool, // Allowed to be falsely positive - defrag: [Mutex<(Fragged, i64)>; MAX_INCOMPLETE_SESSION_QUEUE_SIZE], - sessions: RwLock>, -} - -/// Lookup maps for sessions within a session context. -struct SessionsById { - // Active sessions, automatically closed if the application no longer holds their Arc<>. - active: HashMap>>, - - // Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout. - incoming: HashMap>>, + dos_salt: RandomState, + init_has_pending: AtomicBool, // Allowed to be falsely positive + incoming_has_pending: AtomicBool, // Allowed to be falsely positive + init_defrag: Mutex<[(i64, u64, Fragged); MAX_INCOMPLETE_SESSION_QUEUE_SIZE]>, + incoming_sessions: RwLock<[(i64, u64, Option>>); MAX_INCOMPLETE_SESSION_QUEUE_SIZE]>, + active_sessions: RwLock>>>, } /// Result generated by the context packet receive function, with possible payloads. -pub enum ReceiveResult<'b, Application: ApplicationLayer> { +pub enum ReceiveResult<'a, 'b, Application: ApplicationLayer> { /// Packet was valid, but no action needs to be taken and no payload was delivered. - Ok(Option>>), + Ok(Option>>), /// Packet was valid and a data payload was decoded and authenticated. - OkData(Arc>, &'b mut [u8]), + OkData(Arc>, &'b mut [u8]), /// Packet was valid and a new session was created, with optional attached meta-data. - OkNewSession(Arc>, Option<&'b mut [u8]>), + OkNewSession(Arc>, Option<&'b mut [u8]>), /// Packet appears valid but was rejected by the application layer, e.g. a rejected new session attempt. Rejected, @@ -73,13 +65,21 @@ pub enum ReceiveResult<'b, Application: ApplicationLayer> { /// ZeroTier Secure Session Protocol (ZSSP) Session /// /// A FIPS/NIST compliant variant of Noise_XK with hybrid Kyber1024 PQ data forward secrecy. -pub struct Session { - /// This side's locally unique session ID +pub struct Session<'a, Application: ApplicationLayer> { + /// The receive vontext associated with this session, + /// only this context can receive messages from the remote peer. + /// **Read-only**. + pub context: &'a Context<'a, Application>, + /// This side's locally unique session ID. + /// **Read-only**. pub id: SessionId, - /// An arbitrary application defined object associated with each session + /// An arbitrary application defined object associated with each session. + /// **Read-only**. pub application_data: Application::Data, + /// The static public key of the remote peer. + /// **Read-only**. pub static_public_key: P384PublicKey, send_counter: AtomicU64, @@ -99,7 +99,6 @@ struct State { } struct IncomingIncompleteSession { - timestamp: i64, alice_session_id: SessionId, bob_session_id: SessionId, noise_h: [u8; NOISE_HASHLEN], @@ -146,20 +145,19 @@ struct SessionKey { confirmed: bool, // Is this key confirmed by the other side yet? } -impl Context { +impl<'a, Application: ApplicationLayer> Context<'a, Application> { /// Create a new session context. /// /// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase pub fn new(default_physical_mtu: usize) -> Self { Self { default_physical_mtu: AtomicUsize::new(default_physical_mtu), - defrag_salt: RandomState::new(), - defrag_has_pending: AtomicBool::new(false), - defrag: std::array::from_fn(|_| Mutex::new((Fragged::new(), i64::MAX))), - sessions: RwLock::new(SessionsById { - active: HashMap::with_capacity(64), - incoming: HashMap::with_capacity(64), - }), + dos_salt: RandomState::new(), + init_has_pending: AtomicBool::new(false), + incoming_has_pending: AtomicBool::new(false), + init_defrag: Mutex::new(std::array::from_fn(|_| (i64::MAX, 0, Fragged::new()))), + active_sessions: RwLock::new(HashMap::with_capacity(64)), + incoming_sessions: RwLock::new(std::array::from_fn(|_| (i64::MAX, 0, None))), } } @@ -172,110 +170,102 @@ impl Context { /// * `send` - Function to send packets to remote sessions /// * `current_time` - Current monotonic time in milliseconds pub fn service>, &mut [u8])>(&self, mut send: SendFunction, current_time: i64) -> i64 { - let mut dead_active = Vec::new(); - let mut dead_pending = Vec::new(); let retry_cutoff = current_time - Application::RETRY_INTERVAL; let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; // Scan sessions in read lock mode, then lock more briefly in write mode to delete any dead entries that we found. - { - let sessions = self.sessions.read().unwrap(); - for (id, s) in sessions.active.iter() { - if let Some(session) = s.upgrade() { - let state = session.state.read().unwrap(); - if match &state.outgoing_offer { - Offer::None => true, - Offer::NoiseXKInit(offer) => { - // If there's an outstanding attempt to open a session, retransmit this periodically - // in case the initial packet doesn't make it. Note that we currently don't have - // retransmission for the intermediate steps, so a new session may still fail if the - // packet loss rate is huge. The application layer has its own logic to keep trying - // under those conditions. - if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { - offer.last_retry_time.store(current_time, Ordering::Relaxed); - let _ = send_with_fragmentation( - |b| send(&session, b), - &mut (offer.init_packet.clone()), - state.physical_mtu, - PACKET_TYPE_ALICE_NOISE_XK_INIT, - None, - 0, - random::next_u64_secure(), - None, - ); - } - false + let active_sessions = self.active_sessions.read().unwrap(); + for (_id, s) in active_sessions.iter() { + if let Some(session) = s.upgrade() { + let state = session.state.read().unwrap(); + if match &state.outgoing_offer { + Offer::None => true, + Offer::NoiseXKInit(offer) => { + // If there's an outstanding attempt to open a session, retransmit this periodically + // in case the initial packet doesn't make it. Note that we currently don't have + // retransmission for the intermediate steps, so a new session may still fail if the + // packet loss rate is huge. The application layer has its own logic to keep trying + // under those conditions. + if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { + offer.last_retry_time.store(current_time, Ordering::Relaxed); + let _ = send_with_fragmentation( + |b| send(&session, b), + &mut (offer.init_packet.clone()), + state.physical_mtu, + PACKET_TYPE_ALICE_NOISE_XK_INIT, + None, + 0, + random::next_u64_secure(), + None, + ); } - Offer::NoiseXKAck(ack) => { - // We also keep retransmitting the final ACK until we get a valid DATA or NOP packet - // from Bob, otherwise we could get a half open session. - if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { - ack.last_retry_time.store(current_time, Ordering::Relaxed); - let _ = send_with_fragmentation( - |b| send(&session, b), - &mut (ack.ack.clone())[..ack.ack_len], - state.physical_mtu, - PACKET_TYPE_ALICE_NOISE_XK_ACK, - state.remote_session_id, - 0, - 2, - Some(&session.header_protection_cipher), - ); - } - false + false + } + Offer::NoiseXKAck(ack) => { + // We also keep retransmitting the final ACK until we get a valid DATA or NOP packet + // from Bob, otherwise we could get a half open session. + if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { + ack.last_retry_time.store(current_time, Ordering::Relaxed); + let _ = send_with_fragmentation( + |b| send(&session, b), + &mut (ack.ack.clone())[..ack.ack_len], + state.physical_mtu, + PACKET_TYPE_ALICE_NOISE_XK_ACK, + state.remote_session_id, + 0, + 2, + Some(&session.header_protection_cipher), + ); } - Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff, - } { - // Check whether we need to rekey if there is no pending offer or if the last rekey - // offer was before retry_cutoff (checked in the 'match' above). - if let Some(key) = state.keys[state.current_key].as_ref() { - if key.my_turn_to_rekey - && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) - { - drop(state); - session.initiate_rekey(|b| send(&session, b), current_time); - } + false + } + Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff, + } { + // Check whether we need to rekey if there is no pending offer or if the last rekey + // offer was before retry_cutoff (checked in the 'match' above). + if let Some(key) = state.keys[state.current_key].as_ref() { + if key.my_turn_to_rekey + && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) + { + drop(state); + session.initiate_rekey(|b| send(&session, b), current_time); } } - } else { - dead_active.push(*id); - } - } - - for (id, incoming) in sessions.incoming.iter() { - if incoming.timestamp <= negotiation_timeout_cutoff { - dead_pending.push(*id); } } } + drop(active_sessions); + // Only check for expiration if we have a pending packet. // This check is allowed to have false positives for simplicity's sake. - if self.defrag_has_pending.swap(false, Ordering::Relaxed) { + if self.init_has_pending.swap(false, Ordering::Relaxed) { let mut has_pending = false; - for m in &self.defrag { - let mut pending = m.lock().unwrap(); - if pending.1 <= negotiation_timeout_cutoff { - pending.1 = i64::MAX; - pending.0.drop_in_place(); - } else if pending.0.counter() != 0 { + for pending in &mut *self.init_defrag.lock().unwrap() { + if pending.0 <= negotiation_timeout_cutoff { + pending.0 = i64::MAX; + pending.2.drop_in_place(); + } else if pending.0 != i64::MAX { has_pending = true; } } if has_pending { - self.defrag_has_pending.store(true, Ordering::Relaxed); + self.init_has_pending.store(true, Ordering::Relaxed); } } - - if !dead_active.is_empty() || !dead_pending.is_empty() { - let mut sessions = self.sessions.write().unwrap(); - for id in dead_active.iter() { - sessions.active.remove(id); + if self.incoming_has_pending.swap(false, Ordering::Relaxed) { + let mut has_pending = false; + for pending in self.incoming_sessions.write().unwrap().iter_mut() { + if pending.0 <= negotiation_timeout_cutoff { + pending.0 = i64::MAX; + pending.2 = None; + } else if pending.0 != i64::MAX { + has_pending = true; + } } - for id in dead_pending.iter() { - sessions.incoming.remove(id); + if has_pending { + self.incoming_has_pending.store(true, Ordering::Relaxed); } } - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS.min(Application::RETRY_INTERVAL) } @@ -294,7 +284,7 @@ impl Context { /// * `application_data` - Arbitrary opaque data to include with session object /// * `current_time` - Current monotonic time in milliseconds pub fn open( - &self, + &'a self, app: &Application, mut send: SendFunction, mtu: usize, @@ -304,7 +294,7 @@ impl Context { metadata: Option>, application_data: Application::Data, current_time: i64, - ) -> Result>, Error> { + ) -> Result>, Error> { if (metadata.as_ref().map(|md| md.len()).unwrap_or(0) + app.get_local_s_public_blob().len()) > MAX_INIT_PAYLOAD_SIZE { return Err(Error::DataTooLarge); } @@ -321,17 +311,29 @@ impl Context { let mut gcm = AesGcm::new(&kbkdf256::(&noise_ck_es)); let (local_session_id, session) = { - let mut sessions = self.sessions.write().unwrap(); + let mut active_sessions = self.active_sessions.write().unwrap(); + let incoming_sessions = self.incoming_sessions.read().unwrap(); + // Pick an unused session ID on this side. let mut local_session_id; + let mut hashed_id; loop { local_session_id = SessionId::random(); - if !sessions.active.contains_key(&local_session_id) && !sessions.incoming.contains_key(&local_session_id) { + let mut hasher = self.dos_salt.build_hasher(); + hasher.write_u64(local_session_id.into()); + hashed_id = hasher.finish(); + let (_, is_used) = lookup( + &*incoming_sessions, + hashed_id, + current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS, + ); + if !is_used && !active_sessions.contains_key(&local_session_id) { break; } } let session = Arc::new(Session { + context: &self, id: local_session_id, application_data, static_public_key: remote_s_public_p384, @@ -357,7 +359,7 @@ impl Context { defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), }); - sessions.active.insert(local_session_id, Arc::downgrade(&session)); + active_sessions.insert(local_session_id, Arc::downgrade(&session)); (local_session_id, session) }; @@ -442,7 +444,7 @@ impl Context { CheckAllowIncomingSession: FnMut() -> bool, CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>, >( - &self, + &'a self, app: &Application, mut check_allow_incoming_session: CheckAllowIncomingSession, mut check_accept_session: CheckAcceptSession, @@ -451,18 +453,17 @@ impl Context { data_buf: &'b mut [u8], mut incoming_physical_packet_buf: Application::IncomingPacketBuffer, current_time: i64, - ) -> Result, Error> { + ) -> Result, Error> { let incoming_physical_packet: &mut [u8] = incoming_physical_packet_buf.as_mut(); if incoming_physical_packet.len() < MIN_PACKET_SIZE { return Err(Error::InvalidPacket); } if let Some(local_session_id) = SessionId::new_from_bytes(&incoming_physical_packet[0..SessionId::SIZE]) { - let sessions = self.sessions.read().unwrap(); - if let Some(session) = sessions.active.get(&local_session_id).and_then(|s| s.upgrade()) { - drop(sessions); - debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id)); - + let active_sessions = self.active_sessions.read().unwrap(); + let session = active_sessions.get(&local_session_id).and_then(|s| s.upgrade()); + drop(active_sessions); + if let Some(session) = session { session .header_protection_cipher .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); @@ -501,10 +502,23 @@ impl Context { } else { return Err(Error::OutOfSequence); } - } else if let Some(incoming) = sessions.incoming.get(&local_session_id).cloned() { - drop(sessions); - debug_assert!(!self.sessions.read().unwrap().active.contains_key(&local_session_id)); + } else if let Some(incoming) = { + let incoming_sessions = self.incoming_sessions.read().unwrap(); + let mut hasher = self.dos_salt.build_hasher(); + hasher.write_u64(local_session_id.into()); + let hashed_id = hasher.finish(); + let (idx, is_old) = lookup( + &*incoming_sessions, + hashed_id, + current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS, + ); + if is_old { + incoming_sessions[idx].2.clone() + } else { + None + } + } { Aes::new(&incoming.header_protection_key) .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); @@ -549,43 +563,26 @@ impl Context { // incoming_counter is expected to be a random u64 generated by the remote peer. // Using just incoming_counter to defragment would be good DOS resistance, // but why not make it harder by hasing it with a random salt and the physical path as well. - let mut hasher = self.defrag_salt.build_hasher(); + let mut hasher = self.dos_salt.build_hasher(); source.hash(&mut hasher); hasher.write_u64(incoming_counter); let hashed_counter = hasher.finish(); - let idx0 = (hashed_counter as usize) % MAX_INCOMPLETE_SESSION_QUEUE_SIZE; - let idx1 = (hashed_counter as usize) / MAX_INCOMPLETE_SESSION_QUEUE_SIZE % MAX_INCOMPLETE_SESSION_QUEUE_SIZE; - // Open hash lookup of just 2 slots. - // By only checking 2 slots we avoid a full table lookup while also minimizing the chance that 2 offers collide. - // To DOS, an adversary would either need to volumetrically spam the defrag table to keep all slots full - // or replay Alice's packet header from a spoofed physical path before Alice's packet is fully assembled. - // Volumetric spam is quite difficult since without the `defrag_salt` value an adversary cannot - // control which slots their fragments index to. And since Alice's packet header has a randomly - // generated counter value replaying it in time requires extreme amounts of network control. - let (slot0, timestamp0) = &mut *self.defrag[idx0].lock().unwrap(); - if slot0.counter() == hashed_counter { - assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); - if assembled.is_some() { - *timestamp0 = i64::MAX; - } - } else { - let (slot1, timestamp1) = &mut *self.defrag[idx1].lock().unwrap(); - if slot1.counter() == hashed_counter { - assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); - if assembled.is_some() { - *timestamp1 = i64::MAX; - } - } else if slot0.counter() == 0 { - *timestamp0 = current_time; - self.defrag_has_pending.store(true, Ordering::Relaxed); - assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); - } else { - // slot1 is either occupied or empty so we overwrite whatever is there to make more room. - *timestamp1 = current_time; - self.defrag_has_pending.store(true, Ordering::Relaxed); - assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); - } + let mut defrag = self.init_defrag.lock().unwrap(); + let (idx, is_old) = lookup( + &*defrag, + hashed_counter, + current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS, + ); + assembled = defrag[idx] + .2 + .assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if assembled.is_some() { + defrag[idx].0 = i64::MAX; + } else if !is_old { + defrag[idx].0 = current_time; + defrag[idx].1 = hashed_counter; + self.init_has_pending.store(true, Ordering::Relaxed); } if let Some(assembled_packet) = &assembled { @@ -620,7 +617,7 @@ impl Context { CheckAllowIncomingSession: FnMut() -> bool, CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>, >( - &self, + &'a self, app: &Application, send: &mut SendFunction, check_allow_incoming_session: &mut CheckAllowIncomingSession, @@ -629,12 +626,13 @@ impl Context { incoming_counter: u64, fragments: &[Application::IncomingPacketBuffer], packet_type: u8, - session: Option>>, + session: Option>>, incoming: Option>>, key_index: usize, current_time: i64, - ) -> Result, Error> { + ) -> Result, Error> { debug_assert!(fragments.len() >= 1); + debug_assert!(incoming.is_none() || session.is_none()); // Generate incoming message nonce for decryption and authentication. let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter); @@ -790,16 +788,29 @@ impl Context { .map_err(|_| Error::FailedAuthentication) .map(|(ct, hk)| (ct, Secret(hk)))?; - let mut sessions = self.sessions.write().unwrap(); + let mut incoming_sessions = self.incoming_sessions.write().unwrap(); + let active_sessions = self.active_sessions.read().unwrap(); // Pick an unused session ID on this side. let mut bob_session_id; + let mut hashed_id; + let mut bob_incoming_idx; loop { bob_session_id = SessionId::random(); - if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) { + let mut hasher = self.dos_salt.build_hasher(); + hasher.write_u64(bob_session_id.into()); + hashed_id = hasher.finish(); + let is_used; + (bob_incoming_idx, is_used) = lookup( + &*incoming_sessions, + hashed_id, + current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS, + ); + if !is_used && !active_sessions.contains_key(&bob_session_id) { break; } } + drop(active_sessions); // Create Bob's ephemeral counter-offer reply. let mut ack_packet = [0u8; BobNoiseXKAck::SIZE]; @@ -816,44 +827,23 @@ impl Context { gcm.crypt_in_place(&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]); ack_packet[BobNoiseXKAck::AUTH_START..BobNoiseXKAck::AUTH_START + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt()); - // If this queue is too big, we remove the latest entry and replace it. The latest - // is used because under flood conditions this is most likely to be another bogus - // entry. If we find one that is actually timed out, that one is replaced instead. - if sessions.incoming.len() >= MAX_INCOMPLETE_SESSION_QUEUE_SIZE { - let mut newest = i64::MIN; - let mut replace_id = None; - let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; - for (id, s) in sessions.incoming.iter() { - if s.timestamp <= cutoff_time { - replace_id = Some(*id); - break; - } else if s.timestamp >= newest { - newest = s.timestamp; - replace_id = Some(*id); - } - } - let _ = sessions.incoming.remove(replace_id.as_ref().unwrap()); - } - // Reserve session ID on this side and record incomplete session state. - sessions.incoming.insert( + incoming_sessions[bob_incoming_idx].0 = current_time; + incoming_sessions[bob_incoming_idx].1 = hashed_id; + incoming_sessions[bob_incoming_idx].2 = Some(Arc::new(IncomingIncompleteSession { + alice_session_id, bob_session_id, - Arc::new(IncomingIncompleteSession { - timestamp: current_time, - alice_session_id, - bob_session_id, - noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]), - noise_ck_es_ee, - hk, - bob_noise_e_secret, - header_protection_key: Secret(pkt.header_protection_key), - defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), - }), - ); - debug_assert!(!sessions.active.contains_key(&bob_session_id)); + noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]), + noise_ck_es_ee, + hk, + bob_noise_e_secret, + header_protection_key: Secret(pkt.header_protection_key), + defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), + })); + self.incoming_has_pending.store(true, Ordering::Relaxed); // Release lock - drop(sessions); + drop(incoming_sessions); send_with_fragmentation( |b| send(None, b), @@ -1068,7 +1058,21 @@ impl Context { // Check session acceptance and fish Alice's NIST P-384 static public key out of her static public blob. let check_result = check_accept_session(alice_static_public_blob); if check_result.is_none() { - self.sessions.write().unwrap().incoming.remove(&incoming.bob_session_id); + let mut hasher = self.dos_salt.build_hasher(); + hasher.write_u64(incoming.bob_session_id.into()); + let hashed_id = hasher.finish(); + let mut incoming_sessions = self.incoming_sessions.write().unwrap(); + let (bob_incoming_idx, is_old) = lookup( + &*incoming_sessions, + hashed_id, + current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS, + ); + // Might have been removed already + if is_old { + incoming_sessions[bob_incoming_idx].0 = i64::MAX; + incoming_sessions[bob_incoming_idx].1 = 0; + incoming_sessions[bob_incoming_idx].2 = None; + } return Ok(ReceiveResult::Rejected); } let (alice_noise_s, psk, application_data) = check_result.unwrap(); @@ -1104,6 +1108,7 @@ impl Context { data_buf[..alice_meta_data.len()].copy_from_slice(alice_meta_data); let session = Arc::new(Session { + context: &self, id: incoming.bob_session_id, application_data, static_public_key: alice_noise_s, @@ -1125,9 +1130,24 @@ impl Context { // Promote incoming session to active. { - let mut sessions = self.sessions.write().unwrap(); - sessions.incoming.remove(&incoming.bob_session_id); - sessions.active.insert(incoming.bob_session_id, Arc::downgrade(&session)); + let mut hasher = self.dos_salt.build_hasher(); + hasher.write_u64(incoming.bob_session_id.into()); + let hashed_id = hasher.finish(); + let mut incoming_sessions = self.incoming_sessions.write().unwrap(); + let (bob_incoming_idx, is_present) = lookup( + &*incoming_sessions, + hashed_id, + current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS, + ); + if is_present { + incoming_sessions[bob_incoming_idx].0 = i64::MAX; + incoming_sessions[bob_incoming_idx].1 = 0; + incoming_sessions[bob_incoming_idx].2 = None; + } + self.active_sessions + .write() + .unwrap() + .insert(incoming.bob_session_id, Arc::downgrade(&session)); } let _ = session.send_nop(|b| send(Some(&session), b)); @@ -1195,7 +1215,7 @@ impl Context { reply.bob_e = *bob_e_secret.public_key_bytes(); reply.next_key_fingerprint = SHA512::hash(noise_ck_psk_es_ee_se.as_bytes()); - let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + let counter = session.get_next_outgoing_counter()?.get(); set_packet_header( &mut reply_buf, 1, @@ -1335,7 +1355,7 @@ impl Context { } } -impl Session { +impl<'a, Application: ApplicationLayer> Session<'a, Application> { /// Send data over the session. /// /// * `send` - Function to call to send physical packet(s) @@ -1346,13 +1366,13 @@ impl Session { debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); let state = self.state.read().unwrap(); if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) { - let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + let counter = self.get_next_outgoing_counter()?.get(); let mut c = session_key.get_send_cipher(counter)?; c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); - let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize; let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE; + let fragment_count = (data.len() + AES_GCM_TAG_SIZE + (fragment_max_chunk_size - 1)) / fragment_max_chunk_size; let last_fragment_no = fragment_count - 1; for fragment_no in 0..fragment_count { @@ -1396,7 +1416,7 @@ impl Session { pub fn send_nop(&self, mut send: SendFunction) -> Result<(), Error> { let state = self.state.read().unwrap(); if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) { - let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + let counter = self.get_next_outgoing_counter()?.get(); let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; let mut c = session_key.get_send_cipher(counter)?; c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); @@ -1447,7 +1467,7 @@ impl Session { let state = self.state.read().unwrap(); if let Some(remote_session_id) = state.remote_session_id { if let Some(key) = state.keys[state.current_key].as_ref() { - if let Some(counter) = self.get_next_outgoing_counter() { + if let Ok(counter) = self.get_next_outgoing_counter() { if let Ok(mut gcm) = key.get_send_cipher(counter.get()) { gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_INIT, counter.get())); gcm.crypt_in_place(&mut rekey_buf[RekeyInit::ENC_START..RekeyInit::AUTH_START]); @@ -1480,8 +1500,8 @@ impl Session { /// Get the next outgoing counter value. #[inline(always)] - fn get_next_outgoing_counter(&self) -> Option { - NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed)) + fn get_next_outgoing_counter(&self) -> Result { + NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed)).ok_or(Error::MaxKeyLifetimeExceeded) } /// Check the receive window without mutating state. @@ -1499,6 +1519,12 @@ impl Session { prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD } } +impl<'a, App: ApplicationLayer> Drop for Session<'a, App> { + fn drop(&mut self) { + let mut sessions = self.context.active_sessions.write().unwrap(); + sessions.remove(&self.id); + } +} #[inline(always)] fn set_packet_header( @@ -1728,3 +1754,33 @@ fn kbkdf512(key: &Secret) -> Secret(key: &Secret) -> Secret<32> { hmac_sha512_secret256(key.as_bytes(), &[1, b'Z', b'T', LABEL, 0x00, 0, 1u8, 0u8]) } + +#[inline] +fn lookup(table: &[(i64, u64, T)], key: u64, expiry: i64) -> (usize, bool) { + let idx0 = (key as usize) % MAX_INCOMPLETE_SESSION_QUEUE_SIZE; + let mut idx1 = (key as usize) / MAX_INCOMPLETE_SESSION_QUEUE_SIZE % (MAX_INCOMPLETE_SESSION_QUEUE_SIZE - 1); + if idx0 == idx1 { + idx1 = MAX_INCOMPLETE_SESSION_QUEUE_SIZE - 1; + } + + // Open hash lookup of just 2 slots. + // By only checking 2 slots we avoid a full table lookup while also minimizing the chance that 2 offers collide. + // To DOS, an adversary would either need to volumetrically spam the defrag table to keep all slots full + // or replay Alice's packet header from a spoofed physical path before Alice's packet is fully assembled. + // Volumetric spam is quite difficult since without the `dos_salt` value an adversary cannot + // control which slots their fragments index to. And since Alice's packet header has a randomly + // generated counter value replaying it in time requires extreme amounts of network control. + if table[idx0].1 == key { + (idx0, true) + } else if table[idx1].1 == key { + (idx1, true) + } else if table[idx0].0 == i64::MAX || table[idx0].0 > table[idx1].0 || table[idx0].0 < expiry { + // slot0 is either empty, expired, or the youngest of the two slots so use it. + // We evict the youngest because even in worst case flood conditions it guarantees *at least* + // one slot can never be evicted until expiry, giving plenty of time to be fully processed. + (idx0, false) + } else { + // slot1 is either occupied or empty so we overwrite whatever is there to make more room. + (idx1, false) + } +}