passed tests

This commit is contained in:
mamoniot 2023-03-25 11:50:27 -04:00
parent fd055561ea
commit 951273f51c
No known key found for this signature in database
GPG key ID: ADCCDBBE0E3D3B3B
4 changed files with 278 additions and 210 deletions

View file

@ -45,12 +45,6 @@ impl<Fragment, const MAX_FRAGMENTS: usize> Fragged<Fragment, MAX_FRAGMENTS> {
unsafe { zeroed() } unsafe { zeroed() }
} }
/// Returns the counter value associated with the packet currently being assembled.
/// If no packet is currently being assembled it returns 0.
#[inline(always)]
pub fn counter(&self) -> u64 {
self.counter
}
/// Add a fragment and return an assembled packet container if all fragments have been received. /// Add a fragment and return an assembled packet container if all fragments have been received.
/// ///
/// When a fully assembled packet is returned the internal state is reset and this object can /// When a fully assembled packet is returned the internal state is reset and this object can

View file

@ -9,6 +9,7 @@
mod applicationlayer; mod applicationlayer;
mod error; mod error;
mod fragged; mod fragged;
mod priority_queue;
mod proto; mod proto;
mod sessionid; mod sessionid;
mod zssp; mod zssp;

View file

@ -0,0 +1,52 @@
/// Is a Donald Knuth minheap, which is extremely fast and memory efficient.
pub struct EventQueue {
heap: Vec<(i64, u64)>
}
impl EventQueue {
pub fn new() -> Self {
Self {
heap: Vec::new(),
}
}
/// Pops a single event from the queue if one exists to be run past the current time
pub fn pump(&mut self, current_time: i64) -> Option<(i64, u64)> {
if self.heap.len() > 0 {
if self.heap[0].0 <= current_time {
let ret = self.heap.swap_remove(0);
let mut parent = 0;
while 2*parent < self.heap.len() {
let child0 = 2*parent;
let child1 = child0 + 1;
let child_min = if child1 < self.heap.len() && self.heap[child1].0 < self.heap[child0].0 {
child1
} else {
child0
};
if self.heap[child_min].0 < self.heap[parent].0 {
self.heap.swap(parent, child_min);
parent = child_min;
} else {
break;
}
}
return Some(ret);
}
}
None
}
/// Pushes an event onto the queue with the given timestamp
pub fn push(&mut self, timestamp: i64, id: u64) {
let mut idx = self.heap.len();
self.heap.push((timestamp, id));
while idx > 0 {
let parent = idx/2;
if self.heap[parent].0 > self.heap[idx].0 {
self.heap.swap(parent, idx);
}
idx = parent;
}
}
}

View file

@ -10,7 +10,6 @@
// FIPS compliant Noise_XK with Jedi powers (Kyber1024) and built-in attack-resistant large payload (fragmentation) support. // FIPS compliant Noise_XK with Jedi powers (Kyber1024) and built-in attack-resistant large payload (fragmentation) support.
use std::collections::hash_map::RandomState; use std::collections::hash_map::RandomState;
//use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap; use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher}; use std::hash::{BuildHasher, Hash, Hasher};
use std::num::NonZeroU64; use std::num::NonZeroU64;
@ -38,33 +37,26 @@ const GCM_CIPHER_POOL_SIZE: usize = 4;
/// ///
/// Each application using ZSSP must create an instance of this to own sessions and /// Each application using ZSSP must create an instance of this to own sessions and
/// defragment incoming packets that are not yet associated with a session. /// defragment incoming packets that are not yet associated with a session.
pub struct Context<Application: ApplicationLayer> { pub struct Context<'a, Application: ApplicationLayer> {
default_physical_mtu: AtomicUsize, default_physical_mtu: AtomicUsize,
defrag_salt: RandomState, dos_salt: RandomState,
defrag_has_pending: AtomicBool, // Allowed to be falsely positive defrag_has_pending: AtomicBool, // Allowed to be falsely positive
defrag: [Mutex<(Fragged<Application::IncomingPacketBuffer, MAX_NOISE_HANDSHAKE_FRAGMENTS>, i64)>; MAX_INCOMPLETE_SESSION_QUEUE_SIZE], incoming_has_pending: AtomicBool, // Allowed to be falsely positive
sessions: RwLock<SessionsById<Application>>, defrag: Mutex<[(i64, u64, Fragged<Application::IncomingPacketBuffer, MAX_NOISE_HANDSHAKE_FRAGMENTS>); MAX_INCOMPLETE_SESSION_QUEUE_SIZE]>,
} active_sessions: RwLock<HashMap<SessionId, Weak<Session<'a, Application>>>>,
incoming_sessions: RwLock<[(i64, u64, Option<Arc<IncomingIncompleteSession<Application>>>); MAX_INCOMPLETE_SESSION_QUEUE_SIZE]>,
/// Lookup maps for sessions within a session context.
struct SessionsById<Application: ApplicationLayer> {
// Active sessions, automatically closed if the application no longer holds their Arc<>.
active: HashMap<SessionId, Weak<Session<Application>>>,
// Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout.
incoming: HashMap<SessionId, Arc<IncomingIncompleteSession<Application>>>,
} }
/// Result generated by the context packet receive function, with possible payloads. /// Result generated by the context packet receive function, with possible payloads.
pub enum ReceiveResult<'b, Application: ApplicationLayer> { pub enum ReceiveResult<'a, 'b, Application: ApplicationLayer> {
/// Packet was valid, but no action needs to be taken and no payload was delivered. /// Packet was valid, but no action needs to be taken and no payload was delivered.
Ok(Option<Arc<Session<Application>>>), Ok(Option<Arc<Session<'a, Application>>>),
/// Packet was valid and a data payload was decoded and authenticated. /// Packet was valid and a data payload was decoded and authenticated.
OkData(Arc<Session<Application>>, &'b mut [u8]), OkData(Arc<Session<'a, Application>>, &'b mut [u8]),
/// Packet was valid and a new session was created, with optional attached meta-data. /// Packet was valid and a new session was created, with optional attached meta-data.
OkNewSession(Arc<Session<Application>>, Option<&'b mut [u8]>), OkNewSession(Arc<Session<'a, Application>>, Option<&'b mut [u8]>),
/// Packet appears valid but was rejected by the application layer, e.g. a rejected new session attempt. /// Packet appears valid but was rejected by the application layer, e.g. a rejected new session attempt.
Rejected, Rejected,
@ -73,7 +65,8 @@ pub enum ReceiveResult<'b, Application: ApplicationLayer> {
/// ZeroTier Secure Session Protocol (ZSSP) Session /// ZeroTier Secure Session Protocol (ZSSP) Session
/// ///
/// A FIPS/NIST compliant variant of Noise_XK with hybrid Kyber1024 PQ data forward secrecy. /// A FIPS/NIST compliant variant of Noise_XK with hybrid Kyber1024 PQ data forward secrecy.
pub struct Session<Application: ApplicationLayer> { pub struct Session<'a, Application: ApplicationLayer> {
pub context: &'a Context<'a, Application>,
/// This side's locally unique session ID /// This side's locally unique session ID
pub id: SessionId, pub id: SessionId,
@ -99,7 +92,6 @@ struct State {
} }
struct IncomingIncompleteSession<Application: ApplicationLayer> { struct IncomingIncompleteSession<Application: ApplicationLayer> {
timestamp: i64,
alice_session_id: SessionId, alice_session_id: SessionId,
bob_session_id: SessionId, bob_session_id: SessionId,
noise_h: [u8; NOISE_HASHLEN], noise_h: [u8; NOISE_HASHLEN],
@ -146,20 +138,19 @@ struct SessionKey {
confirmed: bool, // Is this key confirmed by the other side yet? confirmed: bool, // Is this key confirmed by the other side yet?
} }
impl<Application: ApplicationLayer> Context<Application> { impl<'a, Application: ApplicationLayer> Context<'a, Application> {
/// Create a new session context. /// Create a new session context.
/// ///
/// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase /// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase
pub fn new(default_physical_mtu: usize) -> Self { pub fn new(default_physical_mtu: usize) -> Self {
Self { Self {
default_physical_mtu: AtomicUsize::new(default_physical_mtu), default_physical_mtu: AtomicUsize::new(default_physical_mtu),
defrag_salt: RandomState::new(), dos_salt: RandomState::new(),
defrag_has_pending: AtomicBool::new(false), defrag_has_pending: AtomicBool::new(false),
defrag: std::array::from_fn(|_| Mutex::new((Fragged::new(), i64::MAX))), incoming_has_pending: AtomicBool::new(false),
sessions: RwLock::new(SessionsById { defrag: Mutex::new(std::array::from_fn(|_| (i64::MAX, 0, Fragged::new()))),
active: HashMap::with_capacity(64), active_sessions: RwLock::new(HashMap::with_capacity(64)),
incoming: HashMap::with_capacity(64), incoming_sessions: RwLock::new(std::array::from_fn(|_| (i64::MAX, 0, None))),
}),
} }
} }
@ -173,91 +164,83 @@ impl<Application: ApplicationLayer> Context<Application> {
/// * `current_time` - Current monotonic time in milliseconds /// * `current_time` - Current monotonic time in milliseconds
pub fn service<SendFunction: FnMut(&Arc<Session<Application>>, &mut [u8])>(&self, mut send: SendFunction, current_time: i64) -> i64 { pub fn service<SendFunction: FnMut(&Arc<Session<Application>>, &mut [u8])>(&self, mut send: SendFunction, current_time: i64) -> i64 {
let mut dead_active = Vec::new(); let mut dead_active = Vec::new();
let mut dead_pending = Vec::new();
let retry_cutoff = current_time - Application::RETRY_INTERVAL; let retry_cutoff = current_time - Application::RETRY_INTERVAL;
let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
// Scan sessions in read lock mode, then lock more briefly in write mode to delete any dead entries that we found. // Scan sessions in read lock mode, then lock more briefly in write mode to delete any dead entries that we found.
{ let active_sessions = self.active_sessions.read().unwrap();
let sessions = self.sessions.read().unwrap(); for (id, s) in active_sessions.iter() {
for (id, s) in sessions.active.iter() { if let Some(session) = s.upgrade() {
if let Some(session) = s.upgrade() { let state = session.state.read().unwrap();
let state = session.state.read().unwrap(); if match &state.outgoing_offer {
if match &state.outgoing_offer { Offer::None => true,
Offer::None => true, Offer::NoiseXKInit(offer) => {
Offer::NoiseXKInit(offer) => { // If there's an outstanding attempt to open a session, retransmit this periodically
// If there's an outstanding attempt to open a session, retransmit this periodically // in case the initial packet doesn't make it. Note that we currently don't have
// in case the initial packet doesn't make it. Note that we currently don't have // retransmission for the intermediate steps, so a new session may still fail if the
// retransmission for the intermediate steps, so a new session may still fail if the // packet loss rate is huge. The application layer has its own logic to keep trying
// packet loss rate is huge. The application layer has its own logic to keep trying // under those conditions.
// under those conditions. if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { offer.last_retry_time.store(current_time, Ordering::Relaxed);
offer.last_retry_time.store(current_time, Ordering::Relaxed); let _ = send_with_fragmentation(
let _ = send_with_fragmentation( |b| send(&session, b),
|b| send(&session, b), &mut (offer.init_packet.clone()),
&mut (offer.init_packet.clone()), state.physical_mtu,
state.physical_mtu, PACKET_TYPE_ALICE_NOISE_XK_INIT,
PACKET_TYPE_ALICE_NOISE_XK_INIT, None,
None, 0,
0, random::next_u64_secure(),
random::next_u64_secure(), None,
None, );
);
}
false
} }
Offer::NoiseXKAck(ack) => { false
// We also keep retransmitting the final ACK until we get a valid DATA or NOP packet }
// from Bob, otherwise we could get a half open session. Offer::NoiseXKAck(ack) => {
if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff { // We also keep retransmitting the final ACK until we get a valid DATA or NOP packet
ack.last_retry_time.store(current_time, Ordering::Relaxed); // from Bob, otherwise we could get a half open session.
let _ = send_with_fragmentation( if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
|b| send(&session, b), ack.last_retry_time.store(current_time, Ordering::Relaxed);
&mut (ack.ack.clone())[..ack.ack_len], let _ = send_with_fragmentation(
state.physical_mtu, |b| send(&session, b),
PACKET_TYPE_ALICE_NOISE_XK_ACK, &mut (ack.ack.clone())[..ack.ack_len],
state.remote_session_id, state.physical_mtu,
0, PACKET_TYPE_ALICE_NOISE_XK_ACK,
2, state.remote_session_id,
Some(&session.header_protection_cipher), 0,
); 2,
} Some(&session.header_protection_cipher),
false );
} }
Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff, false
} { }
// Check whether we need to rekey if there is no pending offer or if the last rekey Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff,
// offer was before retry_cutoff (checked in the 'match' above). } {
if let Some(key) = state.keys[state.current_key].as_ref() { // Check whether we need to rekey if there is no pending offer or if the last rekey
if key.my_turn_to_rekey // offer was before retry_cutoff (checked in the 'match' above).
&& (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) if let Some(key) = state.keys[state.current_key].as_ref() {
{ if key.my_turn_to_rekey
drop(state); && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter)
session.initiate_rekey(|b| send(&session, b), current_time); {
} drop(state);
session.initiate_rekey(|b| send(&session, b), current_time);
} }
} }
} else {
dead_active.push(*id);
}
}
for (id, incoming) in sessions.incoming.iter() {
if incoming.timestamp <= negotiation_timeout_cutoff {
dead_pending.push(*id);
} }
} else {
dead_active.push(*id);
} }
} }
drop(active_sessions);
// Only check for expiration if we have a pending packet. // Only check for expiration if we have a pending packet.
// This check is allowed to have false positives for simplicity's sake. // This check is allowed to have false positives for simplicity's sake.
if self.defrag_has_pending.swap(false, Ordering::Relaxed) { if self.defrag_has_pending.swap(false, Ordering::Relaxed) {
let mut has_pending = false; let mut has_pending = false;
for m in &self.defrag { for pending in &mut *self.defrag.lock().unwrap() {
let mut pending = m.lock().unwrap(); if pending.0 <= negotiation_timeout_cutoff {
if pending.1 <= negotiation_timeout_cutoff { pending.0 = i64::MAX;
pending.1 = i64::MAX; pending.2.drop_in_place();
pending.0.drop_in_place(); } else if pending.0 != i64::MAX {
} else if pending.0.counter() != 0 {
has_pending = true; has_pending = true;
} }
} }
@ -265,17 +248,20 @@ impl<Application: ApplicationLayer> Context<Application> {
self.defrag_has_pending.store(true, Ordering::Relaxed); self.defrag_has_pending.store(true, Ordering::Relaxed);
} }
} }
if self.incoming_has_pending.swap(false, Ordering::Relaxed) {
if !dead_active.is_empty() || !dead_pending.is_empty() { let mut has_pending = false;
let mut sessions = self.sessions.write().unwrap(); for pending in self.incoming_sessions.write().unwrap().iter_mut() {
for id in dead_active.iter() { if pending.0 <= negotiation_timeout_cutoff {
sessions.active.remove(id); pending.0 = i64::MAX;
pending.2 = None;
} else if pending.0 != i64::MAX {
has_pending = true;
}
} }
for id in dead_pending.iter() { if has_pending {
sessions.incoming.remove(id); self.incoming_has_pending.store(true, Ordering::Relaxed);
} }
} }
Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS.min(Application::RETRY_INTERVAL) Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS.min(Application::RETRY_INTERVAL)
} }
@ -294,7 +280,7 @@ impl<Application: ApplicationLayer> Context<Application> {
/// * `application_data` - Arbitrary opaque data to include with session object /// * `application_data` - Arbitrary opaque data to include with session object
/// * `current_time` - Current monotonic time in milliseconds /// * `current_time` - Current monotonic time in milliseconds
pub fn open<SendFunction: FnMut(&mut [u8])>( pub fn open<SendFunction: FnMut(&mut [u8])>(
&self, &'a self,
app: &Application, app: &Application,
mut send: SendFunction, mut send: SendFunction,
mtu: usize, mtu: usize,
@ -304,7 +290,7 @@ impl<Application: ApplicationLayer> Context<Application> {
metadata: Option<Vec<u8>>, metadata: Option<Vec<u8>>,
application_data: Application::Data, application_data: Application::Data,
current_time: i64, current_time: i64,
) -> Result<Arc<Session<Application>>, Error> { ) -> Result<Arc<Session<'a, Application>>, Error> {
if (metadata.as_ref().map(|md| md.len()).unwrap_or(0) + app.get_local_s_public_blob().len()) > MAX_INIT_PAYLOAD_SIZE { if (metadata.as_ref().map(|md| md.len()).unwrap_or(0) + app.get_local_s_public_blob().len()) > MAX_INIT_PAYLOAD_SIZE {
return Err(Error::DataTooLarge); return Err(Error::DataTooLarge);
} }
@ -321,17 +307,25 @@ impl<Application: ApplicationLayer> Context<Application> {
let mut gcm = AesGcm::new(&kbkdf256::<KBKDF_KEY_USAGE_LABEL_KEX_ES>(&noise_ck_es)); let mut gcm = AesGcm::new(&kbkdf256::<KBKDF_KEY_USAGE_LABEL_KEX_ES>(&noise_ck_es));
let (local_session_id, session) = { let (local_session_id, session) = {
let mut sessions = self.sessions.write().unwrap(); let mut active_sessions = self.active_sessions.write().unwrap();
let incoming_sessions = self.incoming_sessions.read().unwrap();
// Pick an unused session ID on this side.
let mut local_session_id; let mut local_session_id;
let mut hashed_id;
loop { loop {
local_session_id = SessionId::random(); local_session_id = SessionId::random();
if !sessions.active.contains_key(&local_session_id) && !sessions.incoming.contains_key(&local_session_id) { let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(local_session_id.into());
hashed_id = hasher.finish();
let (_, is_used) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
if !is_used && !active_sessions.contains_key(&local_session_id) {
break; break;
} }
} }
let session = Arc::new(Session { let session = Arc::new(Session {
context: &self,
id: local_session_id, id: local_session_id,
application_data, application_data,
static_public_key: remote_s_public_p384, static_public_key: remote_s_public_p384,
@ -357,7 +351,7 @@ impl<Application: ApplicationLayer> Context<Application> {
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
}); });
sessions.active.insert(local_session_id, Arc::downgrade(&session)); active_sessions.insert(local_session_id, Arc::downgrade(&session));
(local_session_id, session) (local_session_id, session)
}; };
@ -442,7 +436,7 @@ impl<Application: ApplicationLayer> Context<Application> {
CheckAllowIncomingSession: FnMut() -> bool, CheckAllowIncomingSession: FnMut() -> bool,
CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>, CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>,
>( >(
&self, &'a self,
app: &Application, app: &Application,
mut check_allow_incoming_session: CheckAllowIncomingSession, mut check_allow_incoming_session: CheckAllowIncomingSession,
mut check_accept_session: CheckAcceptSession, mut check_accept_session: CheckAcceptSession,
@ -451,18 +445,17 @@ impl<Application: ApplicationLayer> Context<Application> {
data_buf: &'b mut [u8], data_buf: &'b mut [u8],
mut incoming_physical_packet_buf: Application::IncomingPacketBuffer, mut incoming_physical_packet_buf: Application::IncomingPacketBuffer,
current_time: i64, current_time: i64,
) -> Result<ReceiveResult<'b, Application>, Error> { ) -> Result<ReceiveResult<'a, 'b, Application>, Error> {
let incoming_physical_packet: &mut [u8] = incoming_physical_packet_buf.as_mut(); let incoming_physical_packet: &mut [u8] = incoming_physical_packet_buf.as_mut();
if incoming_physical_packet.len() < MIN_PACKET_SIZE { if incoming_physical_packet.len() < MIN_PACKET_SIZE {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
if let Some(local_session_id) = SessionId::new_from_bytes(&incoming_physical_packet[0..SessionId::SIZE]) { if let Some(local_session_id) = SessionId::new_from_bytes(&incoming_physical_packet[0..SessionId::SIZE]) {
let sessions = self.sessions.read().unwrap(); let active_sessions = self.active_sessions.read().unwrap();
if let Some(session) = sessions.active.get(&local_session_id).and_then(|s| s.upgrade()) { let session = active_sessions.get(&local_session_id).and_then(|s| s.upgrade());
drop(sessions); drop(active_sessions);
debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id)); if let Some(session) = session {
session session
.header_protection_cipher .header_protection_cipher
.decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
@ -501,10 +494,19 @@ impl<Application: ApplicationLayer> Context<Application> {
} else { } else {
return Err(Error::OutOfSequence); return Err(Error::OutOfSequence);
} }
} else if let Some(incoming) = sessions.incoming.get(&local_session_id).cloned() { } else if let Some(incoming) = {
drop(sessions); let incoming_sessions = self.incoming_sessions.read().unwrap();
debug_assert!(!self.sessions.read().unwrap().active.contains_key(&local_session_id)); let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(local_session_id.into());
let hashed_id = hasher.finish();
let (idx, is_old) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
if is_old {
incoming_sessions[idx].2.clone()
} else {
None
}
} {
Aes::new(&incoming.header_protection_key) Aes::new(&incoming.header_protection_key)
.decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet);
@ -549,43 +551,20 @@ impl<Application: ApplicationLayer> Context<Application> {
// incoming_counter is expected to be a random u64 generated by the remote peer. // incoming_counter is expected to be a random u64 generated by the remote peer.
// Using just incoming_counter to defragment would be good DOS resistance, // Using just incoming_counter to defragment would be good DOS resistance,
// but why not make it harder by hasing it with a random salt and the physical path as well. // but why not make it harder by hasing it with a random salt and the physical path as well.
let mut hasher = self.defrag_salt.build_hasher(); let mut hasher = self.dos_salt.build_hasher();
source.hash(&mut hasher); source.hash(&mut hasher);
hasher.write_u64(incoming_counter); hasher.write_u64(incoming_counter);
let hashed_counter = hasher.finish(); let hashed_counter = hasher.finish();
let idx0 = (hashed_counter as usize) % MAX_INCOMPLETE_SESSION_QUEUE_SIZE;
let idx1 = (hashed_counter as usize) / MAX_INCOMPLETE_SESSION_QUEUE_SIZE % MAX_INCOMPLETE_SESSION_QUEUE_SIZE;
// Open hash lookup of just 2 slots. let mut defrag = self.defrag.lock().unwrap();
// By only checking 2 slots we avoid a full table lookup while also minimizing the chance that 2 offers collide. let (idx, is_old) = lookup(&*defrag, hashed_counter, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
// To DOS, an adversary would either need to volumetrically spam the defrag table to keep all slots full assembled = defrag[idx].2.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
// or replay Alice's packet header from a spoofed physical path before Alice's packet is fully assembled. if assembled.is_some() {
// Volumetric spam is quite difficult since without the `defrag_salt` value an adversary cannot defrag[idx].0 = i64::MAX;
// control which slots their fragments index to. And since Alice's packet header has a randomly } else if !is_old {
// generated counter value replaying it in time requires extreme amounts of network control. defrag[idx].0 = current_time;
let (slot0, timestamp0) = &mut *self.defrag[idx0].lock().unwrap(); defrag[idx].1 = hashed_counter;
if slot0.counter() == hashed_counter { self.defrag_has_pending.store(true, Ordering::Relaxed);
assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
if assembled.is_some() {
*timestamp0 = i64::MAX;
}
} else {
let (slot1, timestamp1) = &mut *self.defrag[idx1].lock().unwrap();
if slot1.counter() == hashed_counter {
assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
if assembled.is_some() {
*timestamp1 = i64::MAX;
}
} else if slot0.counter() == 0 {
*timestamp0 = current_time;
self.defrag_has_pending.store(true, Ordering::Relaxed);
assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
} else {
// slot1 is either occupied or empty so we overwrite whatever is there to make more room.
*timestamp1 = current_time;
self.defrag_has_pending.store(true, Ordering::Relaxed);
assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
}
} }
if let Some(assembled_packet) = &assembled { if let Some(assembled_packet) = &assembled {
@ -620,7 +599,7 @@ impl<Application: ApplicationLayer> Context<Application> {
CheckAllowIncomingSession: FnMut() -> bool, CheckAllowIncomingSession: FnMut() -> bool,
CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>, CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>,
>( >(
&self, &'a self,
app: &Application, app: &Application,
send: &mut SendFunction, send: &mut SendFunction,
check_allow_incoming_session: &mut CheckAllowIncomingSession, check_allow_incoming_session: &mut CheckAllowIncomingSession,
@ -629,12 +608,13 @@ impl<Application: ApplicationLayer> Context<Application> {
incoming_counter: u64, incoming_counter: u64,
fragments: &[Application::IncomingPacketBuffer], fragments: &[Application::IncomingPacketBuffer],
packet_type: u8, packet_type: u8,
session: Option<Arc<Session<Application>>>, session: Option<Arc<Session<'a, Application>>>,
incoming: Option<Arc<IncomingIncompleteSession<Application>>>, incoming: Option<Arc<IncomingIncompleteSession<Application>>>,
key_index: usize, key_index: usize,
current_time: i64, current_time: i64,
) -> Result<ReceiveResult<'b, Application>, Error> { ) -> Result<ReceiveResult<'a, 'b, Application>, Error> {
debug_assert!(fragments.len() >= 1); debug_assert!(fragments.len() >= 1);
debug_assert!(incoming.is_none() || session.is_none());
// Generate incoming message nonce for decryption and authentication. // Generate incoming message nonce for decryption and authentication.
let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter); let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter);
@ -790,16 +770,25 @@ impl<Application: ApplicationLayer> Context<Application> {
.map_err(|_| Error::FailedAuthentication) .map_err(|_| Error::FailedAuthentication)
.map(|(ct, hk)| (ct, Secret(hk)))?; .map(|(ct, hk)| (ct, Secret(hk)))?;
let mut sessions = self.sessions.write().unwrap(); let mut incoming_sessions = self.incoming_sessions.write().unwrap();
let active_sessions = self.active_sessions.read().unwrap();
// Pick an unused session ID on this side. // Pick an unused session ID on this side.
let mut bob_session_id; let mut bob_session_id;
let mut hashed_id;
let mut bob_incoming_idx;
loop { loop {
bob_session_id = SessionId::random(); bob_session_id = SessionId::random();
if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) { let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(bob_session_id.into());
hashed_id = hasher.finish();
let is_used;
(bob_incoming_idx, is_used) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
if !is_used && !active_sessions.contains_key(&bob_session_id) {
break; break;
} }
} }
drop(active_sessions);
// Create Bob's ephemeral counter-offer reply. // Create Bob's ephemeral counter-offer reply.
let mut ack_packet = [0u8; BobNoiseXKAck::SIZE]; let mut ack_packet = [0u8; BobNoiseXKAck::SIZE];
@ -816,44 +805,23 @@ impl<Application: ApplicationLayer> Context<Application> {
gcm.crypt_in_place(&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]); gcm.crypt_in_place(&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]);
ack_packet[BobNoiseXKAck::AUTH_START..BobNoiseXKAck::AUTH_START + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt()); ack_packet[BobNoiseXKAck::AUTH_START..BobNoiseXKAck::AUTH_START + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt());
// If this queue is too big, we remove the latest entry and replace it. The latest
// is used because under flood conditions this is most likely to be another bogus
// entry. If we find one that is actually timed out, that one is replaced instead.
if sessions.incoming.len() >= MAX_INCOMPLETE_SESSION_QUEUE_SIZE {
let mut newest = i64::MIN;
let mut replace_id = None;
let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
for (id, s) in sessions.incoming.iter() {
if s.timestamp <= cutoff_time {
replace_id = Some(*id);
break;
} else if s.timestamp >= newest {
newest = s.timestamp;
replace_id = Some(*id);
}
}
let _ = sessions.incoming.remove(replace_id.as_ref().unwrap());
}
// Reserve session ID on this side and record incomplete session state. // Reserve session ID on this side and record incomplete session state.
sessions.incoming.insert( incoming_sessions[bob_incoming_idx].0 = current_time;
incoming_sessions[bob_incoming_idx].1 = hashed_id;
incoming_sessions[bob_incoming_idx].2 = Some(Arc::new(IncomingIncompleteSession {
alice_session_id,
bob_session_id, bob_session_id,
Arc::new(IncomingIncompleteSession { noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]),
timestamp: current_time, noise_ck_es_ee,
alice_session_id, hk,
bob_session_id, bob_noise_e_secret,
noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]), header_protection_key: Secret(pkt.header_protection_key),
noise_ck_es_ee, defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
hk, }));
bob_noise_e_secret, self.incoming_has_pending.store(true, Ordering::Relaxed);
header_protection_key: Secret(pkt.header_protection_key),
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
}),
);
debug_assert!(!sessions.active.contains_key(&bob_session_id));
// Release lock // Release lock
drop(sessions); drop(incoming_sessions);
send_with_fragmentation( send_with_fragmentation(
|b| send(None, b), |b| send(None, b),
@ -1068,7 +1036,17 @@ impl<Application: ApplicationLayer> Context<Application> {
// Check session acceptance and fish Alice's NIST P-384 static public key out of her static public blob. // Check session acceptance and fish Alice's NIST P-384 static public key out of her static public blob.
let check_result = check_accept_session(alice_static_public_blob); let check_result = check_accept_session(alice_static_public_blob);
if check_result.is_none() { if check_result.is_none() {
self.sessions.write().unwrap().incoming.remove(&incoming.bob_session_id); let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(incoming.bob_session_id.into());
let hashed_id = hasher.finish();
let mut incoming_sessions = self.incoming_sessions.write().unwrap();
let (bob_incoming_idx, is_old) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
// Might have been removed already
if is_old {
incoming_sessions[bob_incoming_idx].0 = i64::MAX;
incoming_sessions[bob_incoming_idx].1 = 0;
incoming_sessions[bob_incoming_idx].2 = None;
}
return Ok(ReceiveResult::Rejected); return Ok(ReceiveResult::Rejected);
} }
let (alice_noise_s, psk, application_data) = check_result.unwrap(); let (alice_noise_s, psk, application_data) = check_result.unwrap();
@ -1104,6 +1082,7 @@ impl<Application: ApplicationLayer> Context<Application> {
data_buf[..alice_meta_data.len()].copy_from_slice(alice_meta_data); data_buf[..alice_meta_data.len()].copy_from_slice(alice_meta_data);
let session = Arc::new(Session { let session = Arc::new(Session {
context: &self,
id: incoming.bob_session_id, id: incoming.bob_session_id,
application_data, application_data,
static_public_key: alice_noise_s, static_public_key: alice_noise_s,
@ -1125,9 +1104,17 @@ impl<Application: ApplicationLayer> Context<Application> {
// Promote incoming session to active. // Promote incoming session to active.
{ {
let mut sessions = self.sessions.write().unwrap(); let mut hasher = self.dos_salt.build_hasher();
sessions.incoming.remove(&incoming.bob_session_id); hasher.write_u64(incoming.bob_session_id.into());
sessions.active.insert(incoming.bob_session_id, Arc::downgrade(&session)); let hashed_id = hasher.finish();
let mut incoming_sessions = self.incoming_sessions.write().unwrap();
let (bob_incoming_idx, is_present) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
if is_present {
incoming_sessions[bob_incoming_idx].0 = i64::MAX;
incoming_sessions[bob_incoming_idx].1 = 0;
incoming_sessions[bob_incoming_idx].2 = None;
}
self.active_sessions.write().unwrap().insert(incoming.bob_session_id, Arc::downgrade(&session));
} }
let _ = session.send_nop(|b| send(Some(&session), b)); let _ = session.send_nop(|b| send(Some(&session), b));
@ -1195,7 +1182,7 @@ impl<Application: ApplicationLayer> Context<Application> {
reply.bob_e = *bob_e_secret.public_key_bytes(); reply.bob_e = *bob_e_secret.public_key_bytes();
reply.next_key_fingerprint = SHA512::hash(noise_ck_psk_es_ee_se.as_bytes()); reply.next_key_fingerprint = SHA512::hash(noise_ck_psk_es_ee_se.as_bytes());
let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); let counter = session.get_next_outgoing_counter()?.get();
set_packet_header( set_packet_header(
&mut reply_buf, &mut reply_buf,
1, 1,
@ -1335,7 +1322,7 @@ impl<Application: ApplicationLayer> Context<Application> {
} }
} }
impl<Application: ApplicationLayer> Session<Application> { impl<'a, Application: ApplicationLayer> Session<'a, Application> {
/// Send data over the session. /// Send data over the session.
/// ///
/// * `send` - Function to call to send physical packet(s) /// * `send` - Function to call to send physical packet(s)
@ -1346,13 +1333,13 @@ impl<Application: ApplicationLayer> Session<Application> {
debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU);
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) { if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) {
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); let counter = self.get_next_outgoing_counter()?.get();
let mut c = session_key.get_send_cipher(counter)?; let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter));
let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize;
let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE; let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE;
let fragment_count = (data.len() + AES_GCM_TAG_SIZE + (fragment_max_chunk_size - 1)) / fragment_max_chunk_size;
let last_fragment_no = fragment_count - 1; let last_fragment_no = fragment_count - 1;
for fragment_no in 0..fragment_count { for fragment_no in 0..fragment_count {
@ -1396,7 +1383,7 @@ impl<Application: ApplicationLayer> Session<Application> {
pub fn send_nop<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction) -> Result<(), Error> { pub fn send_nop<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction) -> Result<(), Error> {
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) { if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) {
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); let counter = self.get_next_outgoing_counter()?.get();
let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
let mut c = session_key.get_send_cipher(counter)?; let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter));
@ -1447,7 +1434,7 @@ impl<Application: ApplicationLayer> Session<Application> {
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id { if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys[state.current_key].as_ref() { if let Some(key) = state.keys[state.current_key].as_ref() {
if let Some(counter) = self.get_next_outgoing_counter() { if let Ok(counter) = self.get_next_outgoing_counter() {
if let Ok(mut gcm) = key.get_send_cipher(counter.get()) { if let Ok(mut gcm) = key.get_send_cipher(counter.get()) {
gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_INIT, counter.get())); gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_INIT, counter.get()));
gcm.crypt_in_place(&mut rekey_buf[RekeyInit::ENC_START..RekeyInit::AUTH_START]); gcm.crypt_in_place(&mut rekey_buf[RekeyInit::ENC_START..RekeyInit::AUTH_START]);
@ -1480,8 +1467,8 @@ impl<Application: ApplicationLayer> Session<Application> {
/// Get the next outgoing counter value. /// Get the next outgoing counter value.
#[inline(always)] #[inline(always)]
fn get_next_outgoing_counter(&self) -> Option<NonZeroU64> { fn get_next_outgoing_counter(&self) -> Result<NonZeroU64, Error> {
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed)) NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed)).ok_or(Error::MaxKeyLifetimeExceeded)
} }
/// Check the receive window without mutating state. /// Check the receive window without mutating state.
@ -1499,6 +1486,12 @@ impl<Application: ApplicationLayer> Session<Application> {
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
} }
} }
impl<'a, App: ApplicationLayer> Drop for Session<'a, App> {
fn drop(&mut self) {
let mut sessions = self.context.active_sessions.write().unwrap();
sessions.remove(&self.id);
}
}
#[inline(always)] #[inline(always)]
fn set_packet_header( fn set_packet_header(
@ -1728,3 +1721,31 @@ fn kbkdf512<const LABEL: u8>(key: &Secret<NOISE_HASHLEN>) -> Secret<NOISE_HASHLE
fn kbkdf256<const LABEL: u8>(key: &Secret<NOISE_HASHLEN>) -> Secret<32> { fn kbkdf256<const LABEL: u8>(key: &Secret<NOISE_HASHLEN>) -> Secret<32> {
hmac_sha512_secret256(key.as_bytes(), &[1, b'Z', b'T', LABEL, 0x00, 0, 1u8, 0u8]) hmac_sha512_secret256(key.as_bytes(), &[1, b'Z', b'T', LABEL, 0x00, 0, 1u8, 0u8])
} }
#[inline]
fn lookup<T>(table: &[(i64, u64, T)], key: u64, expiry: i64) -> (usize, bool) {
let idx0 = (key as usize) % MAX_INCOMPLETE_SESSION_QUEUE_SIZE;
let mut idx1 = (key as usize) / MAX_INCOMPLETE_SESSION_QUEUE_SIZE % (MAX_INCOMPLETE_SESSION_QUEUE_SIZE - 1);
if idx0 == idx1 {
idx1 = MAX_INCOMPLETE_SESSION_QUEUE_SIZE - 1;
}
// Open hash lookup of just 2 slots.
// By only checking 2 slots we avoid a full table lookup while also minimizing the chance that 2 offers collide.
// To DOS, an adversary would either need to volumetrically spam the defrag table to keep all slots full
// or replay Alice's packet header from a spoofed physical path before Alice's packet is fully assembled.
// Volumetric spam is quite difficult since without the `dos_salt` value an adversary cannot
// control which slots their fragments index to. And since Alice's packet header has a randomly
// generated counter value replaying it in time requires extreme amounts of network control.
if table[idx0].1 == key {
(idx0, true)
} else if table[idx1].1 == key {
(idx1, true)
} else if table[idx0].0 == i64::MAX || table[idx0].0 > table[idx1].0 || table[idx0].0 <= expiry {
// slot0 is either empty, expired, or the youngest of the two slots so use it.
(idx0, false)
} else {
// slot1 is either occupied or empty so we overwrite whatever is there to make more room.
(idx1, false)
}
}