passed tests

This commit is contained in:
mamoniot 2023-03-25 11:50:27 -04:00
parent fd055561ea
commit 951273f51c
No known key found for this signature in database
GPG key ID: ADCCDBBE0E3D3B3B
4 changed files with 278 additions and 210 deletions

View file

@ -45,12 +45,6 @@ impl<Fragment, const MAX_FRAGMENTS: usize> Fragged<Fragment, MAX_FRAGMENTS> {
unsafe { zeroed() }
}
/// Returns the counter value associated with the packet currently being assembled.
/// If no packet is currently being assembled it returns 0.
#[inline(always)]
pub fn counter(&self) -> u64 {
self.counter
}
/// Add a fragment and return an assembled packet container if all fragments have been received.
///
/// When a fully assembled packet is returned the internal state is reset and this object can

View file

@ -9,6 +9,7 @@
mod applicationlayer;
mod error;
mod fragged;
mod priority_queue;
mod proto;
mod sessionid;
mod zssp;

View file

@ -0,0 +1,52 @@
/// Is a Donald Knuth minheap, which is extremely fast and memory efficient.
pub struct EventQueue {
heap: Vec<(i64, u64)>
}
impl EventQueue {
pub fn new() -> Self {
Self {
heap: Vec::new(),
}
}
/// Pops a single event from the queue if one exists to be run past the current time
pub fn pump(&mut self, current_time: i64) -> Option<(i64, u64)> {
if self.heap.len() > 0 {
if self.heap[0].0 <= current_time {
let ret = self.heap.swap_remove(0);
let mut parent = 0;
while 2*parent < self.heap.len() {
let child0 = 2*parent;
let child1 = child0 + 1;
let child_min = if child1 < self.heap.len() && self.heap[child1].0 < self.heap[child0].0 {
child1
} else {
child0
};
if self.heap[child_min].0 < self.heap[parent].0 {
self.heap.swap(parent, child_min);
parent = child_min;
} else {
break;
}
}
return Some(ret);
}
}
None
}
/// Pushes an event onto the queue with the given timestamp
pub fn push(&mut self, timestamp: i64, id: u64) {
let mut idx = self.heap.len();
self.heap.push((timestamp, id));
while idx > 0 {
let parent = idx/2;
if self.heap[parent].0 > self.heap[idx].0 {
self.heap.swap(parent, idx);
}
idx = parent;
}
}
}

View file

@ -10,7 +10,6 @@
// FIPS compliant Noise_XK with Jedi powers (Kyber1024) and built-in attack-resistant large payload (fragmentation) support.
use std::collections::hash_map::RandomState;
//use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher};
use std::num::NonZeroU64;
@ -38,33 +37,26 @@ const GCM_CIPHER_POOL_SIZE: usize = 4;
///
/// Each application using ZSSP must create an instance of this to own sessions and
/// defragment incoming packets that are not yet associated with a session.
pub struct Context<Application: ApplicationLayer> {
pub struct Context<'a, Application: ApplicationLayer> {
default_physical_mtu: AtomicUsize,
defrag_salt: RandomState,
dos_salt: RandomState,
defrag_has_pending: AtomicBool, // Allowed to be falsely positive
defrag: [Mutex<(Fragged<Application::IncomingPacketBuffer, MAX_NOISE_HANDSHAKE_FRAGMENTS>, i64)>; MAX_INCOMPLETE_SESSION_QUEUE_SIZE],
sessions: RwLock<SessionsById<Application>>,
}
/// Lookup maps for sessions within a session context.
struct SessionsById<Application: ApplicationLayer> {
// Active sessions, automatically closed if the application no longer holds their Arc<>.
active: HashMap<SessionId, Weak<Session<Application>>>,
// Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout.
incoming: HashMap<SessionId, Arc<IncomingIncompleteSession<Application>>>,
incoming_has_pending: AtomicBool, // Allowed to be falsely positive
defrag: Mutex<[(i64, u64, Fragged<Application::IncomingPacketBuffer, MAX_NOISE_HANDSHAKE_FRAGMENTS>); MAX_INCOMPLETE_SESSION_QUEUE_SIZE]>,
active_sessions: RwLock<HashMap<SessionId, Weak<Session<'a, Application>>>>,
incoming_sessions: RwLock<[(i64, u64, Option<Arc<IncomingIncompleteSession<Application>>>); MAX_INCOMPLETE_SESSION_QUEUE_SIZE]>,
}
/// Result generated by the context packet receive function, with possible payloads.
pub enum ReceiveResult<'b, Application: ApplicationLayer> {
pub enum ReceiveResult<'a, 'b, Application: ApplicationLayer> {
/// Packet was valid, but no action needs to be taken and no payload was delivered.
Ok(Option<Arc<Session<Application>>>),
Ok(Option<Arc<Session<'a, Application>>>),
/// Packet was valid and a data payload was decoded and authenticated.
OkData(Arc<Session<Application>>, &'b mut [u8]),
OkData(Arc<Session<'a, Application>>, &'b mut [u8]),
/// Packet was valid and a new session was created, with optional attached meta-data.
OkNewSession(Arc<Session<Application>>, Option<&'b mut [u8]>),
OkNewSession(Arc<Session<'a, Application>>, Option<&'b mut [u8]>),
/// Packet appears valid but was rejected by the application layer, e.g. a rejected new session attempt.
Rejected,
@ -73,7 +65,8 @@ pub enum ReceiveResult<'b, Application: ApplicationLayer> {
/// ZeroTier Secure Session Protocol (ZSSP) Session
///
/// A FIPS/NIST compliant variant of Noise_XK with hybrid Kyber1024 PQ data forward secrecy.
pub struct Session<Application: ApplicationLayer> {
pub struct Session<'a, Application: ApplicationLayer> {
pub context: &'a Context<'a, Application>,
/// This side's locally unique session ID
pub id: SessionId,
@ -99,7 +92,6 @@ struct State {
}
struct IncomingIncompleteSession<Application: ApplicationLayer> {
timestamp: i64,
alice_session_id: SessionId,
bob_session_id: SessionId,
noise_h: [u8; NOISE_HASHLEN],
@ -146,20 +138,19 @@ struct SessionKey {
confirmed: bool, // Is this key confirmed by the other side yet?
}
impl<Application: ApplicationLayer> Context<Application> {
impl<'a, Application: ApplicationLayer> Context<'a, Application> {
/// Create a new session context.
///
/// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase
pub fn new(default_physical_mtu: usize) -> Self {
Self {
default_physical_mtu: AtomicUsize::new(default_physical_mtu),
defrag_salt: RandomState::new(),
dos_salt: RandomState::new(),
defrag_has_pending: AtomicBool::new(false),
defrag: std::array::from_fn(|_| Mutex::new((Fragged::new(), i64::MAX))),
sessions: RwLock::new(SessionsById {
active: HashMap::with_capacity(64),
incoming: HashMap::with_capacity(64),
}),
incoming_has_pending: AtomicBool::new(false),
defrag: Mutex::new(std::array::from_fn(|_| (i64::MAX, 0, Fragged::new()))),
active_sessions: RwLock::new(HashMap::with_capacity(64)),
incoming_sessions: RwLock::new(std::array::from_fn(|_| (i64::MAX, 0, None))),
}
}
@ -173,91 +164,83 @@ impl<Application: ApplicationLayer> Context<Application> {
/// * `current_time` - Current monotonic time in milliseconds
pub fn service<SendFunction: FnMut(&Arc<Session<Application>>, &mut [u8])>(&self, mut send: SendFunction, current_time: i64) -> i64 {
let mut dead_active = Vec::new();
let mut dead_pending = Vec::new();
let retry_cutoff = current_time - Application::RETRY_INTERVAL;
let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
// Scan sessions in read lock mode, then lock more briefly in write mode to delete any dead entries that we found.
{
let sessions = self.sessions.read().unwrap();
for (id, s) in sessions.active.iter() {
if let Some(session) = s.upgrade() {
let state = session.state.read().unwrap();
if match &state.outgoing_offer {
Offer::None => true,
Offer::NoiseXKInit(offer) => {
// If there's an outstanding attempt to open a session, retransmit this periodically
// in case the initial packet doesn't make it. Note that we currently don't have
// retransmission for the intermediate steps, so a new session may still fail if the
// packet loss rate is huge. The application layer has its own logic to keep trying
// under those conditions.
if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
offer.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
|b| send(&session, b),
&mut (offer.init_packet.clone()),
state.physical_mtu,
PACKET_TYPE_ALICE_NOISE_XK_INIT,
None,
0,
random::next_u64_secure(),
None,
);
}
false
let active_sessions = self.active_sessions.read().unwrap();
for (id, s) in active_sessions.iter() {
if let Some(session) = s.upgrade() {
let state = session.state.read().unwrap();
if match &state.outgoing_offer {
Offer::None => true,
Offer::NoiseXKInit(offer) => {
// If there's an outstanding attempt to open a session, retransmit this periodically
// in case the initial packet doesn't make it. Note that we currently don't have
// retransmission for the intermediate steps, so a new session may still fail if the
// packet loss rate is huge. The application layer has its own logic to keep trying
// under those conditions.
if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
offer.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
|b| send(&session, b),
&mut (offer.init_packet.clone()),
state.physical_mtu,
PACKET_TYPE_ALICE_NOISE_XK_INIT,
None,
0,
random::next_u64_secure(),
None,
);
}
Offer::NoiseXKAck(ack) => {
// We also keep retransmitting the final ACK until we get a valid DATA or NOP packet
// from Bob, otherwise we could get a half open session.
if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
ack.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
|b| send(&session, b),
&mut (ack.ack.clone())[..ack.ack_len],
state.physical_mtu,
PACKET_TYPE_ALICE_NOISE_XK_ACK,
state.remote_session_id,
0,
2,
Some(&session.header_protection_cipher),
);
}
false
false
}
Offer::NoiseXKAck(ack) => {
// We also keep retransmitting the final ACK until we get a valid DATA or NOP packet
// from Bob, otherwise we could get a half open session.
if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
ack.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
|b| send(&session, b),
&mut (ack.ack.clone())[..ack.ack_len],
state.physical_mtu,
PACKET_TYPE_ALICE_NOISE_XK_ACK,
state.remote_session_id,
0,
2,
Some(&session.header_protection_cipher),
);
}
Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff,
} {
// Check whether we need to rekey if there is no pending offer or if the last rekey
// offer was before retry_cutoff (checked in the 'match' above).
if let Some(key) = state.keys[state.current_key].as_ref() {
if key.my_turn_to_rekey
&& (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter)
{
drop(state);
session.initiate_rekey(|b| send(&session, b), current_time);
}
false
}
Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff,
} {
// Check whether we need to rekey if there is no pending offer or if the last rekey
// offer was before retry_cutoff (checked in the 'match' above).
if let Some(key) = state.keys[state.current_key].as_ref() {
if key.my_turn_to_rekey
&& (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter)
{
drop(state);
session.initiate_rekey(|b| send(&session, b), current_time);
}
}
} else {
dead_active.push(*id);
}
}
for (id, incoming) in sessions.incoming.iter() {
if incoming.timestamp <= negotiation_timeout_cutoff {
dead_pending.push(*id);
}
} else {
dead_active.push(*id);
}
}
drop(active_sessions);
// Only check for expiration if we have a pending packet.
// This check is allowed to have false positives for simplicity's sake.
if self.defrag_has_pending.swap(false, Ordering::Relaxed) {
let mut has_pending = false;
for m in &self.defrag {
let mut pending = m.lock().unwrap();
if pending.1 <= negotiation_timeout_cutoff {
pending.1 = i64::MAX;
pending.0.drop_in_place();
} else if pending.0.counter() != 0 {
for pending in &mut *self.defrag.lock().unwrap() {
if pending.0 <= negotiation_timeout_cutoff {
pending.0 = i64::MAX;
pending.2.drop_in_place();
} else if pending.0 != i64::MAX {
has_pending = true;
}
}
@ -265,17 +248,20 @@ impl<Application: ApplicationLayer> Context<Application> {
self.defrag_has_pending.store(true, Ordering::Relaxed);
}
}
if !dead_active.is_empty() || !dead_pending.is_empty() {
let mut sessions = self.sessions.write().unwrap();
for id in dead_active.iter() {
sessions.active.remove(id);
if self.incoming_has_pending.swap(false, Ordering::Relaxed) {
let mut has_pending = false;
for pending in self.incoming_sessions.write().unwrap().iter_mut() {
if pending.0 <= negotiation_timeout_cutoff {
pending.0 = i64::MAX;
pending.2 = None;
} else if pending.0 != i64::MAX {
has_pending = true;
}
}
for id in dead_pending.iter() {
sessions.incoming.remove(id);
if has_pending {
self.incoming_has_pending.store(true, Ordering::Relaxed);
}
}
Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS.min(Application::RETRY_INTERVAL)
}
@ -294,7 +280,7 @@ impl<Application: ApplicationLayer> Context<Application> {
/// * `application_data` - Arbitrary opaque data to include with session object
/// * `current_time` - Current monotonic time in milliseconds
pub fn open<SendFunction: FnMut(&mut [u8])>(
&self,
&'a self,
app: &Application,
mut send: SendFunction,
mtu: usize,
@ -304,7 +290,7 @@ impl<Application: ApplicationLayer> Context<Application> {
metadata: Option<Vec<u8>>,
application_data: Application::Data,
current_time: i64,
) -> Result<Arc<Session<Application>>, Error> {
) -> Result<Arc<Session<'a, Application>>, Error> {
if (metadata.as_ref().map(|md| md.len()).unwrap_or(0) + app.get_local_s_public_blob().len()) > MAX_INIT_PAYLOAD_SIZE {
return Err(Error::DataTooLarge);
}
@ -321,17 +307,25 @@ impl<Application: ApplicationLayer> Context<Application> {
let mut gcm = AesGcm::new(&kbkdf256::<KBKDF_KEY_USAGE_LABEL_KEX_ES>(&noise_ck_es));
let (local_session_id, session) = {
let mut sessions = self.sessions.write().unwrap();
let mut active_sessions = self.active_sessions.write().unwrap();
let incoming_sessions = self.incoming_sessions.read().unwrap();
// Pick an unused session ID on this side.
let mut local_session_id;
let mut hashed_id;
loop {
local_session_id = SessionId::random();
if !sessions.active.contains_key(&local_session_id) && !sessions.incoming.contains_key(&local_session_id) {
let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(local_session_id.into());
hashed_id = hasher.finish();
let (_, is_used) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
if !is_used && !active_sessions.contains_key(&local_session_id) {
break;
}
}
let session = Arc::new(Session {
context: &self,
id: local_session_id,
application_data,
static_public_key: remote_s_public_p384,
@ -357,7 +351,7 @@ impl<Application: ApplicationLayer> Context<Application> {
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
});
sessions.active.insert(local_session_id, Arc::downgrade(&session));
active_sessions.insert(local_session_id, Arc::downgrade(&session));
(local_session_id, session)
};
@ -442,7 +436,7 @@ impl<Application: ApplicationLayer> Context<Application> {
CheckAllowIncomingSession: FnMut() -> bool,
CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>,
>(
&self,
&'a self,
app: &Application,
mut check_allow_incoming_session: CheckAllowIncomingSession,
mut check_accept_session: CheckAcceptSession,
@ -451,18 +445,17 @@ impl<Application: ApplicationLayer> Context<Application> {
data_buf: &'b mut [u8],
mut incoming_physical_packet_buf: Application::IncomingPacketBuffer,
current_time: i64,
) -> Result<ReceiveResult<'b, Application>, Error> {
) -> Result<ReceiveResult<'a, 'b, Application>, Error> {
let incoming_physical_packet: &mut [u8] = incoming_physical_packet_buf.as_mut();
if incoming_physical_packet.len() < MIN_PACKET_SIZE {
return Err(Error::InvalidPacket);
}
if let Some(local_session_id) = SessionId::new_from_bytes(&incoming_physical_packet[0..SessionId::SIZE]) {
let sessions = self.sessions.read().unwrap();
if let Some(session) = sessions.active.get(&local_session_id).and_then(|s| s.upgrade()) {
drop(sessions);
debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id));
let active_sessions = self.active_sessions.read().unwrap();
let session = active_sessions.get(&local_session_id).and_then(|s| s.upgrade());
drop(active_sessions);
if let Some(session) = session {
session
.header_protection_cipher
.decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
@ -501,10 +494,19 @@ impl<Application: ApplicationLayer> Context<Application> {
} else {
return Err(Error::OutOfSequence);
}
} else if let Some(incoming) = sessions.incoming.get(&local_session_id).cloned() {
drop(sessions);
debug_assert!(!self.sessions.read().unwrap().active.contains_key(&local_session_id));
} else if let Some(incoming) = {
let incoming_sessions = self.incoming_sessions.read().unwrap();
let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(local_session_id.into());
let hashed_id = hasher.finish();
let (idx, is_old) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
if is_old {
incoming_sessions[idx].2.clone()
} else {
None
}
} {
Aes::new(&incoming.header_protection_key)
.decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet);
@ -549,43 +551,20 @@ impl<Application: ApplicationLayer> Context<Application> {
// incoming_counter is expected to be a random u64 generated by the remote peer.
// Using just incoming_counter to defragment would be good DOS resistance,
// but why not make it harder by hasing it with a random salt and the physical path as well.
let mut hasher = self.defrag_salt.build_hasher();
let mut hasher = self.dos_salt.build_hasher();
source.hash(&mut hasher);
hasher.write_u64(incoming_counter);
let hashed_counter = hasher.finish();
let idx0 = (hashed_counter as usize) % MAX_INCOMPLETE_SESSION_QUEUE_SIZE;
let idx1 = (hashed_counter as usize) / MAX_INCOMPLETE_SESSION_QUEUE_SIZE % MAX_INCOMPLETE_SESSION_QUEUE_SIZE;
// Open hash lookup of just 2 slots.
// By only checking 2 slots we avoid a full table lookup while also minimizing the chance that 2 offers collide.
// To DOS, an adversary would either need to volumetrically spam the defrag table to keep all slots full
// or replay Alice's packet header from a spoofed physical path before Alice's packet is fully assembled.
// Volumetric spam is quite difficult since without the `defrag_salt` value an adversary cannot
// control which slots their fragments index to. And since Alice's packet header has a randomly
// generated counter value replaying it in time requires extreme amounts of network control.
let (slot0, timestamp0) = &mut *self.defrag[idx0].lock().unwrap();
if slot0.counter() == hashed_counter {
assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
if assembled.is_some() {
*timestamp0 = i64::MAX;
}
} else {
let (slot1, timestamp1) = &mut *self.defrag[idx1].lock().unwrap();
if slot1.counter() == hashed_counter {
assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
if assembled.is_some() {
*timestamp1 = i64::MAX;
}
} else if slot0.counter() == 0 {
*timestamp0 = current_time;
self.defrag_has_pending.store(true, Ordering::Relaxed);
assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
} else {
// slot1 is either occupied or empty so we overwrite whatever is there to make more room.
*timestamp1 = current_time;
self.defrag_has_pending.store(true, Ordering::Relaxed);
assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
}
let mut defrag = self.defrag.lock().unwrap();
let (idx, is_old) = lookup(&*defrag, hashed_counter, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
assembled = defrag[idx].2.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
if assembled.is_some() {
defrag[idx].0 = i64::MAX;
} else if !is_old {
defrag[idx].0 = current_time;
defrag[idx].1 = hashed_counter;
self.defrag_has_pending.store(true, Ordering::Relaxed);
}
if let Some(assembled_packet) = &assembled {
@ -620,7 +599,7 @@ impl<Application: ApplicationLayer> Context<Application> {
CheckAllowIncomingSession: FnMut() -> bool,
CheckAcceptSession: FnMut(&[u8]) -> Option<(P384PublicKey, Secret<64>, Application::Data)>,
>(
&self,
&'a self,
app: &Application,
send: &mut SendFunction,
check_allow_incoming_session: &mut CheckAllowIncomingSession,
@ -629,12 +608,13 @@ impl<Application: ApplicationLayer> Context<Application> {
incoming_counter: u64,
fragments: &[Application::IncomingPacketBuffer],
packet_type: u8,
session: Option<Arc<Session<Application>>>,
session: Option<Arc<Session<'a, Application>>>,
incoming: Option<Arc<IncomingIncompleteSession<Application>>>,
key_index: usize,
current_time: i64,
) -> Result<ReceiveResult<'b, Application>, Error> {
) -> Result<ReceiveResult<'a, 'b, Application>, Error> {
debug_assert!(fragments.len() >= 1);
debug_assert!(incoming.is_none() || session.is_none());
// Generate incoming message nonce for decryption and authentication.
let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter);
@ -790,16 +770,25 @@ impl<Application: ApplicationLayer> Context<Application> {
.map_err(|_| Error::FailedAuthentication)
.map(|(ct, hk)| (ct, Secret(hk)))?;
let mut sessions = self.sessions.write().unwrap();
let mut incoming_sessions = self.incoming_sessions.write().unwrap();
let active_sessions = self.active_sessions.read().unwrap();
// Pick an unused session ID on this side.
let mut bob_session_id;
let mut hashed_id;
let mut bob_incoming_idx;
loop {
bob_session_id = SessionId::random();
if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) {
let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(bob_session_id.into());
hashed_id = hasher.finish();
let is_used;
(bob_incoming_idx, is_used) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
if !is_used && !active_sessions.contains_key(&bob_session_id) {
break;
}
}
drop(active_sessions);
// Create Bob's ephemeral counter-offer reply.
let mut ack_packet = [0u8; BobNoiseXKAck::SIZE];
@ -816,44 +805,23 @@ impl<Application: ApplicationLayer> Context<Application> {
gcm.crypt_in_place(&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]);
ack_packet[BobNoiseXKAck::AUTH_START..BobNoiseXKAck::AUTH_START + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt());
// If this queue is too big, we remove the latest entry and replace it. The latest
// is used because under flood conditions this is most likely to be another bogus
// entry. If we find one that is actually timed out, that one is replaced instead.
if sessions.incoming.len() >= MAX_INCOMPLETE_SESSION_QUEUE_SIZE {
let mut newest = i64::MIN;
let mut replace_id = None;
let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
for (id, s) in sessions.incoming.iter() {
if s.timestamp <= cutoff_time {
replace_id = Some(*id);
break;
} else if s.timestamp >= newest {
newest = s.timestamp;
replace_id = Some(*id);
}
}
let _ = sessions.incoming.remove(replace_id.as_ref().unwrap());
}
// Reserve session ID on this side and record incomplete session state.
sessions.incoming.insert(
incoming_sessions[bob_incoming_idx].0 = current_time;
incoming_sessions[bob_incoming_idx].1 = hashed_id;
incoming_sessions[bob_incoming_idx].2 = Some(Arc::new(IncomingIncompleteSession {
alice_session_id,
bob_session_id,
Arc::new(IncomingIncompleteSession {
timestamp: current_time,
alice_session_id,
bob_session_id,
noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]),
noise_ck_es_ee,
hk,
bob_noise_e_secret,
header_protection_key: Secret(pkt.header_protection_key),
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
}),
);
debug_assert!(!sessions.active.contains_key(&bob_session_id));
noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]),
noise_ck_es_ee,
hk,
bob_noise_e_secret,
header_protection_key: Secret(pkt.header_protection_key),
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
}));
self.incoming_has_pending.store(true, Ordering::Relaxed);
// Release lock
drop(sessions);
drop(incoming_sessions);
send_with_fragmentation(
|b| send(None, b),
@ -1068,7 +1036,17 @@ impl<Application: ApplicationLayer> Context<Application> {
// Check session acceptance and fish Alice's NIST P-384 static public key out of her static public blob.
let check_result = check_accept_session(alice_static_public_blob);
if check_result.is_none() {
self.sessions.write().unwrap().incoming.remove(&incoming.bob_session_id);
let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(incoming.bob_session_id.into());
let hashed_id = hasher.finish();
let mut incoming_sessions = self.incoming_sessions.write().unwrap();
let (bob_incoming_idx, is_old) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
// Might have been removed already
if is_old {
incoming_sessions[bob_incoming_idx].0 = i64::MAX;
incoming_sessions[bob_incoming_idx].1 = 0;
incoming_sessions[bob_incoming_idx].2 = None;
}
return Ok(ReceiveResult::Rejected);
}
let (alice_noise_s, psk, application_data) = check_result.unwrap();
@ -1104,6 +1082,7 @@ impl<Application: ApplicationLayer> Context<Application> {
data_buf[..alice_meta_data.len()].copy_from_slice(alice_meta_data);
let session = Arc::new(Session {
context: &self,
id: incoming.bob_session_id,
application_data,
static_public_key: alice_noise_s,
@ -1125,9 +1104,17 @@ impl<Application: ApplicationLayer> Context<Application> {
// Promote incoming session to active.
{
let mut sessions = self.sessions.write().unwrap();
sessions.incoming.remove(&incoming.bob_session_id);
sessions.active.insert(incoming.bob_session_id, Arc::downgrade(&session));
let mut hasher = self.dos_salt.build_hasher();
hasher.write_u64(incoming.bob_session_id.into());
let hashed_id = hasher.finish();
let mut incoming_sessions = self.incoming_sessions.write().unwrap();
let (bob_incoming_idx, is_present) = lookup(&*incoming_sessions, hashed_id, current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS);
if is_present {
incoming_sessions[bob_incoming_idx].0 = i64::MAX;
incoming_sessions[bob_incoming_idx].1 = 0;
incoming_sessions[bob_incoming_idx].2 = None;
}
self.active_sessions.write().unwrap().insert(incoming.bob_session_id, Arc::downgrade(&session));
}
let _ = session.send_nop(|b| send(Some(&session), b));
@ -1195,7 +1182,7 @@ impl<Application: ApplicationLayer> Context<Application> {
reply.bob_e = *bob_e_secret.public_key_bytes();
reply.next_key_fingerprint = SHA512::hash(noise_ck_psk_es_ee_se.as_bytes());
let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
let counter = session.get_next_outgoing_counter()?.get();
set_packet_header(
&mut reply_buf,
1,
@ -1335,7 +1322,7 @@ impl<Application: ApplicationLayer> Context<Application> {
}
}
impl<Application: ApplicationLayer> Session<Application> {
impl<'a, Application: ApplicationLayer> Session<'a, Application> {
/// Send data over the session.
///
/// * `send` - Function to call to send physical packet(s)
@ -1346,13 +1333,13 @@ impl<Application: ApplicationLayer> Session<Application> {
debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU);
let state = self.state.read().unwrap();
if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) {
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
let counter = self.get_next_outgoing_counter()?.get();
let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter));
let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize;
let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE;
let fragment_count = (data.len() + AES_GCM_TAG_SIZE + (fragment_max_chunk_size - 1)) / fragment_max_chunk_size;
let last_fragment_no = fragment_count - 1;
for fragment_no in 0..fragment_count {
@ -1396,7 +1383,7 @@ impl<Application: ApplicationLayer> Session<Application> {
pub fn send_nop<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction) -> Result<(), Error> {
let state = self.state.read().unwrap();
if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) {
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
let counter = self.get_next_outgoing_counter()?.get();
let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter));
@ -1447,7 +1434,7 @@ impl<Application: ApplicationLayer> Session<Application> {
let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys[state.current_key].as_ref() {
if let Some(counter) = self.get_next_outgoing_counter() {
if let Ok(counter) = self.get_next_outgoing_counter() {
if let Ok(mut gcm) = key.get_send_cipher(counter.get()) {
gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_INIT, counter.get()));
gcm.crypt_in_place(&mut rekey_buf[RekeyInit::ENC_START..RekeyInit::AUTH_START]);
@ -1480,8 +1467,8 @@ impl<Application: ApplicationLayer> Session<Application> {
/// Get the next outgoing counter value.
#[inline(always)]
fn get_next_outgoing_counter(&self) -> Option<NonZeroU64> {
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed))
fn get_next_outgoing_counter(&self) -> Result<NonZeroU64, Error> {
NonZeroU64::new(self.send_counter.fetch_add(1, Ordering::Relaxed)).ok_or(Error::MaxKeyLifetimeExceeded)
}
/// Check the receive window without mutating state.
@ -1499,6 +1486,12 @@ impl<Application: ApplicationLayer> Session<Application> {
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
}
impl<'a, App: ApplicationLayer> Drop for Session<'a, App> {
fn drop(&mut self) {
let mut sessions = self.context.active_sessions.write().unwrap();
sessions.remove(&self.id);
}
}
#[inline(always)]
fn set_packet_header(
@ -1728,3 +1721,31 @@ fn kbkdf512<const LABEL: u8>(key: &Secret<NOISE_HASHLEN>) -> Secret<NOISE_HASHLE
fn kbkdf256<const LABEL: u8>(key: &Secret<NOISE_HASHLEN>) -> Secret<32> {
hmac_sha512_secret256(key.as_bytes(), &[1, b'Z', b'T', LABEL, 0x00, 0, 1u8, 0u8])
}
#[inline]
fn lookup<T>(table: &[(i64, u64, T)], key: u64, expiry: i64) -> (usize, bool) {
let idx0 = (key as usize) % MAX_INCOMPLETE_SESSION_QUEUE_SIZE;
let mut idx1 = (key as usize) / MAX_INCOMPLETE_SESSION_QUEUE_SIZE % (MAX_INCOMPLETE_SESSION_QUEUE_SIZE - 1);
if idx0 == idx1 {
idx1 = MAX_INCOMPLETE_SESSION_QUEUE_SIZE - 1;
}
// Open hash lookup of just 2 slots.
// By only checking 2 slots we avoid a full table lookup while also minimizing the chance that 2 offers collide.
// To DOS, an adversary would either need to volumetrically spam the defrag table to keep all slots full
// or replay Alice's packet header from a spoofed physical path before Alice's packet is fully assembled.
// Volumetric spam is quite difficult since without the `dos_salt` value an adversary cannot
// control which slots their fragments index to. And since Alice's packet header has a randomly
// generated counter value replaying it in time requires extreme amounts of network control.
if table[idx0].1 == key {
(idx0, true)
} else if table[idx1].1 == key {
(idx1, true)
} else if table[idx0].0 == i64::MAX || table[idx0].0 > table[idx1].0 || table[idx0].0 <= expiry {
// slot0 is either empty, expired, or the youngest of the two slots so use it.
(idx0, false)
} else {
// slot1 is either occupied or empty so we overwrite whatever is there to make more room.
(idx1, false)
}
}