mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-04-22 06:56:54 +02:00
A bunch of ZSSP cleanup and optimization. Runs a bit faster now.
This commit is contained in:
parent
7072338037
commit
f83bf41427
2 changed files with 292 additions and 282 deletions
|
@ -10,7 +10,6 @@ use std::fmt::Display;
|
|||
use std::num::NonZeroU64;
|
||||
|
||||
use zerotier_crypto::random;
|
||||
use zerotier_utils::memory::{array_range, as_byte_array};
|
||||
|
||||
/// 48-bit session ID (most significant 16 bits of u64 are unused)
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
|
@ -25,6 +24,7 @@ impl SessionId {
|
|||
pub const MAX: u64 = 0xffffffffffff;
|
||||
|
||||
/// Create a new session ID, panicing if 'i' is zero or exceeds MAX.
|
||||
#[inline(always)]
|
||||
pub fn new(i: u64) -> SessionId {
|
||||
assert!(i <= Self::MAX);
|
||||
Self(NonZeroU64::new(i.to_le()).unwrap())
|
||||
|
@ -35,22 +35,23 @@ impl SessionId {
|
|||
Self(NonZeroU64::new(((random::xorshift64_random() % (Self::MAX - 1)) + 1).to_le()).unwrap())
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(b: &[u8; Self::SIZE]) -> Option<SessionId> {
|
||||
let mut tmp = [0u8; 8];
|
||||
#[inline(always)]
|
||||
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
|
||||
self.0.get().to_ne_bytes()[..Self::SIZE].try_into().unwrap()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn new_from_bytes(b: &[u8]) -> Option<SessionId> {
|
||||
let mut tmp = 0u64.to_ne_bytes();
|
||||
tmp[..SESSION_ID_SIZE_BYTES].copy_from_slice(b);
|
||||
Self::new_from_u64_le(u64::from_ne_bytes(tmp))
|
||||
NonZeroU64::new(u64::from_ne_bytes(tmp)).map(|i| Self(i))
|
||||
}
|
||||
|
||||
/// Create from a u64 that is already in little-endian byte order.
|
||||
#[inline(always)]
|
||||
pub(crate) fn new_from_u64_le(i: u64) -> Option<SessionId> {
|
||||
NonZeroU64::new(i & Self::MAX.to_le()).map(|i| Self(i))
|
||||
}
|
||||
|
||||
/// Get this session ID as a little-endian byte array.
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_bytes(&self) -> &[u8; Self::SIZE] {
|
||||
array_range::<u8, 8, 0, SESSION_ID_SIZE_BYTES>(as_byte_array(&self.0))
|
||||
pub fn new_from_array(b: &[u8; Self::SIZE]) -> Option<SessionId> {
|
||||
let mut tmp = 0u64.to_ne_bytes();
|
||||
tmp[..SESSION_ID_SIZE_BYTES].copy_from_slice(b);
|
||||
NonZeroU64::new(u64::from_ne_bytes(tmp)).map(|i| Self(i))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
547
zssp/src/zssp.rs
547
zssp/src/zssp.rs
|
@ -56,7 +56,7 @@ struct SessionsById<Application: ApplicationLayer> {
|
|||
active: HashMap<SessionId, Weak<Session<Application>>>,
|
||||
|
||||
// Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout.
|
||||
incoming: HashMap<SessionId, Arc<IncomingIncompleteSession>>,
|
||||
incoming: HashMap<SessionId, Arc<IncomingIncompleteSession<Application>>>,
|
||||
}
|
||||
|
||||
/// Result generated by the context packet receive function, with possible payloads.
|
||||
|
@ -97,10 +97,10 @@ struct State {
|
|||
remote_session_id: Option<SessionId>,
|
||||
keys: [Option<SessionKey>; 2],
|
||||
current_key: usize,
|
||||
current_offer: Offer,
|
||||
outgoing_offer: Offer,
|
||||
}
|
||||
|
||||
struct IncomingIncompleteSession {
|
||||
struct IncomingIncompleteSession<Application: ApplicationLayer> {
|
||||
timestamp: i64,
|
||||
alice_session_id: SessionId,
|
||||
bob_session_id: SessionId,
|
||||
|
@ -109,6 +109,7 @@ struct IncomingIncompleteSession {
|
|||
hk: Secret<KYBER_SSBYTES>,
|
||||
header_protection_key: Secret<AES_HEADER_PROTECTION_KEY_SIZE>,
|
||||
bob_noise_e_secret: P384KeyPair,
|
||||
defrag: [Mutex<Fragged<Application::IncomingPacketBuffer, MAX_FRAGMENTS>>; MAX_NOISE_HANDSHAKE_FRAGMENTS],
|
||||
}
|
||||
|
||||
struct OutgoingSessionOffer {
|
||||
|
@ -184,7 +185,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
for (id, s) in sessions.active.iter() {
|
||||
if let Some(session) = s.upgrade() {
|
||||
let state = session.state.read().unwrap();
|
||||
if match &state.current_offer {
|
||||
if match &state.outgoing_offer {
|
||||
Offer::None => true,
|
||||
Offer::NoiseXKInit(offer) => {
|
||||
// If there's an outstanding attempt to open a session, retransmit this periodically
|
||||
|
@ -324,7 +325,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
remote_session_id: None,
|
||||
keys: [None, None],
|
||||
current_key: 0,
|
||||
current_offer: Offer::NoiseXKInit(Box::new(OutgoingSessionOffer {
|
||||
outgoing_offer: Offer::NoiseXKInit(Box::new(OutgoingSessionOffer {
|
||||
last_retry_time: AtomicI64::new(current_time),
|
||||
psk,
|
||||
noise_h: mix_hash(&mix_hash(&INITIAL_H, remote_s_public_blob), &alice_noise_e),
|
||||
|
@ -345,7 +346,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
|
||||
{
|
||||
let mut state = session.state.write().unwrap();
|
||||
let offer = if let Offer::NoiseXKInit(offer) = &mut state.current_offer {
|
||||
let offer = if let Offer::NoiseXKInit(offer) = &mut state.outgoing_offer {
|
||||
offer
|
||||
} else {
|
||||
panic!(); // should be impossible as this is what we initialized with
|
||||
|
@ -357,7 +358,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(init_packet).unwrap();
|
||||
init.session_protocol_version = SESSION_PROTOCOL_VERSION;
|
||||
init.alice_noise_e = alice_noise_e;
|
||||
init.alice_session_id = *local_session_id.as_bytes();
|
||||
init.alice_session_id = local_session_id.to_bytes();
|
||||
init.alice_hk_public = alice_hk_secret.public;
|
||||
init.header_protection_key = header_protection_key.0;
|
||||
}
|
||||
|
@ -417,6 +418,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
/// * `data_buf` - Buffer to receive decrypted and authenticated object data (an error is returned if too small)
|
||||
/// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership)
|
||||
/// * `current_time` - Current monotonic time in milliseconds
|
||||
#[inline]
|
||||
pub fn receive<
|
||||
'b,
|
||||
SendFunction: FnMut(Option<&Arc<Session<Application>>>, &mut [u8]),
|
||||
|
@ -430,112 +432,83 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
mut send: SendFunction,
|
||||
source: &Application::PhysicalPath,
|
||||
data_buf: &'b mut [u8],
|
||||
mut incoming_packet_buf: Application::IncomingPacketBuffer,
|
||||
mut incoming_physical_packet_buf: Application::IncomingPacketBuffer,
|
||||
current_time: i64,
|
||||
) -> Result<ReceiveResult<'b, Application>, Error> {
|
||||
let incoming_packet: &mut [u8] = incoming_packet_buf.as_mut();
|
||||
if incoming_packet.len() < MIN_PACKET_SIZE {
|
||||
let incoming_physical_packet: &mut [u8] = incoming_physical_packet_buf.as_mut();
|
||||
if incoming_physical_packet.len() < MIN_PACKET_SIZE {
|
||||
return Err(Error::InvalidPacket);
|
||||
}
|
||||
|
||||
let mut incoming = None;
|
||||
if let Some(local_session_id) = SessionId::new_from_u64_le(u64::from_le_bytes(incoming_packet[0..8].try_into().unwrap())) {
|
||||
if let Some(session) = self.sessions.read().unwrap().active.get(&local_session_id).and_then(|s| s.upgrade()) {
|
||||
if let Some(local_session_id) = SessionId::new_from_bytes(&incoming_physical_packet[0..SessionId::SIZE]) {
|
||||
let sessions = self.sessions.read().unwrap();
|
||||
if let Some(session) = sessions.active.get(&local_session_id).and_then(|s| s.upgrade()) {
|
||||
drop(sessions);
|
||||
debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id));
|
||||
|
||||
session
|
||||
.header_protection_cipher
|
||||
.decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet);
|
||||
.decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet);
|
||||
|
||||
if session.check_receive_window(incoming_counter) {
|
||||
if fragment_count > 1 {
|
||||
let mut fragged = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO].lock().unwrap();
|
||||
if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) {
|
||||
drop(fragged);
|
||||
return self.process_complete_incoming_packet(
|
||||
app,
|
||||
&mut send,
|
||||
&mut check_allow_incoming_session,
|
||||
&mut check_accept_session,
|
||||
data_buf,
|
||||
incoming_counter,
|
||||
assembled_packet.as_ref(),
|
||||
packet_type,
|
||||
Some(session),
|
||||
None,
|
||||
key_index,
|
||||
current_time,
|
||||
);
|
||||
let (assembled_packet, incoming_packet_buf_arr);
|
||||
let incoming_packet = if fragment_count > 1 {
|
||||
assembled_packet = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO]
|
||||
.lock()
|
||||
.unwrap()
|
||||
.assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
|
||||
if let Some(assembled_packet) = assembled_packet.as_ref() {
|
||||
assembled_packet.as_ref()
|
||||
} else {
|
||||
drop(fragged);
|
||||
return Ok(ReceiveResult::Ok(Some(session)));
|
||||
}
|
||||
} else {
|
||||
return self.process_complete_incoming_packet(
|
||||
app,
|
||||
&mut send,
|
||||
&mut check_allow_incoming_session,
|
||||
&mut check_accept_session,
|
||||
data_buf,
|
||||
incoming_counter,
|
||||
&[incoming_packet_buf],
|
||||
packet_type,
|
||||
Some(session),
|
||||
None,
|
||||
key_index,
|
||||
current_time,
|
||||
);
|
||||
}
|
||||
incoming_packet_buf_arr = [incoming_physical_packet_buf];
|
||||
&incoming_packet_buf_arr
|
||||
};
|
||||
|
||||
return self.process_complete_incoming_packet(
|
||||
app,
|
||||
&mut send,
|
||||
&mut check_allow_incoming_session,
|
||||
&mut check_accept_session,
|
||||
data_buf,
|
||||
incoming_counter,
|
||||
incoming_packet,
|
||||
packet_type,
|
||||
Some(session),
|
||||
None,
|
||||
key_index,
|
||||
current_time,
|
||||
);
|
||||
} else {
|
||||
return Err(Error::OutOfSequence);
|
||||
}
|
||||
} else {
|
||||
if let Some(i) = self.sessions.read().unwrap().incoming.get(&local_session_id).cloned() {
|
||||
Aes::new(&i.header_protection_key)
|
||||
.decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
incoming = Some(i);
|
||||
} else {
|
||||
return Err(Error::UnknownLocalSessionId);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(incoming) = sessions.incoming.get(&local_session_id).cloned() {
|
||||
drop(sessions);
|
||||
debug_assert!(!self.sessions.read().unwrap().active.contains_key(&local_session_id));
|
||||
|
||||
// If we make it here the packet is not associated with a session or is associated with an
|
||||
// incoming session (Noise_XK mid-negotiation).
|
||||
Aes::new(&incoming.header_protection_key)
|
||||
.decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet);
|
||||
|
||||
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet);
|
||||
if fragment_count > 1 {
|
||||
let f = {
|
||||
let mut defrag = self.defrag.lock().unwrap();
|
||||
let f = defrag
|
||||
.entry((source.clone(), incoming_counter))
|
||||
.or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time)))
|
||||
.clone();
|
||||
|
||||
// Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions.
|
||||
if defrag.len() >= self.max_incomplete_session_queue_size {
|
||||
// First, drop all entries that are timed out or whose physical source duplicates another entry.
|
||||
let mut sources = HashSet::with_capacity(defrag.len());
|
||||
let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
|
||||
defrag.retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f));
|
||||
|
||||
// Then, if we are still at or over the limit, drop 10% of remaining entries at random.
|
||||
if defrag.len() >= self.max_incomplete_session_queue_size {
|
||||
let mut rn = random::next_u32_secure();
|
||||
defrag.retain(|_, fragged| {
|
||||
rn = prng32(rn);
|
||||
rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f)
|
||||
});
|
||||
let (assembled_packet, incoming_packet_buf_arr);
|
||||
let incoming_packet = if fragment_count > 1 {
|
||||
assembled_packet = incoming.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO]
|
||||
.lock()
|
||||
.unwrap()
|
||||
.assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
|
||||
if let Some(assembled_packet) = assembled_packet.as_ref() {
|
||||
assembled_packet.as_ref()
|
||||
} else {
|
||||
return Ok(ReceiveResult::Ok(None));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
incoming_packet_buf_arr = [incoming_physical_packet_buf];
|
||||
&incoming_packet_buf_arr
|
||||
};
|
||||
|
||||
f
|
||||
};
|
||||
let mut fragged = f.0.lock().unwrap();
|
||||
|
||||
if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) {
|
||||
self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter));
|
||||
return self.process_complete_incoming_packet(
|
||||
app,
|
||||
&mut send,
|
||||
|
@ -543,15 +516,63 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
&mut check_accept_session,
|
||||
data_buf,
|
||||
incoming_counter,
|
||||
assembled_packet.as_ref(),
|
||||
incoming_packet,
|
||||
packet_type,
|
||||
None,
|
||||
incoming,
|
||||
Some(incoming),
|
||||
key_index,
|
||||
current_time,
|
||||
);
|
||||
} else {
|
||||
return Err(Error::UnknownLocalSessionId);
|
||||
}
|
||||
} else {
|
||||
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet);
|
||||
|
||||
let (assembled_packet, incoming_packet_buf_arr);
|
||||
let incoming_packet = if fragment_count > 1 {
|
||||
assembled_packet = {
|
||||
let mut defrag = self.defrag.lock().unwrap();
|
||||
let f = defrag
|
||||
.entry((source.clone(), incoming_counter))
|
||||
.or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time)))
|
||||
.clone();
|
||||
|
||||
// Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions.
|
||||
if defrag.len() >= self.max_incomplete_session_queue_size {
|
||||
// First, drop all entries that are timed out or whose physical source duplicates another entry.
|
||||
let mut sources = HashSet::with_capacity(defrag.len());
|
||||
let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
|
||||
defrag
|
||||
.retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f));
|
||||
|
||||
// Then, if we are still at or over the limit, drop 10% of remaining entries at random.
|
||||
if defrag.len() >= self.max_incomplete_session_queue_size {
|
||||
let mut rn = random::next_u32_secure();
|
||||
defrag.retain(|_, fragged| {
|
||||
rn = prng32(rn);
|
||||
rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
f
|
||||
}
|
||||
.0
|
||||
.lock()
|
||||
.unwrap()
|
||||
.assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count);
|
||||
if let Some(assembled_packet) = assembled_packet.as_ref() {
|
||||
self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter));
|
||||
assembled_packet.as_ref()
|
||||
} else {
|
||||
return Ok(ReceiveResult::Ok(None));
|
||||
}
|
||||
} else {
|
||||
incoming_packet_buf_arr = [incoming_physical_packet_buf];
|
||||
&incoming_packet_buf_arr
|
||||
};
|
||||
|
||||
return self.process_complete_incoming_packet(
|
||||
app,
|
||||
&mut send,
|
||||
|
@ -559,16 +580,14 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
&mut check_accept_session,
|
||||
data_buf,
|
||||
incoming_counter,
|
||||
&[incoming_packet_buf],
|
||||
incoming_packet,
|
||||
packet_type,
|
||||
None,
|
||||
incoming,
|
||||
None,
|
||||
key_index,
|
||||
current_time,
|
||||
);
|
||||
}
|
||||
|
||||
return Ok(ReceiveResult::Ok(None));
|
||||
}
|
||||
|
||||
fn process_complete_incoming_packet<
|
||||
|
@ -587,7 +606,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
fragments: &[Application::IncomingPacketBuffer],
|
||||
packet_type: u8,
|
||||
session: Option<Arc<Session<Application>>>,
|
||||
incoming: Option<Arc<IncomingIncompleteSession>>,
|
||||
incoming: Option<Arc<IncomingIncompleteSession<Application>>>,
|
||||
key_index: usize,
|
||||
current_time: i64,
|
||||
) -> Result<ReceiveResult<'b, Application>, Error> {
|
||||
|
@ -651,9 +670,9 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
|
||||
// If we got a valid data packet from Bob, this means we can cancel any offers
|
||||
// that are still oustanding for initialization.
|
||||
match &state.current_offer {
|
||||
match &state.outgoing_offer {
|
||||
Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => {
|
||||
state.current_offer = Offer::None;
|
||||
state.outgoing_offer = Offer::None;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -730,7 +749,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
}
|
||||
|
||||
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
|
||||
let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?;
|
||||
let alice_session_id = SessionId::new_from_array(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?;
|
||||
let header_protection_key = Secret(pkt.header_protection_key);
|
||||
|
||||
// Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create
|
||||
|
@ -761,7 +780,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
let ack: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut ack_packet)?;
|
||||
ack.session_protocol_version = SESSION_PROTOCOL_VERSION;
|
||||
ack.bob_noise_e = bob_noise_e;
|
||||
ack.bob_session_id = *bob_session_id.as_bytes();
|
||||
ack.bob_session_id = bob_session_id.to_bytes();
|
||||
ack.bob_hk_ciphertext = bob_hk_ciphertext;
|
||||
|
||||
// Encrypt main section of reply and attach tag.
|
||||
|
@ -802,6 +821,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
hk,
|
||||
bob_noise_e_secret,
|
||||
header_protection_key: Secret(pkt.header_protection_key),
|
||||
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
|
||||
}),
|
||||
);
|
||||
debug_assert!(!sessions.active.contains_key(&bob_session_id));
|
||||
|
@ -847,7 +867,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
return Err(Error::OutOfSequence);
|
||||
}
|
||||
|
||||
if let Offer::NoiseXKInit(outgoing_offer) = &state.current_offer {
|
||||
if let Offer::NoiseXKInit(outgoing_offer) = &state.outgoing_offer {
|
||||
let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?;
|
||||
|
||||
// Derive noise_es_ee from Bob's ephemeral public key.
|
||||
|
@ -875,7 +895,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
|
||||
let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?;
|
||||
|
||||
if let Some(bob_session_id) = SessionId::new_from_bytes(&pkt.bob_session_id) {
|
||||
if let Some(bob_session_id) = SessionId::new_from_array(&pkt.bob_session_id) {
|
||||
// Complete Noise_XKpsk3 by mixing in noise_se followed by the PSK. The PSK as far as
|
||||
// the Noise pattern is concerned is the result of mixing the externally supplied PSK
|
||||
// with the Kyber1024 shared secret (hk). Kyber is treated as part of the PSK because
|
||||
|
@ -948,7 +968,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
));
|
||||
debug_assert!(state.keys[1].is_none());
|
||||
state.current_key = 0;
|
||||
state.current_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck {
|
||||
state.outgoing_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck {
|
||||
last_retry_time: AtomicI64::new(current_time),
|
||||
ack,
|
||||
ack_len,
|
||||
|
@ -1071,7 +1091,7 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
None,
|
||||
],
|
||||
current_key: 0,
|
||||
current_offer: Offer::None,
|
||||
outgoing_offer: Offer::None,
|
||||
}),
|
||||
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
|
||||
});
|
||||
|
@ -1108,81 +1128,74 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
|
||||
if let Some(session) = session {
|
||||
let state = session.state.read().unwrap();
|
||||
if let Some(remote_session_id) = state.remote_session_id {
|
||||
if let Some(key) = state.keys[key_index].as_ref() {
|
||||
// Only the current "Alice" accepts rekeys initiated by the current "Bob." These roles
|
||||
// flip with each rekey event.
|
||||
if !key.my_turn_to_rekey {
|
||||
let mut c = key.get_receive_cipher(incoming_counter);
|
||||
c.reset_init_gcm(&incoming_message_nonce);
|
||||
c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]);
|
||||
let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]);
|
||||
drop(c);
|
||||
if let (Some(remote_session_id), Some(key)) = (state.remote_session_id, state.keys[key_index].as_ref()) {
|
||||
if !key.my_turn_to_rekey && {
|
||||
let mut c = key.get_receive_cipher(incoming_counter);
|
||||
c.reset_init_gcm(&incoming_message_nonce);
|
||||
c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]);
|
||||
c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..])
|
||||
} {
|
||||
let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
|
||||
if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) {
|
||||
let bob_e_secret = P384KeyPair::generate();
|
||||
let next_session_key = hmac_sha512_secret(
|
||||
key.ratchet_key.as_bytes(),
|
||||
bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?.as_bytes(),
|
||||
);
|
||||
|
||||
if aead_authentication_ok {
|
||||
let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
|
||||
if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) {
|
||||
let bob_e_secret = P384KeyPair::generate();
|
||||
let next_session_key = hmac_sha512_secret(
|
||||
key.ratchet_key.as_bytes(),
|
||||
bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?.as_bytes(),
|
||||
);
|
||||
// Packet fully authenticated
|
||||
if session.update_receive_window(incoming_counter) {
|
||||
let mut reply_buf = [0u8; RekeyAck::SIZE];
|
||||
let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap();
|
||||
reply.session_protocol_version = SESSION_PROTOCOL_VERSION;
|
||||
reply.bob_e = *bob_e_secret.public_key_bytes();
|
||||
reply.next_key_fingerprint = SHA384::hash(next_session_key.as_bytes());
|
||||
|
||||
// Packet fully authenticated
|
||||
if session.update_receive_window(incoming_counter) {
|
||||
let mut reply_buf = [0u8; RekeyAck::SIZE];
|
||||
let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap();
|
||||
reply.session_protocol_version = SESSION_PROTOCOL_VERSION;
|
||||
reply.bob_e = *bob_e_secret.public_key_bytes();
|
||||
reply.next_key_fingerprint = SHA384::hash(next_session_key.as_bytes());
|
||||
let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
|
||||
set_packet_header(
|
||||
&mut reply_buf,
|
||||
1,
|
||||
0,
|
||||
PACKET_TYPE_REKEY_ACK,
|
||||
u64::from(remote_session_id),
|
||||
state.current_key,
|
||||
counter,
|
||||
);
|
||||
|
||||
let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
|
||||
set_packet_header(
|
||||
&mut reply_buf,
|
||||
1,
|
||||
0,
|
||||
PACKET_TYPE_REKEY_ACK,
|
||||
u64::from(remote_session_id),
|
||||
state.current_key,
|
||||
counter,
|
||||
);
|
||||
let mut c = key.get_send_cipher(counter)?;
|
||||
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter));
|
||||
c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]);
|
||||
reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt());
|
||||
drop(c);
|
||||
|
||||
let mut c = key.get_send_cipher(counter)?;
|
||||
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter));
|
||||
c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]);
|
||||
reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt());
|
||||
drop(c);
|
||||
session
|
||||
.header_protection_cipher
|
||||
.encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
send(Some(&session), &mut reply_buf);
|
||||
|
||||
session
|
||||
.header_protection_cipher
|
||||
.encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
send(Some(&session), &mut reply_buf);
|
||||
// The new "Bob" doesn't know yet if Alice has received the new key, so the
|
||||
// new key is recorded as the "alt" (key_index ^ 1) but the current key is
|
||||
// not advanced yet. This happens automatically the first time we receive a
|
||||
// valid packet with the new key.
|
||||
let next_ratchet_count = key.ratchet_count + 1;
|
||||
drop(state);
|
||||
let mut state = session.state.write().unwrap();
|
||||
let _ = state.keys[key_index ^ 1].replace(SessionKey::new::<Application>(
|
||||
next_session_key,
|
||||
next_ratchet_count,
|
||||
current_time,
|
||||
counter,
|
||||
false,
|
||||
false,
|
||||
));
|
||||
|
||||
// The new "Bob" doesn't know yet if Alice has received the new key, so the
|
||||
// new key is recorded as the "alt" (key_index ^ 1) but the current key is
|
||||
// not advanced yet. This happens automatically the first time we receive a
|
||||
// valid packet with the new key.
|
||||
let next_ratchet_count = key.ratchet_count + 1;
|
||||
drop(state);
|
||||
let mut state = session.state.write().unwrap();
|
||||
let _ = state.keys[key_index ^ 1].replace(SessionKey::new::<Application>(
|
||||
next_session_key,
|
||||
next_ratchet_count,
|
||||
current_time,
|
||||
counter,
|
||||
false,
|
||||
false,
|
||||
));
|
||||
|
||||
drop(state);
|
||||
return Ok(ReceiveResult::Ok(Some(session)));
|
||||
} else {
|
||||
return Err(Error::OutOfSequence);
|
||||
}
|
||||
}
|
||||
drop(state);
|
||||
return Ok(ReceiveResult::Ok(Some(session)));
|
||||
} else {
|
||||
return Err(Error::OutOfSequence);
|
||||
}
|
||||
return Err(Error::FailedAuthentication);
|
||||
}
|
||||
return Err(Error::FailedAuthentication);
|
||||
}
|
||||
}
|
||||
return Err(Error::OutOfSequence);
|
||||
|
@ -1201,59 +1214,52 @@ impl<Application: ApplicationLayer> Context<Application> {
|
|||
|
||||
if let Some(session) = session {
|
||||
let state = session.state.read().unwrap();
|
||||
if let Offer::RekeyInit(alice_e_secret, _) = &state.current_offer {
|
||||
if let Some(key) = state.keys[key_index].as_ref() {
|
||||
// Only the current "Bob" initiates rekeys and expects this ACK.
|
||||
if key.my_turn_to_rekey {
|
||||
let mut c = key.get_receive_cipher(incoming_counter);
|
||||
c.reset_init_gcm(&incoming_message_nonce);
|
||||
c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]);
|
||||
let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]);
|
||||
drop(c);
|
||||
if let (Offer::RekeyInit(alice_e_secret, _), Some(key)) = (&state.outgoing_offer, state.keys[key_index].as_ref()) {
|
||||
if key.my_turn_to_rekey && {
|
||||
let mut c = key.get_receive_cipher(incoming_counter);
|
||||
c.reset_init_gcm(&incoming_message_nonce);
|
||||
c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]);
|
||||
c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..])
|
||||
} {
|
||||
let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
|
||||
if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) {
|
||||
let next_session_key = hmac_sha512_secret(
|
||||
key.ratchet_key.as_bytes(),
|
||||
alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?.as_bytes(),
|
||||
);
|
||||
|
||||
if aead_authentication_ok {
|
||||
// Packet fully authenticated
|
||||
if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(next_session_key.as_bytes())) {
|
||||
if session.update_receive_window(incoming_counter) {
|
||||
// The new "Alice" knows Bob has the key since this is an ACK, so she can go
|
||||
// ahead and set current_key to the new key. Then when she sends something
|
||||
// to Bob the other side will automatically advance to the new key as well.
|
||||
let next_ratchet_count = key.ratchet_count + 1;
|
||||
drop(state);
|
||||
let next_key_index = key_index ^ 1;
|
||||
let mut state = session.state.write().unwrap();
|
||||
let _ = state.keys[next_key_index].replace(SessionKey::new::<Application>(
|
||||
next_session_key,
|
||||
next_ratchet_count,
|
||||
current_time,
|
||||
session.send_counter.load(Ordering::Relaxed),
|
||||
true,
|
||||
true,
|
||||
));
|
||||
state.current_key = next_key_index; // this is an ACK so it's confirmed
|
||||
state.outgoing_offer = Offer::None;
|
||||
|
||||
let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap();
|
||||
if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) {
|
||||
let next_session_key = hmac_sha512_secret(
|
||||
key.ratchet_key.as_bytes(),
|
||||
alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?.as_bytes(),
|
||||
);
|
||||
|
||||
if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(next_session_key.as_bytes())) {
|
||||
if session.update_receive_window(incoming_counter) {
|
||||
// The new "Alice" knows Bob has the key since this is an ACK, so she can go
|
||||
// ahead and set current_key to the new key. Then when she sends something
|
||||
// to Bob the other side will automatically advance to the new key as well.
|
||||
let next_ratchet_count = key.ratchet_count + 1;
|
||||
drop(state);
|
||||
let next_key_index = key_index ^ 1;
|
||||
let mut state = session.state.write().unwrap();
|
||||
let _ = state.keys[next_key_index].replace(SessionKey::new::<Application>(
|
||||
next_session_key,
|
||||
next_ratchet_count,
|
||||
current_time,
|
||||
session.send_counter.load(Ordering::Relaxed),
|
||||
true,
|
||||
true,
|
||||
));
|
||||
state.current_key = next_key_index; // this is an ACK so it's confirmed
|
||||
state.current_offer = Offer::None;
|
||||
|
||||
drop(state);
|
||||
return Ok(ReceiveResult::Ok(Some(session)));
|
||||
} else {
|
||||
return Err(Error::OutOfSequence);
|
||||
}
|
||||
}
|
||||
drop(state);
|
||||
return Ok(ReceiveResult::Ok(Some(session)));
|
||||
} else {
|
||||
return Err(Error::OutOfSequence);
|
||||
}
|
||||
}
|
||||
return Err(Error::FailedAuthentication);
|
||||
}
|
||||
}
|
||||
return Err(Error::FailedAuthentication);
|
||||
} else {
|
||||
return Err(Error::OutOfSequence);
|
||||
}
|
||||
return Err(Error::OutOfSequence);
|
||||
} else {
|
||||
return Err(Error::UnknownLocalSessionId);
|
||||
}
|
||||
|
@ -1277,51 +1283,49 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
pub fn send<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction, mtu_sized_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> {
|
||||
debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU);
|
||||
let state = self.state.read().unwrap();
|
||||
if let Some(remote_session_id) = state.remote_session_id {
|
||||
if let Some(session_key) = state.keys[state.current_key].as_ref() {
|
||||
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
|
||||
if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) {
|
||||
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
|
||||
|
||||
let mut c = session_key.get_send_cipher(counter)?;
|
||||
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter));
|
||||
let mut c = session_key.get_send_cipher(counter)?;
|
||||
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter));
|
||||
|
||||
let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize;
|
||||
let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE;
|
||||
let last_fragment_no = fragment_count - 1;
|
||||
let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize;
|
||||
let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE;
|
||||
let last_fragment_no = fragment_count - 1;
|
||||
|
||||
for fragment_no in 0..fragment_count {
|
||||
let chunk_size = fragment_max_chunk_size.min(data.len());
|
||||
let mut fragment_size = chunk_size + HEADER_SIZE;
|
||||
for fragment_no in 0..fragment_count {
|
||||
let chunk_size = fragment_max_chunk_size.min(data.len());
|
||||
let mut fragment_size = chunk_size + HEADER_SIZE;
|
||||
|
||||
set_packet_header(
|
||||
mtu_sized_buffer,
|
||||
fragment_count,
|
||||
fragment_no,
|
||||
PACKET_TYPE_DATA,
|
||||
u64::from(remote_session_id),
|
||||
state.current_key,
|
||||
counter,
|
||||
);
|
||||
set_packet_header(
|
||||
mtu_sized_buffer,
|
||||
fragment_count,
|
||||
fragment_no,
|
||||
PACKET_TYPE_DATA,
|
||||
u64::from(remote_session_id),
|
||||
state.current_key,
|
||||
counter,
|
||||
);
|
||||
|
||||
c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]);
|
||||
data = &data[chunk_size..];
|
||||
c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]);
|
||||
data = &data[chunk_size..];
|
||||
|
||||
if fragment_no == last_fragment_no {
|
||||
debug_assert!(data.is_empty());
|
||||
let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE;
|
||||
mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt());
|
||||
fragment_size = tagged_fragment_size;
|
||||
}
|
||||
|
||||
self.header_protection_cipher
|
||||
.encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
send(&mut mtu_sized_buffer[..fragment_size]);
|
||||
if fragment_no == last_fragment_no {
|
||||
debug_assert!(data.is_empty());
|
||||
let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE;
|
||||
mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt());
|
||||
fragment_size = tagged_fragment_size;
|
||||
}
|
||||
debug_assert!(data.is_empty());
|
||||
|
||||
drop(c);
|
||||
|
||||
return Ok(());
|
||||
self.header_protection_cipher
|
||||
.encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
send(&mut mtu_sized_buffer[..fragment_size]);
|
||||
}
|
||||
debug_assert!(data.is_empty());
|
||||
|
||||
drop(c);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
return Err(Error::SessionNotEstablished);
|
||||
}
|
||||
|
@ -1329,23 +1333,26 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
/// Send a NOP to the other side (e.g. for keep alive).
|
||||
pub fn send_nop<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction) -> Result<(), Error> {
|
||||
let state = self.state.read().unwrap();
|
||||
if let Some(remote_session_id) = state.remote_session_id {
|
||||
if let Some(session_key) = state.keys[state.current_key].as_ref() {
|
||||
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
|
||||
let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
|
||||
let mut c = session_key.get_send_cipher(counter)?;
|
||||
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter));
|
||||
nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt());
|
||||
drop(c);
|
||||
set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter);
|
||||
self.header_protection_cipher
|
||||
.encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
send(&mut nop);
|
||||
}
|
||||
if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) {
|
||||
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
|
||||
let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
|
||||
let mut c = session_key.get_send_cipher(counter)?;
|
||||
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter));
|
||||
nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt());
|
||||
drop(c);
|
||||
set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter);
|
||||
self.header_protection_cipher
|
||||
.encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
|
||||
send(&mut nop);
|
||||
}
|
||||
return Err(Error::SessionNotEstablished);
|
||||
}
|
||||
|
||||
/// Set the current physical MTU that this session should use to send packets.
|
||||
pub fn set_physical_mtu(&self, mtu: usize) {
|
||||
self.state.write().unwrap().physical_mtu = mtu;
|
||||
}
|
||||
|
||||
/// Check whether this session is established.
|
||||
pub fn established(&self) -> bool {
|
||||
let state = self.state.read().unwrap();
|
||||
|
@ -1403,7 +1410,7 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
send(&mut rekey_buf);
|
||||
|
||||
drop(state);
|
||||
self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, current_time);
|
||||
self.state.write().unwrap().outgoing_offer = Offer::RekeyInit(rekey_e, current_time);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1578,6 +1585,7 @@ impl SessionKey {
|
|||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_send_cipher<'a>(&'a self, counter: u64) -> Result<MutexGuard<'a, AesGcm<true>>, Error> {
|
||||
if counter < self.expire_at_counter {
|
||||
Ok(self.send_cipher_pool[(counter as usize) % GCM_CIPHER_POOL_SIZE].lock().unwrap())
|
||||
|
@ -1586,6 +1594,7 @@ impl SessionKey {
|
|||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_receive_cipher<'a>(&'a self, counter: u64) -> MutexGuard<'a, AesGcm<false>> {
|
||||
self.receive_cipher_pool[(counter as usize) % GCM_CIPHER_POOL_SIZE].lock().unwrap()
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue