Fix some stuff, perf fixes, and add a check to the ZSSP header to make fragmentation attacks hard even for a MITM.

This commit is contained in:
Adam Ierymenko 2022-09-07 13:18:52 -04:00
parent d5943f246a
commit 42f6f016e9
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
3 changed files with 338 additions and 167 deletions

View file

@ -1,11 +1,12 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. // (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
use std::collections::LinkedList;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::num::NonZeroU64; use std::num::NonZeroU64;
use std::ops::Deref; use std::ops::Deref;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use crate::aes::AesGcm; use crate::aes::{Aes, AesGcm};
use crate::hash::{hmac_sha384, hmac_sha512, SHA384}; use crate::hash::{hmac_sha384, hmac_sha512, SHA384};
use crate::p384::{P384KeyPair, P384PublicKey, P384_PUBLIC_KEY_SIZE}; use crate::p384::{P384KeyPair, P384PublicKey, P384_PUBLIC_KEY_SIZE};
use crate::random; use crate::random;
@ -53,12 +54,18 @@ const E1_TYPE_KYBER1024: u8 = 1;
const MAX_FRAGMENTS: usize = 48; // protocol max: 63 const MAX_FRAGMENTS: usize = 48; // protocol max: 63
const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc. const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc.
const HEADER_SIZE: usize = 12; const HEADER_SIZE: usize = 16;
const HEADER_CHECK_SIZE: usize = 4;
const AES_GCM_TAG_SIZE: usize = 16; const AES_GCM_TAG_SIZE: usize = 16;
const AES_GCM_NONCE_SIZE: usize = 12;
const AES_GCM_NONCE_START: usize = 4;
const AES_GCM_NONCE_END: usize = 16;
const HMAC_SIZE: usize = 48; const HMAC_SIZE: usize = 48;
const SESSION_ID_SIZE: usize = 6; const SESSION_ID_SIZE: usize = 6;
const KEY_HISTORY_SIZE_MAX: usize = 3;
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
const KBKDF_KEY_USAGE_LABEL_HEADER_MAC: u8 = b'H';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
@ -243,6 +250,7 @@ pub struct Session<H: Host> {
send_counter: Counter, send_counter: Counter,
psk: Secret<64>, // Arbitrary PSK provided by external code psk: Secret<64>, // Arbitrary PSK provided by external code
ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer
header_check_cipher: Aes, // Cipher used for fast 32-bit header MAC
state: RwLock<MutableState>, // Mutable parts of state (other than defrag buffers) state: RwLock<MutableState>, // Mutable parts of state (other than defrag buffers)
remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
@ -251,13 +259,14 @@ pub struct Session<H: Host> {
struct MutableState { struct MutableState {
remote_session_id: Option<SessionId>, remote_session_id: Option<SessionId>,
keys: [Option<SessionKey>; 2], // current, next (promoted to current on successful decrypt) keys: LinkedList<SessionKey>,
offer: Option<EphemeralOffer>, offer: Option<EphemeralOffer>,
} }
/// State information to associate with receiving contexts such as sockets or remote paths/endpoints. /// State information to associate with receiving contexts such as sockets or remote paths/endpoints.
pub struct ReceiveContext<H: Host> { pub struct ReceiveContext<H: Host> {
initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 128>>, initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 128>>,
incoming_init_header_check_cipher: Aes,
} }
impl<H: Host> Session<H> { impl<H: Host> Session<H> {
@ -276,11 +285,13 @@ impl<H: Host> Session<H> {
) -> Result<Self, Error> { ) -> Result<Self, Error> {
if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) { if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) {
if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) { if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) {
let counter = Counter::new(); let send_counter = Counter::new();
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
let remote_s_public_hash = SHA384::hash(remote_s_public); let remote_s_public_hash = SHA384::hash(remote_s_public);
let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
if let Ok(offer) = EphemeralOffer::create_alice_offer( if let Ok(offer) = EphemeralOffer::create_alice_offer(
&mut send, &mut send,
counter.next(), send_counter.next(),
local_session_id, local_session_id,
None, None,
host.get_local_s_public(), host.get_local_s_public(),
@ -288,6 +299,7 @@ impl<H: Host> Session<H> {
&remote_s_public_p384, &remote_s_public_p384,
&remote_s_public_hash, &remote_s_public_hash,
&ss, &ss,
&outgoing_init_header_check_cipher,
mtu, mtu,
current_time, current_time,
jedi, jedi,
@ -295,12 +307,13 @@ impl<H: Host> Session<H> {
return Ok(Self { return Ok(Self {
id: local_session_id, id: local_session_id,
associated_object, associated_object,
send_counter: counter, send_counter,
psk: psk.clone(), psk: psk.clone(),
ss, ss,
header_check_cipher,
state: RwLock::new(MutableState { state: RwLock::new(MutableState {
remote_session_id: None, remote_session_id: None,
keys: [None, None], keys: LinkedList::new(),
offer: Some(offer), offer: Some(offer),
}), }),
remote_s_public_hash, remote_s_public_hash,
@ -313,10 +326,59 @@ impl<H: Host> Session<H> {
return Err(Error::InvalidParameter); return Err(Error::InvalidParameter);
} }
#[inline]
pub fn send<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction, mtu_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> {
debug_assert!(mtu_buffer.len() >= MIN_MTU);
let state = self.state.read();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys.front() {
let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
let counter = self.send_counter.next();
send_with_fragmentation_init_header(mtu_buffer, packet_len, mtu_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), counter);
let mut c = key.get_send_cipher(counter)?;
c.init(mtu_buffer);
if packet_len > mtu_buffer.len() {
let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap();
let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE;
let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
loop {
debug_assert!(data.len() > last_fragment_data_mtu);
c.crypt(&data[HEADER_CHECK_SIZE..fragment_data_mtu], &mut mtu_buffer[HEADER_SIZE..]);
let hc = header_check(mtu_buffer, &self.header_check_cipher);
mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
send(mtu_buffer);
data = &data[fragment_data_mtu..];
debug_assert!(header[7].wrapping_shr(2) < 63);
header[7] += 0x04; // increment fragment number
mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].copy_from_slice(&header);
if data.len() <= last_fragment_data_mtu {
break;
}
}
packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
}
let gcm_tag_idx = data.len() + HEADER_SIZE;
c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]);
let hc = header_check(mtu_buffer, &self.header_check_cipher);
mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish());
send(&mut mtu_buffer[..packet_len]);
key.return_send_cipher(c);
return Ok(());
}
}
return Err(Error::SessionNotEstablished);
}
#[inline] #[inline]
pub fn rekey_check<SendFunction: FnMut(&mut [u8])>(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) { pub fn rekey_check<SendFunction: FnMut(&mut [u8])>(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) {
let state = self.state.upgradable_read(); let state = self.state.upgradable_read();
if let Some(key) = state.keys[0].as_ref() { if let Some(key) = state.keys.front() {
if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) { if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) {
if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) { if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) {
if let Ok(offer) = EphemeralOffer::create_alice_offer( if let Ok(offer) = EphemeralOffer::create_alice_offer(
@ -329,6 +391,7 @@ impl<H: Host> Session<H> {
&remote_s_public_p384, &remote_s_public_p384,
&self.remote_s_public_hash, &self.remote_s_public_hash,
&self.ss, &self.ss,
&self.header_check_cipher,
mtu, mtu,
current_time, current_time,
jedi, jedi,
@ -343,9 +406,10 @@ impl<H: Host> Session<H> {
impl<H: Host> ReceiveContext<H> { impl<H: Host> ReceiveContext<H> {
#[inline] #[inline]
pub fn new() -> Self { pub fn new(host: &H) -> Self {
Self { Self {
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()),
} }
} }
@ -366,57 +430,54 @@ impl<H: Host> ReceiveContext<H> {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
let header_0_8 = memory::u64_from_le_bytes(incoming_packet); // type, frag info, session ID let header_0_8 = memory::u64_from_le_bytes(&incoming_packet[HEADER_CHECK_SIZE..12]); // session ID, type, frag info
let counter = memory::u32_from_le_bytes(&incoming_packet[8..]); let counter = memory::u32_from_le_bytes(&incoming_packet[12..16]);
let local_session_id = SessionId::new_from_u64(header_0_8.wrapping_shr(16)); let local_session_id = SessionId::new_from_u64(header_0_8 & SessionId::MAX_BIT_MASK);
let packet_type = (header_0_8 as u8) & 15; let packet_type = (header_0_8.wrapping_shr(48) as u8) & 15;
let fragment_count = ((header_0_8.wrapping_shr(4) as u8) & 63).wrapping_add(1); let fragment_count = ((header_0_8.wrapping_shr(52) as u8) & 63).wrapping_add(1);
let fragment_no = (header_0_8.wrapping_shr(10) as u8) & 63; let fragment_no = (header_0_8.wrapping_shr(58) as u8) & 63;
if let Some(local_session_id) = local_session_id {
if let Some(session) = host.session_lookup(local_session_id) {
if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &session.header_check_cipher) {
unlikely_branch();
return Err(Error::FailedAuthentication);
}
if fragment_count > 1 { if fragment_count > 1 {
if let Some(local_session_id) = local_session_id {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
if let Some(session) = host.session_lookup(local_session_id) {
let mut defrag = session.defrag.lock(); let mut defrag = session.defrag.lock();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock drop(defrag); // release lock
return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time); return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time);
} }
} else {
unlikely_branch();
return Err(Error::InvalidPacket);
}
} else {
return self.receive_complete(host, &mut send, data_buf, &[incoming_packet_buf], packet_type, Some(session), mtu, jedi, current_time);
}
} else { } else {
unlikely_branch(); unlikely_branch();
return Err(Error::UnknownLocalSessionId(local_session_id)); return Err(Error::UnknownLocalSessionId(local_session_id));
} }
} else { } else {
unlikely_branch(); unlikely_branch();
return Err(Error::InvalidPacket);
if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &self.incoming_init_header_check_cipher) {
unlikely_branch();
return Err(Error::FailedAuthentication);
} }
} else {
if fragment_count <= (KEY_EXCHANGE_MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
let mut defrag = self.initial_offer_defrag.lock(); let mut defrag = self.initial_offer_defrag.lock();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock drop(defrag); // release lock
return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
} }
} else { };
unlikely_branch();
return Err(Error::InvalidPacket);
}
}
} else {
return self.receive_complete(
host,
&mut send,
data_buf,
&[incoming_packet_buf],
packet_type,
local_session_id.and_then(|lsid| host.session_lookup(lsid)),
mtu,
jedi,
current_time,
);
}
return Ok(ReceiveResult::Ok); return Ok(ReceiveResult::Ok);
} }
@ -440,8 +501,8 @@ impl<H: Host> ReceiveContext<H> {
if packet_type <= PACKET_TYPE_NOP { if packet_type <= PACKET_TYPE_NOP {
if let Some(session) = session { if let Some(session) = session {
let state = session.state.read(); let state = session.state.read();
for ki in 0..2 { let key_count = state.keys.len();
if let Some(key) = state.keys[ki].as_ref() { for (key_index, key) in state.keys.iter().enumerate() {
let tail = fragments.last().unwrap().as_ref(); let tail = fragments.last().unwrap().as_ref();
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch(); unlikely_branch();
@ -449,7 +510,7 @@ impl<H: Host> ReceiveContext<H> {
} }
let mut c = key.get_receive_cipher(); let mut c = key.get_receive_cipher();
c.init(&get_aes_gcm_nonce(fragments.first().unwrap().as_ref())); c.init(&fragments.first().unwrap().as_ref()[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
let mut data_len = 0; let mut data_len = 0;
@ -479,12 +540,16 @@ impl<H: Host> ReceiveContext<H> {
key.return_receive_cipher(c); key.return_receive_cipher(c);
if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) { if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) {
// If this succeeded with the "next" key, promote it to current. // Drop obsolete keys if we had to iterate past the first key to get here.
if ki == 1 { if key_index > 0 {
unlikely_branch(); unlikely_branch();
drop(state); drop(state);
let mut state = session.state.write(); let mut state = session.state.write();
state.keys[0] = state.keys[1].take(); if state.keys.len() == key_count {
for _ in 0..key_index {
let _ = state.keys.pop_front();
}
}
} }
if packet_type == PACKET_TYPE_DATA { if packet_type == PACKET_TYPE_DATA {
@ -495,7 +560,6 @@ impl<H: Host> ReceiveContext<H> {
} }
} }
} }
}
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} else { } else {
unlikely_branch(); unlikely_branch();
@ -541,7 +605,7 @@ impl<H: Host> ReceiveContext<H> {
let hmac1_end = incoming_packet_len - HMAC_SIZE; let hmac1_end = incoming_packet_len - HMAC_SIZE;
// Check that the sender knows this host's identity before doing anything else. // Check that the sender knows this host's identity before doing anything else.
if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[..hmac1_end]).eq(&incoming_packet[hmac1_end..]) { if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[HEADER_CHECK_SIZE..hmac1_end]).eq(&incoming_packet[hmac1_end..]) {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} }
@ -561,7 +625,7 @@ impl<H: Host> ReceiveContext<H> {
let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes())); let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes()));
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false);
c.init(&get_aes_gcm_nonce(incoming_packet)); c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
let c = c.finish(); let c = c.finish();
if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) {
@ -582,7 +646,12 @@ impl<H: Host> ReceiveContext<H> {
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]).eq(&incoming_packet[aes_gcm_tag_end..hmac1_end]) { if !hmac_sha384(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
&original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end],
)
.eq(&incoming_packet[aes_gcm_tag_end..hmac1_end])
{
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} }
@ -596,15 +665,17 @@ impl<H: Host> ReceiveContext<H> {
None None
} else { } else {
if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) { if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) {
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
Some(Session::<H> { Some(Session::<H> {
id: new_session_id, id: new_session_id,
associated_object, associated_object,
send_counter: Counter::new(), send_counter: Counter::new(),
psk, psk,
ss, ss,
header_check_cipher,
state: RwLock::new(MutableState { state: RwLock::new(MutableState {
remote_session_id: Some(alice_session_id), remote_session_id: Some(alice_session_id),
keys: [None, None], keys: LinkedList::new(),
offer: None, offer: None,
}), }),
remote_s_public_hash: SHA384::hash(&alice_s_public), remote_s_public_hash: SHA384::hash(&alice_s_public),
@ -656,13 +727,14 @@ impl<H: Host> ReceiveContext<H> {
} else { } else {
rp.write_all(&[E1_TYPE_NONE])?; rp.write_all(&[E1_TYPE_NONE])?;
} }
(MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len() (MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len()
}; };
let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter);
reply_buf[..HEADER_SIZE].copy_from_slice(&header); send_with_fragmentation_init_header(&mut reply_buf, reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter);
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true);
c.init(&get_aes_gcm_nonce(&header)); c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]); c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]);
let c = c.finish(); let c = c.finish();
reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c); reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c);
@ -674,18 +746,19 @@ impl<H: Host> ReceiveContext<H> {
// key derivation step. // key derivation step.
let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes())); let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes()));
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[..reply_len]); let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[HEADER_CHECK_SIZE..reply_len]);
reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac); reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac);
reply_len += HMAC_SIZE; reply_len += HMAC_SIZE;
let mut state = session.state.write(); let mut state = session.state.write();
let _ = state.remote_session_id.replace(alice_session_id); let _ = state.remote_session_id.replace(alice_session_id);
state.keys[1].replace(SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi)); add_key(&mut state.keys, SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi));
drop(state); drop(state);
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &mut header); send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &session.header_check_cipher);
if new_session.is_some() { if new_session.is_some() {
return Ok(ReceiveResult::OkNewSession(new_session.unwrap())); return Ok(ReceiveResult::OkNewSession(new_session.unwrap()));
} else { } else {
@ -716,7 +789,7 @@ impl<H: Host> ReceiveContext<H> {
)); ));
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false);
c.init(&get_aes_gcm_nonce(incoming_packet)); c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
let c = c.finish(); let c = c.finish();
if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) {
@ -739,7 +812,10 @@ impl<H: Host> ReceiveContext<H> {
let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes())); let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes()));
if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]) if !hmac_sha384(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
&original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end],
)
.eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()]) .eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()])
{ {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
@ -751,20 +827,22 @@ impl<H: Host> ReceiveContext<H> {
let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi); let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi);
let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
let header = send_with_fragmentation_init_header(HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter); send_with_fragmentation_init_header(&mut reply_buf, HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter);
reply_buf[..HEADER_SIZE].copy_from_slice(&header);
let mut c = key.get_send_cipher(counter)?; let mut c = key.get_send_cipher(counter)?;
c.init(&get_aes_gcm_nonce(&reply_buf)); c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish()); reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish());
key.return_send_cipher(c); key.return_send_cipher(c);
let hc = header_check(&reply_buf, &session.header_check_cipher);
reply_buf[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
send(&mut reply_buf); send(&mut reply_buf);
let mut state = RwLockUpgradableReadGuard::upgrade(state); let mut state = RwLockUpgradableReadGuard::upgrade(state);
let _ = state.remote_session_id.replace(bob_session_id); let _ = state.remote_session_id.replace(bob_session_id);
let _ = state.offer.take(); let _ = state.offer.take();
let _ = state.keys[0].insert(key); add_key(&mut state.keys, key);
return Ok(ReceiveResult::Ok); return Ok(ReceiveResult::Ok);
} }
@ -861,6 +939,7 @@ impl EphemeralOffer {
bob_s_public_p384: &P384PublicKey, bob_s_public_p384: &P384PublicKey,
bob_s_public_hash: &[u8], bob_s_public_hash: &[u8],
ss: &Secret<48>, ss: &Secret<48>,
header_check_cipher: &Aes,
mtu: usize, mtu: usize,
current_time: i64, current_time: i64,
jedi: bool, jedi: bool,
@ -900,8 +979,7 @@ impl EphemeralOffer {
PACKET_BUF_SIZE - p.len() PACKET_BUF_SIZE - p.len()
}; };
let mut header = send_with_fragmentation_init_header(packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter); send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter);
packet_buf[..HEADER_SIZE].copy_from_slice(&header);
let key = Secret(hmac_sha512( let key = Secret(hmac_sha512(
&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()), &hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()),
@ -910,7 +988,7 @@ impl EphemeralOffer {
let gcm_tag = { let gcm_tag = {
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true);
c.init(&get_aes_gcm_nonce(&packet_buf)); c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]); c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]);
c.finish() c.finish()
}; };
@ -919,15 +997,15 @@ impl EphemeralOffer {
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[..packet_len]); let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE; packet_len += HMAC_SIZE;
let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[..packet_len]); let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE; packet_len += HMAC_SIZE;
send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, &mut header); send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher);
Ok(EphemeralOffer { Ok(EphemeralOffer {
creation_time: current_time, creation_time: current_time,
@ -939,7 +1017,7 @@ impl EphemeralOffer {
} }
#[inline(always)] #[inline(always)]
fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] { fn send_with_fragmentation_init_header(header: &mut [u8], packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) {
let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize; let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
debug_assert!(mtu >= MIN_MTU); debug_assert!(mtu >= MIN_MTU);
debug_assert!(packet_len >= HEADER_SIZE); debug_assert!(packet_len >= HEADER_SIZE);
@ -947,27 +1025,26 @@ fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_typ
debug_assert!(fragment_count > 0); debug_assert!(fragment_count > 0);
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
header[HEADER_CHECK_SIZE..12].copy_from_slice(&(recipient_session_id | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)).to_le_bytes());
// Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter header[12..HEADER_SIZE].copy_from_slice(&counter.to_u32().to_le_bytes());
((((((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u64) | recipient_session_id.wrapping_shl(16)) as u128) | (counter.to_u32() as u128).wrapping_shl(64)).to_le_bytes()
[..HEADER_SIZE]
.try_into()
.unwrap()
} }
#[inline(always)] /// Send a packet in fragments (used for everything but DATA which has a hand-rolled version for performance).
fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header: &mut [u8; HEADER_SIZE]) { fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header_check_cipher: &Aes) {
let packet_len = packet.len(); let packet_len = packet.len();
let mut fragment_start = 0; let mut fragment_start = 0;
let mut fragment_end = packet_len.min(mtu); let mut fragment_end = packet_len.min(mtu);
let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = packet[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap();
loop { loop {
let hc = header_check(&packet[fragment_start..], header_check_cipher);
packet[fragment_start..(fragment_start + HEADER_CHECK_SIZE)].copy_from_slice(&hc.to_ne_bytes());
send(&mut packet[fragment_start..fragment_end]); send(&mut packet[fragment_start..fragment_end]);
if fragment_end < packet_len { if fragment_end < packet_len {
debug_assert!(header[7].wrapping_shr(2) < 63);
header[7] += 0x04; // increment fragment number
fragment_start = fragment_end - HEADER_SIZE; fragment_start = fragment_end - HEADER_SIZE;
fragment_end = (fragment_start + mtu).min(packet_len); fragment_end = (fragment_start + mtu).min(packet_len);
debug_assert!(header[1].wrapping_shr(2) < 63); packet[(fragment_start + HEADER_CHECK_SIZE)..(fragment_start + HEADER_SIZE)].copy_from_slice(&header);
header[1] += 0x04; // increment fragment number in least significant 6 bits of byte 1
packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header);
} else { } else {
debug_assert_eq!(fragment_end, packet_len); debug_assert_eq!(fragment_end, packet_len);
break; break;
@ -975,6 +1052,31 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(send: &mut SendFuncti
} }
} }
/// Compute the 32-bit header check code for a packet on receipt or right before send.
fn header_check(packet: &[u8], header_check_cipher: &Aes) -> u32 {
debug_assert!(packet.len() >= HEADER_SIZE);
let mut header_check = 0u128.to_ne_bytes();
if packet.len() >= (16 + HEADER_CHECK_SIZE) {
header_check_cipher.encrypt_block(&packet[HEADER_CHECK_SIZE..(16 + HEADER_CHECK_SIZE)], &mut header_check);
} else {
unlikely_branch();
header_check[..(packet.len() - HEADER_CHECK_SIZE)].copy_from_slice(&packet[HEADER_CHECK_SIZE..]);
header_check_cipher.encrypt_block_in_place(&mut header_check);
}
memory::u32_from_ne_bytes(&header_check)
}
/// Add a new session key to the key list, retiring older non-active keys if necessary.
fn add_key(keys: &mut LinkedList<SessionKey>, key: SessionKey) {
debug_assert!(KEY_HISTORY_SIZE_MAX >= 2);
while keys.len() >= KEY_HISTORY_SIZE_MAX {
let current = keys.pop_front().unwrap();
let _ = keys.pop_front();
keys.push_front(current);
}
keys.push_back(key);
}
fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<(SessionId, &[u8], &[u8], &[u8]), Error> { fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<(SessionId, &[u8], &[u8], &[u8]), Error> {
let mut p = &incoming_packet[..]; let mut p = &incoming_packet[..];
let alice_session_id = SessionId::new_from_reader(&mut p)?; let alice_session_id = SessionId::new_from_reader(&mut p)?;
@ -1083,13 +1185,6 @@ fn kbkdf512(key: &[u8], label: u8) -> Secret<64> {
Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00])) Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00]))
} }
#[inline(always)]
fn get_aes_gcm_nonce(packet: &[u8]) -> [u8; 16] {
let mut tmp = 0u128.to_ne_bytes();
tmp[..HEADER_SIZE].copy_from_slice(&packet[..HEADER_SIZE]);
tmp
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use parking_lot::Mutex; use parking_lot::Mutex;
@ -1175,7 +1270,8 @@ mod tests {
random::fill_bytes_secure(&mut psk.0); random::fill_bytes_secure(&mut psk.0);
let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob")); let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob"));
let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice")); let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice"));
let rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new()); let alice_rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new(&alice_host));
let bob_rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new(&bob_host));
let mut data_buf = [0_u8; 4096]; let mut data_buf = [0_u8; 4096];
//println!("zssp: size of session (bytes): {}", std::mem::size_of::<Session<Box<TestHost>>>()); //println!("zssp: size of session (bytes): {}", std::mem::size_of::<Session<Box<TestHost>>>());
@ -1207,6 +1303,12 @@ mod tests {
} }
}; };
let rc = if std::ptr::eq(host, &alice_host) {
&alice_rc
} else {
&bob_rc
};
loop { loop {
if let Some(qi) = host.queue.lock().pop_back() { if let Some(qi) = host.queue.lock().pop_back() {
let qi_len = qi.len(); let qi_len = qi.len();
@ -1232,7 +1334,8 @@ mod tests {
} }
} }
} else { } else {
println!("zssp: {} => {}: error: {}", host.other_name, host.this_name, r.err().unwrap().to_string()); println!("zssp: {} => {} ({}): error: {}", host.other_name, host.this_name, qi_len, r.err().unwrap().to_string());
panic!();
} }
} else { } else {
break; break;

View file

@ -39,6 +39,42 @@ mod fast_int_memory_access {
unsafe { *b.as_ptr().cast() } unsafe { *b.as_ptr().cast() }
} }
#[inline(always)]
pub fn u64_from_ne_bytes(b: &[u8]) -> u64 {
assert!(b.len() >= 8);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u32_from_ne_bytes(b: &[u8]) -> u32 {
assert!(b.len() >= 4);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u16_from_ne_bytes(b: &[u8]) -> u16 {
assert!(b.len() >= 2);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i64_from_ne_bytes(b: &[u8]) -> i64 {
assert!(b.len() >= 8);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i32_from_ne_bytes(b: &[u8]) -> i32 {
assert!(b.len() >= 4);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i16_from_ne_bytes(b: &[u8]) -> i16 {
assert!(b.len() >= 2);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)] #[inline(always)]
pub fn u64_from_be_bytes(b: &[u8]) -> u64 { pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
assert!(b.len() >= 8); assert!(b.len() >= 8);
@ -109,6 +145,36 @@ mod fast_int_memory_access {
i16::from_le_bytes(b[..2].try_into().unwrap()) i16::from_le_bytes(b[..2].try_into().unwrap())
} }
#[inline(always)]
pub fn u64_from_ne_bytes(b: &[u8]) -> u64 {
u64::from_ne_bytes(b[..8].try_into().unwrap())
}
#[inline(always)]
pub fn u32_from_ne_bytes(b: &[u8]) -> u32 {
u32::from_ne_bytes(b[..4].try_into().unwrap())
}
#[inline(always)]
pub fn u16_from_ne_bytes(b: &[u8]) -> u16 {
u16::from_ne_bytes(b[..2].try_into().unwrap())
}
#[inline(always)]
pub fn i64_from_ne_bytes(b: &[u8]) -> i64 {
i64::from_ne_bytes(b[..8].try_into().unwrap())
}
#[inline(always)]
pub fn i32_from_ne_bytes(b: &[u8]) -> i32 {
i32::from_ne_bytes(b[..4].try_into().unwrap())
}
#[inline(always)]
pub fn i16_from_ne_bytes(b: &[u8]) -> i16 {
i16::from_ne_bytes(b[..2].try_into().unwrap())
}
#[inline(always)] #[inline(always)]
pub fn u64_from_be_bytes(b: &[u8]) -> u64 { pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
u64::from_be_bytes(b[..8].try_into().unwrap()) u64::from_be_bytes(b[..8].try_into().unwrap())

View file

@ -4,6 +4,8 @@ use std::hash::{Hash, Hasher};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
const EMPTY: u16 = 0xffff;
#[inline(always)] #[inline(always)]
fn xorshift64(mut x: u64) -> u64 { fn xorshift64(mut x: u64) -> u64 {
x ^= x.wrapping_shl(13); x ^= x.wrapping_shl(13);
@ -78,9 +80,9 @@ impl Hasher for XorShiftHasher {
struct Entry<K: Eq + PartialEq + Hash + Clone, V> { struct Entry<K: Eq + PartialEq + Hash + Clone, V> {
key: MaybeUninit<K>, key: MaybeUninit<K>,
value: MaybeUninit<V>, value: MaybeUninit<V>,
bucket: i32, // which bucket is this in? -1 for none bucket: u16, // which bucket is this in? EMPTY for none
next: i32, // next item in bucket's linked list, -1 for none next: u16, // next item in bucket's linked list, EMPTY for none
prev: i32, // previous entry to permit deletion of old entries from bucket lists prev: u16, // previous entry to permit deletion of old entries from bucket lists
} }
/// A hybrid between a circular buffer and a map. /// A hybrid between a circular buffer and a map.
@ -90,35 +92,35 @@ struct Entry<K: Eq + PartialEq + Hash + Clone, V> {
/// with a HashMap but that would be less efficient. This requires no memory allocations unless /// with a HashMap but that would be less efficient. This requires no memory allocations unless
/// the K or V types allocate memory and occupies a fixed amount of memory. /// the K or V types allocate memory and occupies a fixed amount of memory.
/// ///
/// This is pretty basic and doesn't have a remove function. Old entries just roll off. This /// There is no explicit remove since that would require more complex logic to maintain FIFO
/// only contains what is needed elsewhere in the project. /// ordering for replacement of entries. Old entries just roll off the end.
///
/// This is used for things like defragmenting incoming packets to support multiple fragmented
/// packets in flight. Having no allocations is good to reduce the potential for memory
/// exhaustion attacks.
/// ///
/// The C template parameter is the total capacity while the B parameter is the number of /// The C template parameter is the total capacity while the B parameter is the number of
/// buckets in the hash table. /// buckets in the hash table. The maximum for both these parameters is 65535. This could be
/// increased by making the index variables larger (e.g. u32 instead of u16).
pub struct RingBufferMap<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> { pub struct RingBufferMap<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> {
entries: [Entry<K, V>; C],
buckets: [i32; B],
entry_ptr: u32,
salt: u32, salt: u32,
entries: [Entry<K, V>; C],
buckets: [u16; B],
entry_ptr: u16,
} }
impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBufferMap<K, V, C, B> { impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBufferMap<K, V, C, B> {
/// Create a new map with the supplied random salt to perturb the hashing function.
#[inline] #[inline]
pub fn new(salt: u32) -> Self { pub fn new(salt: u32) -> Self {
Self { debug_assert!(C <= EMPTY as usize);
entries: { debug_assert!(B <= EMPTY as usize);
let mut entries: [Entry<K, V>; C] = unsafe { MaybeUninit::uninit().assume_init() }; let mut tmp: Self = unsafe { MaybeUninit::uninit().assume_init() };
for e in entries.iter_mut() { // EMPTY is the maximum value of the indices, which is all 0xff, so this sets all indices to EMPTY.
e.bucket = -1; unsafe { std::ptr::write_bytes(&mut tmp, 0xff, 1) };
e.next = -1; tmp.salt = salt;
e.prev = -1; tmp.entry_ptr = 0;
} tmp
entries
},
buckets: [-1; B],
entry_ptr: 0,
salt,
}
} }
#[inline] #[inline]
@ -126,9 +128,9 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
let mut h = XorShiftHasher::new(self.salt); let mut h = XorShiftHasher::new(self.salt);
key.hash(&mut h); key.hash(&mut h);
let mut e = self.buckets[(h.finish() as usize) % B]; let mut e = self.buckets[(h.finish() as usize) % B];
while e >= 0 { while e != EMPTY {
let ee = &self.entries[e as usize]; let ee = &self.entries[e as usize];
debug_assert!(ee.bucket >= 0); debug_assert!(ee.bucket != EMPTY);
if unsafe { ee.key.assume_init_ref().eq(key) } { if unsafe { ee.key.assume_init_ref().eq(key) } {
return Some(unsafe { &ee.value.assume_init_ref() }); return Some(unsafe { &ee.value.assume_init_ref() });
} }
@ -137,7 +139,7 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
return None; return None;
} }
/// Get an entry, creating if not present. /// Get an entry, creating if not present, and return a mutable reference to it.
#[inline] #[inline]
pub fn get_or_create_mut<CF: FnOnce() -> V>(&mut self, key: &K, create: CF) -> &mut V { pub fn get_or_create_mut<CF: FnOnce() -> V>(&mut self, key: &K, create: CF) -> &mut V {
let mut h = XorShiftHasher::new(self.salt); let mut h = XorShiftHasher::new(self.salt);
@ -145,10 +147,10 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
let bucket = (h.finish() as usize) % B; let bucket = (h.finish() as usize) % B;
let mut e = self.buckets[bucket]; let mut e = self.buckets[bucket];
while e >= 0 { while e != EMPTY {
unsafe { unsafe {
let e_ptr = &mut *self.entries.as_mut_ptr().add(e as usize); let e_ptr = &mut *self.entries.as_mut_ptr().add(e as usize);
debug_assert!(e_ptr.bucket >= 0); debug_assert!(e_ptr.bucket != EMPTY);
if e_ptr.key.assume_init_ref().eq(key) { if e_ptr.key.assume_init_ref().eq(key) {
return e_ptr.value.assume_init_mut(); return e_ptr.value.assume_init_mut();
} }
@ -167,9 +169,9 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
let bucket = (h.finish() as usize) % B; let bucket = (h.finish() as usize) % B;
let mut e = self.buckets[bucket]; let mut e = self.buckets[bucket];
while e >= 0 { while e != EMPTY {
let e_ptr = &mut self.entries[e as usize]; let e_ptr = &mut self.entries[e as usize];
debug_assert!(e_ptr.bucket >= 0); debug_assert!(e_ptr.bucket != EMPTY);
if unsafe { e_ptr.key.assume_init_ref().eq(&key) } { if unsafe { e_ptr.key.assume_init_ref().eq(&key) } {
unsafe { *e_ptr.value.assume_init_mut() = value }; unsafe { *e_ptr.value.assume_init_mut() = value };
return; return;
@ -177,7 +179,7 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
e = e_ptr.next; e = e_ptr.next;
} }
self.internal_add(bucket, key, value); let _ = self.internal_add(bucket, key, value);
} }
#[inline] #[inline]
@ -186,8 +188,8 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
self.entry_ptr = self.entry_ptr.wrapping_add(1); self.entry_ptr = self.entry_ptr.wrapping_add(1);
let e_ptr = unsafe { &mut *self.entries.as_mut_ptr().add(e) }; let e_ptr = unsafe { &mut *self.entries.as_mut_ptr().add(e) };
if e_ptr.bucket >= 0 { if e_ptr.bucket != EMPTY {
if e_ptr.prev >= 0 { if e_ptr.prev != EMPTY {
self.entries[e_ptr.prev as usize].next = e_ptr.next; self.entries[e_ptr.prev as usize].next = e_ptr.next;
} else { } else {
self.buckets[e_ptr.bucket as usize] = e_ptr.next; self.buckets[e_ptr.bucket as usize] = e_ptr.next;
@ -200,13 +202,13 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
e_ptr.key.write(key); e_ptr.key.write(key);
e_ptr.value.write(value); e_ptr.value.write(value);
e_ptr.bucket = bucket as i32; e_ptr.bucket = bucket as u16;
e_ptr.next = self.buckets[bucket]; e_ptr.next = self.buckets[bucket];
if e_ptr.next >= 0 { if e_ptr.next != EMPTY {
self.entries[e_ptr.next as usize].prev = e as i32; self.entries[e_ptr.next as usize].prev = e as u16;
} }
self.buckets[bucket] = e as i32; self.buckets[bucket] = e as u16;
e_ptr.prev = -1; e_ptr.prev = EMPTY;
unsafe { e_ptr.value.assume_init_mut() } unsafe { e_ptr.value.assume_init_mut() }
} }
} }
@ -215,7 +217,7 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> Drop f
#[inline] #[inline]
fn drop(&mut self) { fn drop(&mut self) {
for e in self.entries.iter_mut() { for e in self.entries.iter_mut() {
if e.bucket >= 0 { if e.bucket != EMPTY {
unsafe { unsafe {
e.key.assume_init_drop(); e.key.assume_init_drop();
e.value.assume_init_drop(); e.value.assume_init_drop();