Fix some stuff, perf fixes, and add a check to the ZSSP header to make fragmentation attacks hard even for a MITM.

This commit is contained in:
Adam Ierymenko 2022-09-07 13:18:52 -04:00
parent d5943f246a
commit 42f6f016e9
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
3 changed files with 338 additions and 167 deletions

View file

@ -1,11 +1,12 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
use std::collections::LinkedList;
use std::io::{Read, Write};
use std::num::NonZeroU64;
use std::ops::Deref;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::aes::AesGcm;
use crate::aes::{Aes, AesGcm};
use crate::hash::{hmac_sha384, hmac_sha512, SHA384};
use crate::p384::{P384KeyPair, P384PublicKey, P384_PUBLIC_KEY_SIZE};
use crate::random;
@ -53,12 +54,18 @@ const E1_TYPE_KYBER1024: u8 = 1;
const MAX_FRAGMENTS: usize = 48; // protocol max: 63
const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc.
const HEADER_SIZE: usize = 12;
const HEADER_SIZE: usize = 16;
const HEADER_CHECK_SIZE: usize = 4;
const AES_GCM_TAG_SIZE: usize = 16;
const AES_GCM_NONCE_SIZE: usize = 12;
const AES_GCM_NONCE_START: usize = 4;
const AES_GCM_NONCE_END: usize = 16;
const HMAC_SIZE: usize = 48;
const SESSION_ID_SIZE: usize = 6;
const KEY_HISTORY_SIZE_MAX: usize = 3;
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
const KBKDF_KEY_USAGE_LABEL_HEADER_MAC: u8 = b'H';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
@ -243,6 +250,7 @@ pub struct Session<H: Host> {
send_counter: Counter,
psk: Secret<64>, // Arbitrary PSK provided by external code
ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer
header_check_cipher: Aes, // Cipher used for fast 32-bit header MAC
state: RwLock<MutableState>, // Mutable parts of state (other than defrag buffers)
remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
@ -251,13 +259,14 @@ pub struct Session<H: Host> {
struct MutableState {
remote_session_id: Option<SessionId>,
keys: [Option<SessionKey>; 2], // current, next (promoted to current on successful decrypt)
keys: LinkedList<SessionKey>,
offer: Option<EphemeralOffer>,
}
/// State information to associate with receiving contexts such as sockets or remote paths/endpoints.
pub struct ReceiveContext<H: Host> {
initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 128>>,
incoming_init_header_check_cipher: Aes,
}
impl<H: Host> Session<H> {
@ -276,11 +285,13 @@ impl<H: Host> Session<H> {
) -> Result<Self, Error> {
if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) {
if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) {
let counter = Counter::new();
let send_counter = Counter::new();
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
let remote_s_public_hash = SHA384::hash(remote_s_public);
let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
if let Ok(offer) = EphemeralOffer::create_alice_offer(
&mut send,
counter.next(),
send_counter.next(),
local_session_id,
None,
host.get_local_s_public(),
@ -288,6 +299,7 @@ impl<H: Host> Session<H> {
&remote_s_public_p384,
&remote_s_public_hash,
&ss,
&outgoing_init_header_check_cipher,
mtu,
current_time,
jedi,
@ -295,12 +307,13 @@ impl<H: Host> Session<H> {
return Ok(Self {
id: local_session_id,
associated_object,
send_counter: counter,
send_counter,
psk: psk.clone(),
ss,
header_check_cipher,
state: RwLock::new(MutableState {
remote_session_id: None,
keys: [None, None],
keys: LinkedList::new(),
offer: Some(offer),
}),
remote_s_public_hash,
@ -313,10 +326,59 @@ impl<H: Host> Session<H> {
return Err(Error::InvalidParameter);
}
#[inline]
pub fn send<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction, mtu_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> {
debug_assert!(mtu_buffer.len() >= MIN_MTU);
let state = self.state.read();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys.front() {
let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
let counter = self.send_counter.next();
send_with_fragmentation_init_header(mtu_buffer, packet_len, mtu_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), counter);
let mut c = key.get_send_cipher(counter)?;
c.init(mtu_buffer);
if packet_len > mtu_buffer.len() {
let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap();
let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE;
let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
loop {
debug_assert!(data.len() > last_fragment_data_mtu);
c.crypt(&data[HEADER_CHECK_SIZE..fragment_data_mtu], &mut mtu_buffer[HEADER_SIZE..]);
let hc = header_check(mtu_buffer, &self.header_check_cipher);
mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
send(mtu_buffer);
data = &data[fragment_data_mtu..];
debug_assert!(header[7].wrapping_shr(2) < 63);
header[7] += 0x04; // increment fragment number
mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].copy_from_slice(&header);
if data.len() <= last_fragment_data_mtu {
break;
}
}
packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
}
let gcm_tag_idx = data.len() + HEADER_SIZE;
c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]);
let hc = header_check(mtu_buffer, &self.header_check_cipher);
mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish());
send(&mut mtu_buffer[..packet_len]);
key.return_send_cipher(c);
return Ok(());
}
}
return Err(Error::SessionNotEstablished);
}
#[inline]
pub fn rekey_check<SendFunction: FnMut(&mut [u8])>(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) {
let state = self.state.upgradable_read();
if let Some(key) = state.keys[0].as_ref() {
if let Some(key) = state.keys.front() {
if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) {
if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) {
if let Ok(offer) = EphemeralOffer::create_alice_offer(
@ -329,6 +391,7 @@ impl<H: Host> Session<H> {
&remote_s_public_p384,
&self.remote_s_public_hash,
&self.ss,
&self.header_check_cipher,
mtu,
current_time,
jedi,
@ -343,9 +406,10 @@ impl<H: Host> Session<H> {
impl<H: Host> ReceiveContext<H> {
#[inline]
pub fn new() -> Self {
pub fn new(host: &H) -> Self {
Self {
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()),
}
}
@ -366,17 +430,22 @@ impl<H: Host> ReceiveContext<H> {
return Err(Error::InvalidPacket);
}
let header_0_8 = memory::u64_from_le_bytes(incoming_packet); // type, frag info, session ID
let counter = memory::u32_from_le_bytes(&incoming_packet[8..]);
let local_session_id = SessionId::new_from_u64(header_0_8.wrapping_shr(16));
let packet_type = (header_0_8 as u8) & 15;
let fragment_count = ((header_0_8.wrapping_shr(4) as u8) & 63).wrapping_add(1);
let fragment_no = (header_0_8.wrapping_shr(10) as u8) & 63;
let header_0_8 = memory::u64_from_le_bytes(&incoming_packet[HEADER_CHECK_SIZE..12]); // session ID, type, frag info
let counter = memory::u32_from_le_bytes(&incoming_packet[12..16]);
let local_session_id = SessionId::new_from_u64(header_0_8 & SessionId::MAX_BIT_MASK);
let packet_type = (header_0_8.wrapping_shr(48) as u8) & 15;
let fragment_count = ((header_0_8.wrapping_shr(52) as u8) & 63).wrapping_add(1);
let fragment_no = (header_0_8.wrapping_shr(58) as u8) & 63;
if fragment_count > 1 {
if let Some(local_session_id) = local_session_id {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
if let Some(session) = host.session_lookup(local_session_id) {
if let Some(local_session_id) = local_session_id {
if let Some(session) = host.session_lookup(local_session_id) {
if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &session.header_check_cipher) {
unlikely_branch();
return Err(Error::FailedAuthentication);
}
if fragment_count > 1 {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
let mut defrag = session.defrag.lock();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
@ -385,38 +454,30 @@ impl<H: Host> ReceiveContext<H> {
}
} else {
unlikely_branch();
return Err(Error::UnknownLocalSessionId(local_session_id));
return Err(Error::InvalidPacket);
}
} else {
unlikely_branch();
return Err(Error::InvalidPacket);
return self.receive_complete(host, &mut send, data_buf, &[incoming_packet_buf], packet_type, Some(session), mtu, jedi, current_time);
}
} else {
if fragment_count <= (KEY_EXCHANGE_MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
let mut defrag = self.initial_offer_defrag.lock();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock
return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
}
} else {
unlikely_branch();
return Err(Error::InvalidPacket);
}
unlikely_branch();
return Err(Error::UnknownLocalSessionId(local_session_id));
}
} else {
return self.receive_complete(
host,
&mut send,
data_buf,
&[incoming_packet_buf],
packet_type,
local_session_id.and_then(|lsid| host.session_lookup(lsid)),
mtu,
jedi,
current_time,
);
}
unlikely_branch();
if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &self.incoming_init_header_check_cipher) {
unlikely_branch();
return Err(Error::FailedAuthentication);
}
let mut defrag = self.initial_offer_defrag.lock();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock
return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
}
};
return Ok(ReceiveResult::Ok);
}
@ -440,60 +501,63 @@ impl<H: Host> ReceiveContext<H> {
if packet_type <= PACKET_TYPE_NOP {
if let Some(session) = session {
let state = session.state.read();
for ki in 0..2 {
if let Some(key) = state.keys[ki].as_ref() {
let tail = fragments.last().unwrap().as_ref();
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
let key_count = state.keys.len();
for (key_index, key) in state.keys.iter().enumerate() {
let tail = fragments.last().unwrap().as_ref();
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
let mut c = key.get_receive_cipher();
c.init(&get_aes_gcm_nonce(fragments.first().unwrap().as_ref()));
let mut c = key.get_receive_cipher();
c.init(&fragments.first().unwrap().as_ref()[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
let mut data_len = 0;
for f in fragments[..(fragments.len() - 1)].iter() {
let f = f.as_ref();
debug_assert!(f.len() >= HEADER_SIZE);
let current_frag_data_start = data_len;
data_len += f.len() - HEADER_SIZE;
if data_len > data_buf.len() {
unlikely_branch();
key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
}
let mut data_len = 0;
for f in fragments[..(fragments.len() - 1)].iter() {
let f = f.as_ref();
debug_assert!(f.len() >= HEADER_SIZE);
let current_frag_data_start = data_len;
data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
data_len += f.len() - HEADER_SIZE;
if data_len > data_buf.len() {
unlikely_branch();
key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
}
let tag = c.finish();
let current_frag_data_start = data_len;
data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
if data_len > data_buf.len() {
unlikely_branch();
key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) {
// If this succeeded with the "next" key, promote it to current.
if ki == 1 {
unlikely_branch();
drop(state);
let mut state = session.state.write();
state.keys[0] = state.keys[1].take();
}
let tag = c.finish();
key.return_receive_cipher(c);
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else {
unlikely_branch();
return Ok(ReceiveResult::Ok);
if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) {
// Drop obsolete keys if we had to iterate past the first key to get here.
if key_index > 0 {
unlikely_branch();
drop(state);
let mut state = session.state.write();
if state.keys.len() == key_count {
for _ in 0..key_index {
let _ = state.keys.pop_front();
}
}
}
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else {
unlikely_branch();
return Ok(ReceiveResult::Ok);
}
}
}
return Err(Error::FailedAuthentication);
@ -541,7 +605,7 @@ impl<H: Host> ReceiveContext<H> {
let hmac1_end = incoming_packet_len - HMAC_SIZE;
// Check that the sender knows this host's identity before doing anything else.
if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[..hmac1_end]).eq(&incoming_packet[hmac1_end..]) {
if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[HEADER_CHECK_SIZE..hmac1_end]).eq(&incoming_packet[hmac1_end..]) {
return Err(Error::FailedAuthentication);
}
@ -561,7 +625,7 @@ impl<H: Host> ReceiveContext<H> {
let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes()));
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false);
c.init(&get_aes_gcm_nonce(incoming_packet));
c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
let c = c.finish();
if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) {
@ -582,7 +646,12 @@ impl<H: Host> ReceiveContext<H> {
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]).eq(&incoming_packet[aes_gcm_tag_end..hmac1_end]) {
if !hmac_sha384(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
&original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end],
)
.eq(&incoming_packet[aes_gcm_tag_end..hmac1_end])
{
return Err(Error::FailedAuthentication);
}
@ -596,15 +665,17 @@ impl<H: Host> ReceiveContext<H> {
None
} else {
if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) {
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
Some(Session::<H> {
id: new_session_id,
associated_object,
send_counter: Counter::new(),
psk,
ss,
header_check_cipher,
state: RwLock::new(MutableState {
remote_session_id: Some(alice_session_id),
keys: [None, None],
keys: LinkedList::new(),
offer: None,
}),
remote_s_public_hash: SHA384::hash(&alice_s_public),
@ -656,13 +727,14 @@ impl<H: Host> ReceiveContext<H> {
} else {
rp.write_all(&[E1_TYPE_NONE])?;
}
(MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len()
};
let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter);
reply_buf[..HEADER_SIZE].copy_from_slice(&header);
send_with_fragmentation_init_header(&mut reply_buf, reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter);
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true);
c.init(&get_aes_gcm_nonce(&header));
c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]);
let c = c.finish();
reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c);
@ -674,18 +746,19 @@ impl<H: Host> ReceiveContext<H> {
// key derivation step.
let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes()));
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[..reply_len]);
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[HEADER_CHECK_SIZE..reply_len]);
reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac);
reply_len += HMAC_SIZE;
let mut state = session.state.write();
let _ = state.remote_session_id.replace(alice_session_id);
state.keys[1].replace(SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi));
add_key(&mut state.keys, SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi));
drop(state);
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &mut header);
send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &session.header_check_cipher);
if new_session.is_some() {
return Ok(ReceiveResult::OkNewSession(new_session.unwrap()));
} else {
@ -716,7 +789,7 @@ impl<H: Host> ReceiveContext<H> {
));
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false);
c.init(&get_aes_gcm_nonce(incoming_packet));
c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
let c = c.finish();
if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) {
@ -739,8 +812,11 @@ impl<H: Host> ReceiveContext<H> {
let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes()));
if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end])
.eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()])
if !hmac_sha384(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
&original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end],
)
.eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()])
{
return Err(Error::FailedAuthentication);
}
@ -751,20 +827,22 @@ impl<H: Host> ReceiveContext<H> {
let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi);
let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
let header = send_with_fragmentation_init_header(HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter);
reply_buf[..HEADER_SIZE].copy_from_slice(&header);
send_with_fragmentation_init_header(&mut reply_buf, HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter);
let mut c = key.get_send_cipher(counter)?;
c.init(&get_aes_gcm_nonce(&reply_buf));
c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish());
key.return_send_cipher(c);
let hc = header_check(&reply_buf, &session.header_check_cipher);
reply_buf[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
send(&mut reply_buf);
let mut state = RwLockUpgradableReadGuard::upgrade(state);
let _ = state.remote_session_id.replace(bob_session_id);
let _ = state.offer.take();
let _ = state.keys[0].insert(key);
add_key(&mut state.keys, key);
return Ok(ReceiveResult::Ok);
}
@ -861,6 +939,7 @@ impl EphemeralOffer {
bob_s_public_p384: &P384PublicKey,
bob_s_public_hash: &[u8],
ss: &Secret<48>,
header_check_cipher: &Aes,
mtu: usize,
current_time: i64,
jedi: bool,
@ -900,8 +979,7 @@ impl EphemeralOffer {
PACKET_BUF_SIZE - p.len()
};
let mut header = send_with_fragmentation_init_header(packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter);
packet_buf[..HEADER_SIZE].copy_from_slice(&header);
send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter);
let key = Secret(hmac_sha512(
&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()),
@ -910,7 +988,7 @@ impl EphemeralOffer {
let gcm_tag = {
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true);
c.init(&get_aes_gcm_nonce(&packet_buf));
c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]);
c.finish()
};
@ -919,15 +997,15 @@ impl EphemeralOffer {
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[..packet_len]);
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE;
let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[..packet_len]);
let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE;
send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, &mut header);
send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher);
Ok(EphemeralOffer {
creation_time: current_time,
@ -939,7 +1017,7 @@ impl EphemeralOffer {
}
#[inline(always)]
fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] {
fn send_with_fragmentation_init_header(header: &mut [u8], packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) {
let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
debug_assert!(mtu >= MIN_MTU);
debug_assert!(packet_len >= HEADER_SIZE);
@ -947,27 +1025,26 @@ fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_typ
debug_assert!(fragment_count > 0);
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
// Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter
((((((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u64) | recipient_session_id.wrapping_shl(16)) as u128) | (counter.to_u32() as u128).wrapping_shl(64)).to_le_bytes()
[..HEADER_SIZE]
.try_into()
.unwrap()
header[HEADER_CHECK_SIZE..12].copy_from_slice(&(recipient_session_id | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)).to_le_bytes());
header[12..HEADER_SIZE].copy_from_slice(&counter.to_u32().to_le_bytes());
}
#[inline(always)]
fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header: &mut [u8; HEADER_SIZE]) {
/// Send a packet in fragments (used for everything but DATA which has a hand-rolled version for performance).
fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header_check_cipher: &Aes) {
let packet_len = packet.len();
let mut fragment_start = 0;
let mut fragment_end = packet_len.min(mtu);
let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = packet[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap();
loop {
let hc = header_check(&packet[fragment_start..], header_check_cipher);
packet[fragment_start..(fragment_start + HEADER_CHECK_SIZE)].copy_from_slice(&hc.to_ne_bytes());
send(&mut packet[fragment_start..fragment_end]);
if fragment_end < packet_len {
debug_assert!(header[7].wrapping_shr(2) < 63);
header[7] += 0x04; // increment fragment number
fragment_start = fragment_end - HEADER_SIZE;
fragment_end = (fragment_start + mtu).min(packet_len);
debug_assert!(header[1].wrapping_shr(2) < 63);
header[1] += 0x04; // increment fragment number in least significant 6 bits of byte 1
packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header);
packet[(fragment_start + HEADER_CHECK_SIZE)..(fragment_start + HEADER_SIZE)].copy_from_slice(&header);
} else {
debug_assert_eq!(fragment_end, packet_len);
break;
@ -975,6 +1052,31 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(send: &mut SendFuncti
}
}
/// Compute the 32-bit header check code for a packet on receipt or right before send.
fn header_check(packet: &[u8], header_check_cipher: &Aes) -> u32 {
debug_assert!(packet.len() >= HEADER_SIZE);
let mut header_check = 0u128.to_ne_bytes();
if packet.len() >= (16 + HEADER_CHECK_SIZE) {
header_check_cipher.encrypt_block(&packet[HEADER_CHECK_SIZE..(16 + HEADER_CHECK_SIZE)], &mut header_check);
} else {
unlikely_branch();
header_check[..(packet.len() - HEADER_CHECK_SIZE)].copy_from_slice(&packet[HEADER_CHECK_SIZE..]);
header_check_cipher.encrypt_block_in_place(&mut header_check);
}
memory::u32_from_ne_bytes(&header_check)
}
/// Add a new session key to the key list, retiring older non-active keys if necessary.
fn add_key(keys: &mut LinkedList<SessionKey>, key: SessionKey) {
debug_assert!(KEY_HISTORY_SIZE_MAX >= 2);
while keys.len() >= KEY_HISTORY_SIZE_MAX {
let current = keys.pop_front().unwrap();
let _ = keys.pop_front();
keys.push_front(current);
}
keys.push_back(key);
}
fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<(SessionId, &[u8], &[u8], &[u8]), Error> {
let mut p = &incoming_packet[..];
let alice_session_id = SessionId::new_from_reader(&mut p)?;
@ -1083,13 +1185,6 @@ fn kbkdf512(key: &[u8], label: u8) -> Secret<64> {
Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00]))
}
#[inline(always)]
fn get_aes_gcm_nonce(packet: &[u8]) -> [u8; 16] {
let mut tmp = 0u128.to_ne_bytes();
tmp[..HEADER_SIZE].copy_from_slice(&packet[..HEADER_SIZE]);
tmp
}
#[cfg(test)]
mod tests {
use parking_lot::Mutex;
@ -1175,7 +1270,8 @@ mod tests {
random::fill_bytes_secure(&mut psk.0);
let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob"));
let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice"));
let rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new());
let alice_rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new(&alice_host));
let bob_rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new(&bob_host));
let mut data_buf = [0_u8; 4096];
//println!("zssp: size of session (bytes): {}", std::mem::size_of::<Session<Box<TestHost>>>());
@ -1207,6 +1303,12 @@ mod tests {
}
};
let rc = if std::ptr::eq(host, &alice_host) {
&alice_rc
} else {
&bob_rc
};
loop {
if let Some(qi) = host.queue.lock().pop_back() {
let qi_len = qi.len();
@ -1232,7 +1334,8 @@ mod tests {
}
}
} else {
println!("zssp: {} => {}: error: {}", host.other_name, host.this_name, r.err().unwrap().to_string());
println!("zssp: {} => {} ({}): error: {}", host.other_name, host.this_name, qi_len, r.err().unwrap().to_string());
panic!();
}
} else {
break;

View file

@ -39,6 +39,42 @@ mod fast_int_memory_access {
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u64_from_ne_bytes(b: &[u8]) -> u64 {
assert!(b.len() >= 8);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u32_from_ne_bytes(b: &[u8]) -> u32 {
assert!(b.len() >= 4);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u16_from_ne_bytes(b: &[u8]) -> u16 {
assert!(b.len() >= 2);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i64_from_ne_bytes(b: &[u8]) -> i64 {
assert!(b.len() >= 8);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i32_from_ne_bytes(b: &[u8]) -> i32 {
assert!(b.len() >= 4);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i16_from_ne_bytes(b: &[u8]) -> i16 {
assert!(b.len() >= 2);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
assert!(b.len() >= 8);
@ -109,6 +145,36 @@ mod fast_int_memory_access {
i16::from_le_bytes(b[..2].try_into().unwrap())
}
#[inline(always)]
pub fn u64_from_ne_bytes(b: &[u8]) -> u64 {
u64::from_ne_bytes(b[..8].try_into().unwrap())
}
#[inline(always)]
pub fn u32_from_ne_bytes(b: &[u8]) -> u32 {
u32::from_ne_bytes(b[..4].try_into().unwrap())
}
#[inline(always)]
pub fn u16_from_ne_bytes(b: &[u8]) -> u16 {
u16::from_ne_bytes(b[..2].try_into().unwrap())
}
#[inline(always)]
pub fn i64_from_ne_bytes(b: &[u8]) -> i64 {
i64::from_ne_bytes(b[..8].try_into().unwrap())
}
#[inline(always)]
pub fn i32_from_ne_bytes(b: &[u8]) -> i32 {
i32::from_ne_bytes(b[..4].try_into().unwrap())
}
#[inline(always)]
pub fn i16_from_ne_bytes(b: &[u8]) -> i16 {
i16::from_ne_bytes(b[..2].try_into().unwrap())
}
#[inline(always)]
pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
u64::from_be_bytes(b[..8].try_into().unwrap())

View file

@ -4,6 +4,8 @@ use std::hash::{Hash, Hasher};
use std::mem::MaybeUninit;
const EMPTY: u16 = 0xffff;
#[inline(always)]
fn xorshift64(mut x: u64) -> u64 {
x ^= x.wrapping_shl(13);
@ -78,9 +80,9 @@ impl Hasher for XorShiftHasher {
struct Entry<K: Eq + PartialEq + Hash + Clone, V> {
key: MaybeUninit<K>,
value: MaybeUninit<V>,
bucket: i32, // which bucket is this in? -1 for none
next: i32, // next item in bucket's linked list, -1 for none
prev: i32, // previous entry to permit deletion of old entries from bucket lists
bucket: u16, // which bucket is this in? EMPTY for none
next: u16, // next item in bucket's linked list, EMPTY for none
prev: u16, // previous entry to permit deletion of old entries from bucket lists
}
/// A hybrid between a circular buffer and a map.
@ -90,35 +92,35 @@ struct Entry<K: Eq + PartialEq + Hash + Clone, V> {
/// with a HashMap but that would be less efficient. This requires no memory allocations unless
/// the K or V types allocate memory and occupies a fixed amount of memory.
///
/// This is pretty basic and doesn't have a remove function. Old entries just roll off. This
/// only contains what is needed elsewhere in the project.
/// There is no explicit remove since that would require more complex logic to maintain FIFO
/// ordering for replacement of entries. Old entries just roll off the end.
///
/// This is used for things like defragmenting incoming packets to support multiple fragmented
/// packets in flight. Having no allocations is good to reduce the potential for memory
/// exhaustion attacks.
///
/// The C template parameter is the total capacity while the B parameter is the number of
/// buckets in the hash table.
/// buckets in the hash table. The maximum for both these parameters is 65535. This could be
/// increased by making the index variables larger (e.g. u32 instead of u16).
pub struct RingBufferMap<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> {
entries: [Entry<K, V>; C],
buckets: [i32; B],
entry_ptr: u32,
salt: u32,
entries: [Entry<K, V>; C],
buckets: [u16; B],
entry_ptr: u16,
}
impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBufferMap<K, V, C, B> {
/// Create a new map with the supplied random salt to perturb the hashing function.
#[inline]
pub fn new(salt: u32) -> Self {
Self {
entries: {
let mut entries: [Entry<K, V>; C] = unsafe { MaybeUninit::uninit().assume_init() };
for e in entries.iter_mut() {
e.bucket = -1;
e.next = -1;
e.prev = -1;
}
entries
},
buckets: [-1; B],
entry_ptr: 0,
salt,
}
debug_assert!(C <= EMPTY as usize);
debug_assert!(B <= EMPTY as usize);
let mut tmp: Self = unsafe { MaybeUninit::uninit().assume_init() };
// EMPTY is the maximum value of the indices, which is all 0xff, so this sets all indices to EMPTY.
unsafe { std::ptr::write_bytes(&mut tmp, 0xff, 1) };
tmp.salt = salt;
tmp.entry_ptr = 0;
tmp
}
#[inline]
@ -126,9 +128,9 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
let mut h = XorShiftHasher::new(self.salt);
key.hash(&mut h);
let mut e = self.buckets[(h.finish() as usize) % B];
while e >= 0 {
while e != EMPTY {
let ee = &self.entries[e as usize];
debug_assert!(ee.bucket >= 0);
debug_assert!(ee.bucket != EMPTY);
if unsafe { ee.key.assume_init_ref().eq(key) } {
return Some(unsafe { &ee.value.assume_init_ref() });
}
@ -137,7 +139,7 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
return None;
}
/// Get an entry, creating if not present.
/// Get an entry, creating if not present, and return a mutable reference to it.
#[inline]
pub fn get_or_create_mut<CF: FnOnce() -> V>(&mut self, key: &K, create: CF) -> &mut V {
let mut h = XorShiftHasher::new(self.salt);
@ -145,10 +147,10 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
let bucket = (h.finish() as usize) % B;
let mut e = self.buckets[bucket];
while e >= 0 {
while e != EMPTY {
unsafe {
let e_ptr = &mut *self.entries.as_mut_ptr().add(e as usize);
debug_assert!(e_ptr.bucket >= 0);
debug_assert!(e_ptr.bucket != EMPTY);
if e_ptr.key.assume_init_ref().eq(key) {
return e_ptr.value.assume_init_mut();
}
@ -167,9 +169,9 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
let bucket = (h.finish() as usize) % B;
let mut e = self.buckets[bucket];
while e >= 0 {
while e != EMPTY {
let e_ptr = &mut self.entries[e as usize];
debug_assert!(e_ptr.bucket >= 0);
debug_assert!(e_ptr.bucket != EMPTY);
if unsafe { e_ptr.key.assume_init_ref().eq(&key) } {
unsafe { *e_ptr.value.assume_init_mut() = value };
return;
@ -177,7 +179,7 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
e = e_ptr.next;
}
self.internal_add(bucket, key, value);
let _ = self.internal_add(bucket, key, value);
}
#[inline]
@ -186,8 +188,8 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
self.entry_ptr = self.entry_ptr.wrapping_add(1);
let e_ptr = unsafe { &mut *self.entries.as_mut_ptr().add(e) };
if e_ptr.bucket >= 0 {
if e_ptr.prev >= 0 {
if e_ptr.bucket != EMPTY {
if e_ptr.prev != EMPTY {
self.entries[e_ptr.prev as usize].next = e_ptr.next;
} else {
self.buckets[e_ptr.bucket as usize] = e_ptr.next;
@ -200,13 +202,13 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
e_ptr.key.write(key);
e_ptr.value.write(value);
e_ptr.bucket = bucket as i32;
e_ptr.bucket = bucket as u16;
e_ptr.next = self.buckets[bucket];
if e_ptr.next >= 0 {
self.entries[e_ptr.next as usize].prev = e as i32;
if e_ptr.next != EMPTY {
self.entries[e_ptr.next as usize].prev = e as u16;
}
self.buckets[bucket] = e as i32;
e_ptr.prev = -1;
self.buckets[bucket] = e as u16;
e_ptr.prev = EMPTY;
unsafe { e_ptr.value.assume_init_mut() }
}
}
@ -215,7 +217,7 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> Drop f
#[inline]
fn drop(&mut self) {
for e in self.entries.iter_mut() {
if e.bucket >= 0 {
if e.bucket != EMPTY {
unsafe {
e.key.assume_init_drop();
e.value.assume_init_drop();