finished implementation of counter starting at 1

This commit is contained in:
mamoniot 2022-12-27 14:25:20 -05:00
parent 402cf69b72
commit 52556d0d89
4 changed files with 138 additions and 193 deletions

View file

@ -74,9 +74,6 @@ pub(crate) const HMAC_SIZE: usize = 48;
/// This is large since some ZeroTier nodes handle huge numbers of links, like roots and controllers. /// This is large since some ZeroTier nodes handle huge numbers of links, like roots and controllers.
pub(crate) const SESSION_ID_SIZE: usize = 6; pub(crate) const SESSION_ID_SIZE: usize = 6;
/// Number of session keys to hold at a given time (current, previous, next).
pub(crate) const KEY_HISTORY_SIZE: usize = 3;
/// Maximum difference between out-of-order incoming packet counters, and size of deduplication buffer. /// Maximum difference between out-of-order incoming packet counters, and size of deduplication buffer.
pub(crate) const COUNTER_MAX_ALLOWED_OOO: usize = 16; pub(crate) const COUNTER_MAX_ALLOWED_OOO: usize = 16;

View file

@ -1,8 +1,4 @@
use std::{sync::{ use std::sync::atomic::{Ordering, AtomicU32};
atomic::{AtomicU64, Ordering, AtomicU32, AtomicI32, AtomicBool}
}, mem};
use zerotier_crypto::random;
use crate::constants::COUNTER_MAX_ALLOWED_OOO; use crate::constants::COUNTER_MAX_ALLOWED_OOO;
@ -12,20 +8,24 @@ use crate::constants::COUNTER_MAX_ALLOWED_OOO;
/// lets us more safely implement key lifetime limits without confusing logic to handle 32-bit /// lets us more safely implement key lifetime limits without confusing logic to handle 32-bit
/// wrap-around. /// wrap-around.
#[repr(transparent)] #[repr(transparent)]
pub(crate) struct Counter(AtomicU64); pub(crate) struct Counter(AtomicU32);
impl Counter { impl Counter {
#[inline(always)] #[inline(always)]
pub fn new() -> Self { pub fn new() -> Self {
// Using a random value has no security implication. Zero would be fine. This just // Using a random value has no security implication. Zero would be fine. This just
// helps randomize packet contents a bit. // helps randomize packet contents a bit.
Self(AtomicU64::new((random::next_u32_secure()/2) as u64)) Self(AtomicU32::new(1u32))
}
#[inline(always)]
pub fn reset_after_initial_offer(&self) {
self.0.store(2u32, Ordering::SeqCst);
} }
/// Get the value most recently used to send a packet. /// Get the value most recently used to send a packet.
#[inline(always)] #[inline(always)]
pub fn previous(&self) -> CounterValue { pub fn current(&self) -> CounterValue {
CounterValue(self.0.load(Ordering::SeqCst).wrapping_sub(1)) CounterValue(self.0.load(Ordering::SeqCst))
} }
/// Get a counter value for the next packet being sent. /// Get a counter value for the next packet being sent.
@ -38,7 +38,7 @@ impl Counter {
/// A value of the outgoing packet counter. /// A value of the outgoing packet counter.
#[repr(transparent)] #[repr(transparent)]
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct CounterValue(u64); pub(crate) struct CounterValue(u32);
impl CounterValue { impl CounterValue {
/// Get the 32-bit counter value used to build packets. /// Get the 32-bit counter value used to build packets.
@ -46,51 +46,35 @@ impl CounterValue {
pub fn to_u32(&self) -> u32 { pub fn to_u32(&self) -> u32 {
self.0 as u32 self.0 as u32
} }
pub fn get_initial_offer_counter() -> CounterValue {
/// Get the counter value after N more uses of the parent counter. return CounterValue(1u32);
///
/// This checks for u64 overflow for the sake of correctness. Be careful if using ZSSP in a
/// generational starship where sessions may last for millions of years.
#[inline(always)]
pub fn counter_value_after_uses(&self, uses: u64) -> Self {
Self(self.0.checked_add(uses).unwrap())
} }
} }
/// Incoming packet deduplication and replay protection window. /// Incoming packet deduplication and replay protection window.
pub(crate) struct CounterWindow(AtomicBool, AtomicBool, [AtomicU32; COUNTER_MAX_ALLOWED_OOO]); pub(crate) struct CounterWindow([AtomicU32; COUNTER_MAX_ALLOWED_OOO]);
impl CounterWindow { impl CounterWindow {
#[inline(always)] #[inline(always)]
pub fn new(initial: u32) -> Self { pub fn new() -> Self {
Self(AtomicBool::new(true), AtomicBool::new(false), std::array::from_fn(|_| AtomicU32::new(initial))) Self(std::array::from_fn(|_| AtomicU32::new(0)))
} }
#[inline(always)] ///this creates a counter window that rejects everything
pub fn new_uninit() -> Self { pub fn new_invalid() -> Self {
Self(AtomicBool::new(false), AtomicBool::new(false), std::array::from_fn(|_| AtomicU32::new(0))) Self(std::array::from_fn(|_| AtomicU32::new(u32::MAX)))
} }
#[inline(always)] pub fn reset_after_initial_offer(&self) {
pub fn init_authenticated(&self, received_counter_value: u32) { for i in 0..COUNTER_MAX_ALLOWED_OOO {
self.1.store((u32::MAX/4 < received_counter_value) & (received_counter_value <= u32::MAX/4*3), Ordering::SeqCst); self.0[i].store(0, Ordering::SeqCst)
for i in 1..COUNTER_MAX_ALLOWED_OOO {
self.2[i].store(received_counter_value, Ordering::SeqCst);
} }
self.0.store(true, Ordering::SeqCst);
} }
#[inline(always)] #[inline(always)]
pub fn message_received(&self, received_counter_value: u32) -> bool { pub fn message_received(&self, received_counter_value: u32) -> bool {
if self.0.load(Ordering::SeqCst) { let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize;
let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize; //it is highly likely this can be a Relaxed ordering, but I want someone else to confirm that is the case
let pre = self.2[idx].load(Ordering::SeqCst); let pre = self.0[idx].load(Ordering::SeqCst);
if self.1.load(Ordering::Relaxed) { return pre < received_counter_value;
return pre < received_counter_value;
} else {
return (pre as i32) < (received_counter_value as i32);
}
} else {
return true;
}
} }
#[inline(always)] #[inline(always)]
@ -98,11 +82,6 @@ impl CounterWindow {
//if a valid message is received but one of its fragments was lost, it can technically be replayed. However since the message is incomplete, we know it still exists in the gather array, so the gather array will deduplicate the replayed message. Even if the gather array gets flushed, that flush still effectively deduplicates the replayed message. //if a valid message is received but one of its fragments was lost, it can technically be replayed. However since the message is incomplete, we know it still exists in the gather array, so the gather array will deduplicate the replayed message. Even if the gather array gets flushed, that flush still effectively deduplicates the replayed message.
//eventually the counter of that kind of message will be too OOO to be accepted anymore so it can't be used to DOS. //eventually the counter of that kind of message will be too OOO to be accepted anymore so it can't be used to DOS.
let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize; let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize;
if self.1.swap((u32::MAX/4 < received_counter_value) & (received_counter_value <= u32::MAX/4*3), Ordering::SeqCst) { return self.0[idx].fetch_max(received_counter_value, Ordering::SeqCst) < received_counter_value;
return self.2[idx].fetch_max(received_counter_value, Ordering::SeqCst) < received_counter_value;
} else {
let pre_as_signed: &AtomicI32 = unsafe {mem::transmute(&self.2[idx])};
return pre_as_signed.fetch_max(received_counter_value as i32, Ordering::SeqCst) < received_counter_value as i32;
}
} }
} }

View file

@ -232,7 +232,7 @@ mod tests {
let mut counter = 1u32; let mut counter = 1u32;
let mut history = Vec::new(); let mut history = Vec::new();
let mut w = CounterWindow::new(counter); let mut w = CounterWindow::new();
for i in 0..1000000 { for i in 0..1000000 {
let p = xorshift64(&mut rng) as f32/(u32::MAX as f32 + 1.0); let p = xorshift64(&mut rng) as f32/(u32::MAX as f32 + 1.0);
let c; let c;
@ -249,10 +249,15 @@ mod tests {
assert!(!w.message_authenticated(c)); assert!(!w.message_authenticated(c));
} }
continue; continue;
} else { } else if p < 0.999 {
c = xorshift64(&mut rng); c = xorshift64(&mut rng);
w.message_received(c); w.message_received(c);
continue; continue;
} else {
w.reset_after_initial_offer();
counter = 1u32;
history = Vec::new();
continue;
} }
if history.contains(&c) { if history.contains(&c) {
assert!(!w.message_authenticated(c)); assert!(!w.message_authenticated(c));

View file

@ -113,7 +113,7 @@ pub struct Session<Application: ApplicationLayer> {
pub application_data: Application::Data, pub application_data: Application::Data,
send_counter: Counter, // Outgoing packet counter and nonce state send_counter: Counter, // Outgoing packet counter and nonce state
receive_window: CounterWindow, // Receive window for anti-replay and deduplication receive_window: [CounterWindow; 2], // Receive window for anti-replay and deduplication
psk: Secret<64>, // Arbitrary PSK provided by external code psk: Secret<64>, // Arbitrary PSK provided by external code
noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key
header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) header_check_cipher: Aes, // Cipher used for header check codes (not Noise related)
@ -126,17 +126,17 @@ pub struct Session<Application: ApplicationLayer> {
struct SessionMutableState { struct SessionMutableState {
remote_session_id: Option<SessionId>, // The other side's 48-bit session ID remote_session_id: Option<SessionId>, // The other side's 48-bit session ID
session_keys: [Option<SessionKey>; KEY_HISTORY_SIZE], // Buffers to store current, next, and last active key session_keys: [Option<SessionKey>; 2], // Buffers to store current, next, and last active key
cur_session_key_idx: usize, // Pointer used for keys[] circular buffer cur_session_key_id: bool, // Pointer used for keys[] circular buffer
offer: Option<EphemeralOffer>, // Most recent ephemeral offer sent to remote offer: Option<EphemeralOffer>, // Most recent ephemeral offer sent to remote
last_remote_offer: i64, // Time of most recent ephemeral offer (ms) last_remote_offer: i64, // Time of most recent ephemeral offer (ms)
} }
/// A shared symmetric session key. /// A shared symmetric session key.
/// sessions always start at counter 1u32
struct SessionKey { struct SessionKey {
secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret
creation_time: i64, // Time session key was established creation_time: i64, // Time session key was established
creation_counter: CounterValue, // Counter value at which session was established
lifetime: KeyLifetime, // Key expiration time and counter lifetime: KeyLifetime, // Key expiration time and counter
ratchet_key: Secret<64>, // Ratchet key for deriving the next session key ratchet_key: Secret<64>, // Ratchet key for deriving the next session key
receive_key: Secret<AES_KEY_SIZE>, // Receive side AES-GCM key receive_key: Secret<AES_KEY_SIZE>, // Receive side AES-GCM key
@ -149,14 +149,14 @@ struct SessionKey {
/// Key lifetime state /// Key lifetime state
struct KeyLifetime { struct KeyLifetime {
rekey_at_or_after_counter: CounterValue, rekey_at_or_after_counter: u32,
hard_expire_at_counter: CounterValue,
rekey_at_or_after_timestamp: i64, rekey_at_or_after_timestamp: i64,
} }
/// Alice's KEY_OFFER, remembered so Noise agreement process can resume on KEY_COUNTER_OFFER. /// Alice's KEY_OFFER, remembered so Noise agreement process can resume on KEY_COUNTER_OFFER.
struct EphemeralOffer { struct EphemeralOffer {
id: [u8; 16], // Arbitrary random offer ID id: [u8; 16], // Arbitrary random offer ID
key_id: bool, // The key_id bound to this offer, for handling OOO rekeying
creation_time: i64, // Local time when offer was created creation_time: i64, // Local time when offer was created
ratchet_count: u64, // Ratchet count (starting at zero) for initial offer ratchet_count: u64, // Ratchet count (starting at zero) for initial offer
ratchet_key: Option<Secret<64>>, // Ratchet key from previous offer or None if first offer ratchet_key: Option<Secret<64>>, // Ratchet key from previous offer or None if first offer
@ -235,6 +235,7 @@ impl<Application: ApplicationLayer> Session<Application> {
if send_ephemeral_offer( if send_ephemeral_offer(
&mut send, &mut send,
send_counter.next(), send_counter.next(),
false,
local_session_id, local_session_id,
None, None,
app.get_local_s_public_blob(), app.get_local_s_public_blob(),
@ -254,14 +255,14 @@ impl<Application: ApplicationLayer> Session<Application> {
id: local_session_id, id: local_session_id,
application_data, application_data,
send_counter, send_counter,
receive_window: CounterWindow::new_uninit(),//alice does not know bob's counter yet receive_window: [CounterWindow::new(), CounterWindow::new_invalid()],
psk: psk.clone(), psk: psk.clone(),
noise_ss, noise_ss,
header_check_cipher, header_check_cipher,
state: RwLock::new(SessionMutableState { state: RwLock::new(SessionMutableState {
remote_session_id: None, remote_session_id: None,
session_keys: [None, None, None], session_keys: [None, None],
cur_session_key_idx: 0, cur_session_key_id: false,
offer, offer,
last_remote_offer: i64::MIN, last_remote_offer: i64::MIN,
}), }),
@ -290,7 +291,7 @@ impl<Application: ApplicationLayer> Session<Application> {
debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU);
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id { if let Some(remote_session_id) = state.remote_session_id {
if let Some(session_key) = state.session_keys[state.cur_session_key_idx].as_ref() { if let Some(session_key) = state.session_keys[state.cur_session_key_id as usize].as_ref() {
// Total size of the armored packet we are going to send (may end up being fragmented) // Total size of the armored packet we are going to send (may end up being fragmented)
let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
@ -309,6 +310,7 @@ impl<Application: ApplicationLayer> Session<Application> {
PACKET_TYPE_DATA, PACKET_TYPE_DATA,
remote_session_id.into(), remote_session_id.into(),
counter, counter,
state.cur_session_key_id
)?; )?;
// Get an initialized AES-GCM cipher and re-initialize with a 96-bit IV built from remote session ID, // Get an initialized AES-GCM cipher and re-initialize with a 96-bit IV built from remote session ID,
@ -367,7 +369,7 @@ impl<Application: ApplicationLayer> Session<Application> {
/// Check whether this session is established. /// Check whether this session is established.
pub fn established(&self) -> bool { pub fn established(&self) -> bool {
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
state.remote_session_id.is_some() && state.session_keys[state.cur_session_key_idx].is_some() state.remote_session_id.is_some() && state.session_keys[state.cur_session_key_id as usize].is_some()
} }
/// Get information about this session's security state. /// Get information about this session's security state.
@ -376,7 +378,7 @@ impl<Application: ApplicationLayer> Session<Application> {
/// and whether Kyber1024 was used. None is returned if the session isn't established. /// and whether Kyber1024 was used. None is returned if the session isn't established.
pub fn status(&self) -> Option<([u8; 16], i64, u64, bool)> { pub fn status(&self) -> Option<([u8; 16], i64, u64, bool)> {
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
if let Some(key) = state.session_keys[state.cur_session_key_idx].as_ref() { if let Some(key) = state.session_keys[state.cur_session_key_id as usize].as_ref() {
Some((key.secret_fingerprint, key.creation_time, key.ratchet_count, key.jedi)) Some((key.secret_fingerprint, key.creation_time, key.ratchet_count, key.jedi))
} else { } else {
None None
@ -402,9 +404,9 @@ impl<Application: ApplicationLayer> Session<Application> {
) { ) {
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
if (force_rekey if (force_rekey
|| state.session_keys[state.cur_session_key_idx] || state.session_keys[state.cur_session_key_id as usize]
.as_ref() .as_ref()
.map_or(true, |key| key.lifetime.should_rekey(self.send_counter.previous(), current_time))) .map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time)))
&& state && state
.offer .offer
.as_ref() .as_ref()
@ -414,7 +416,8 @@ impl<Application: ApplicationLayer> Session<Application> {
let mut offer = None; let mut offer = None;
if send_ephemeral_offer( if send_ephemeral_offer(
&mut send, &mut send,
self.send_counter.next(), CounterValue::get_initial_offer_counter(),
!state.cur_session_key_id,
self.id, self.id,
state.remote_session_id, state.remote_session_id,
app.get_local_s_public_blob(), app.get_local_s_public_blob(),
@ -422,7 +425,7 @@ impl<Application: ApplicationLayer> Session<Application> {
&remote_s_public, &remote_s_public,
&self.remote_s_public_blob_hash, &self.remote_s_public_blob_hash,
&self.noise_ss, &self.noise_ss,
state.session_keys[state.cur_session_key_idx].as_ref(), state.session_keys[state.cur_session_key_id as usize].as_ref(),
if state.remote_session_id.is_some() { if state.remote_session_id.is_some() {
Some(&self.header_check_cipher) Some(&self.header_check_cipher)
} else { } else {
@ -435,7 +438,7 @@ impl<Application: ApplicationLayer> Session<Application> {
.is_ok() .is_ok()
{ {
drop(state); drop(state);
let _ = self.state.write().unwrap().offer.replace(offer.unwrap()); self.state.write().unwrap().offer = offer;
} }
} }
} }
@ -477,7 +480,9 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
let counter = u32::from_le(memory::load_raw(incoming_packet)); let mut counter = u32::from_le(memory::load_raw(incoming_packet));
let key_id = (counter & 1) > 0;
counter = counter.wrapping_shr(1);
let packet_type_fragment_info = u16::from_le(memory::load_raw(&incoming_packet[14..16])); let packet_type_fragment_info = u16::from_le(memory::load_raw(&incoming_packet[14..16]));
let packet_type = (packet_type_fragment_info & 0x0f) as u8; let packet_type = (packet_type_fragment_info & 0x0f) as u8;
let fragment_count = ((packet_type_fragment_info.wrapping_shr(4) + 1) as u8) & 63; let fragment_count = ((packet_type_fragment_info.wrapping_shr(4) + 1) as u8) & 63;
@ -487,7 +492,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
{ {
if let Some(session) = app.lookup_session(local_session_id) { if let Some(session) = app.lookup_session(local_session_id) {
if verify_header_check_code(incoming_packet, &session.header_check_cipher) { if verify_header_check_code(incoming_packet, &session.header_check_cipher) {
if session.receive_window.message_received(counter) { if session.receive_window[key_id as usize].message_received(counter) {
let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter);
if fragment_count > 1 { if fragment_count > 1 {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
@ -501,6 +506,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
&mut send, &mut send,
data_buf, data_buf,
counter, counter,
key_id,
canonical_header.as_bytes(), canonical_header.as_bytes(),
assembled_packet.as_ref(), assembled_packet.as_ref(),
packet_type, packet_type,
@ -520,6 +526,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
&mut send, &mut send,
data_buf, data_buf,
counter, counter,
key_id,
canonical_header.as_bytes(), canonical_header.as_bytes(),
&[incoming_packet_buf], &[incoming_packet_buf],
packet_type, packet_type,
@ -556,6 +563,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
&mut send, &mut send,
data_buf, data_buf,
counter, counter,
key_id,
canonical_header.as_bytes(), canonical_header.as_bytes(),
assembled_packet.as_ref(), assembled_packet.as_ref(),
packet_type, packet_type,
@ -571,6 +579,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
&mut send, &mut send,
data_buf, data_buf,
counter, counter,
key_id,
canonical_header.as_bytes(), canonical_header.as_bytes(),
&[incoming_packet_buf], &[incoming_packet_buf],
packet_type, packet_type,
@ -599,6 +608,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
send: &mut SendFunction, send: &mut SendFunction,
data_buf: &'a mut [u8], data_buf: &'a mut [u8],
counter: u32, counter: u32,
key_id: bool,
canonical_header_bytes: &[u8; 12], canonical_header_bytes: &[u8; 12],
fragments: &[Application::IncomingPacketBuffer], fragments: &[Application::IncomingPacketBuffer],
packet_type: u8, packet_type: u8,
@ -615,82 +625,59 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
if packet_type <= PACKET_TYPE_NOP { if packet_type <= PACKET_TYPE_NOP {
if let Some(session) = session { if let Some(session) = session {
let state = session.state.read().unwrap(); let state = session.state.read().unwrap();
for p in 0..KEY_HISTORY_SIZE { if let Some(session_key) = state.session_keys[key_id as usize].as_ref() {
let key_idx = (state.cur_session_key_idx + p) % KEY_HISTORY_SIZE; let mut c = session_key.get_receive_cipher();
if let Some(session_key) = state.session_keys[key_idx].as_ref() { c.reset_init_gcm(canonical_header_bytes);
let mut c = session_key.get_receive_cipher(); ////////////////////////////////////////////////////////////////
c.reset_init_gcm(canonical_header_bytes); // packet decoding for post-noise transport
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
// packet decoding for post-noise transport
////////////////////////////////////////////////////////////////
let mut data_len = 0; let mut data_len = 0;
// Decrypt fragments 0..N-1 where N is the number of fragments. // Decrypt fragments 0..N-1 where N is the number of fragments.
for f in fragments[..(fragments.len() - 1)].iter() { for f in fragments[..(fragments.len() - 1)].iter() {
let f = f.as_ref(); let f = f.as_ref();
debug_assert!(f.len() >= HEADER_SIZE); debug_assert!(f.len() >= HEADER_SIZE);
let current_frag_data_start = data_len;
data_len += f.len() - HEADER_SIZE;
if data_len > data_buf.len() {
unlikely_branch();
session_key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
}
// Decrypt final fragment (or only fragment if not fragmented)
let current_frag_data_start = data_len; let current_frag_data_start = data_len;
let last_fragment = fragments.last().unwrap().as_ref(); data_len += f.len() - HEADER_SIZE;
if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
if data_len > data_buf.len() { if data_len > data_buf.len() {
unlikely_branch(); unlikely_branch();
session_key.return_receive_cipher(c); session_key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall); return Err(Error::DataBufferTooSmall);
} }
let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE; c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
c.crypt( }
&last_fragment[HEADER_SIZE..payload_end],
&mut data_buf[current_frag_data_start..data_len],
);
let gcm_tag = &last_fragment[payload_end..]; // Decrypt final fragment (or only fragment if not fragmented)
let aead_authentication_ok = c.finish_decrypt(gcm_tag); let current_frag_data_start = data_len;
let last_fragment = fragments.last().unwrap().as_ref();
if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
if data_len > data_buf.len() {
unlikely_branch();
session_key.return_receive_cipher(c); session_key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE;
c.crypt(
&last_fragment[HEADER_SIZE..payload_end],
&mut data_buf[current_frag_data_start..data_len],
);
if aead_authentication_ok { let gcm_tag = &last_fragment[payload_end..];
if session.receive_window.message_authenticated(counter) { let aead_authentication_ok = c.finish_decrypt(gcm_tag);
// Select this key as the new default if it's newer than the current key. session_key.return_receive_cipher(c);
if p > 0
&& state.session_keys[state.cur_session_key_idx]
.as_ref()
.map_or(true, |old| old.creation_counter < session_key.creation_counter)
{
drop(state);
let mut state = session.state.write().unwrap();
state.cur_session_key_idx = key_idx;
for i in 0..KEY_HISTORY_SIZE {
if i != key_idx {
if let Some(old_key) = state.session_keys[key_idx].as_ref() {
// Release pooled cipher memory from old keys.
old_key.receive_cipher_pool.lock().unwrap().clear();
old_key.send_cipher_pool.lock().unwrap().clear();
}
}
}
}
if packet_type == PACKET_TYPE_DATA { if aead_authentication_ok {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); if session.receive_window[key_id as usize].message_authenticated(counter) {
} else { if packet_type == PACKET_TYPE_DATA {
unlikely_branch(); return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
return Ok(ReceiveResult::Ok); } else {
} unlikely_branch();
return Ok(ReceiveResult::Ok);
} }
} }
} }
@ -856,13 +843,13 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
let mut ratchet_key = None; let mut ratchet_key = None;
let mut last_ratchet_count = 0; let mut last_ratchet_count = 0;
let state = session.state.read().unwrap(); let state = session.state.read().unwrap();
for k in state.session_keys.iter() { if state.cur_session_key_id == key_id {
if let Some(k) = k.as_ref() { return Ok(ReceiveResult::Ignored); // alice is requesting to overwrite the current key, reject it
if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { }
ratchet_key = Some(k.ratchet_key.clone()); if let Some(k) = state.session_keys[state.cur_session_key_id as usize].as_ref() {
last_ratchet_count = k.ratchet_count; if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) {
break; ratchet_key = Some(k.ratchet_key.clone());
} last_ratchet_count = k.ratchet_count;
} }
} }
if ratchet_key.is_none() { if ratchet_key.is_none() {
@ -882,14 +869,14 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
id: new_session_id, id: new_session_id,
application_data: associated_object, application_data: associated_object,
send_counter: Counter::new(), send_counter: Counter::new(),
receive_window: CounterWindow::new(counter), receive_window: [CounterWindow::new_invalid(), CounterWindow::new_invalid()],
psk, psk,
noise_ss, noise_ss,
header_check_cipher, header_check_cipher,
state: RwLock::new(SessionMutableState { state: RwLock::new(SessionMutableState {
remote_session_id: Some(alice_session_id), remote_session_id: Some(alice_session_id),
session_keys: [None, None, None], session_keys: [None, None],
cur_session_key_idx: 0, cur_session_key_id: key_id,
offer: None, offer: None,
last_remote_offer: current_time, last_remote_offer: current_time,
}), }),
@ -959,7 +946,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
let mut reply_buf = [0_u8; KEX_BUF_LEN]; let mut reply_buf = [0_u8; KEX_BUF_LEN];
let reply_counter = session.send_counter.next(); let reply_counter = CounterValue::get_initial_offer_counter();
let mut idx = HEADER_SIZE; let mut idx = HEADER_SIZE;
idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?; idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?;
@ -991,6 +978,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
PACKET_TYPE_KEY_COUNTER_OFFER, PACKET_TYPE_KEY_COUNTER_OFFER,
alice_session_id.into(), alice_session_id.into(),
reply_counter, reply_counter,
key_id,
)?; )?;
let reply_canonical_header = let reply_canonical_header =
CanonicalHeader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32()); CanonicalHeader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32());
@ -1033,23 +1021,25 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
session_key, session_key,
Role::Bob, Role::Bob,
current_time, current_time,
reply_counter,
last_ratchet_count + 1, last_ratchet_count + 1,
hybrid_kk.is_some(), hybrid_kk.is_some(),
); );
let mut state = session.state.write().unwrap(); let mut state = session.state.write().unwrap();
let _ = state.remote_session_id.replace(alice_session_id); let _ = state.remote_session_id.replace(alice_session_id);
let next_key_ptr = (state.cur_session_key_idx + 1) % KEY_HISTORY_SIZE; let _ = state.session_keys[key_id as usize].replace(session_key);
let _ = state.session_keys[next_key_ptr].replace(session_key); session.send_counter.reset_after_initial_offer();
state.cur_session_key_id = key_id;
session.receive_window[key_id as usize].reset_after_initial_offer();
session.receive_window[key_id as usize].message_authenticated(counter);
drop(state); drop(state);
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher); send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher);
if new_session.is_some() { if let Some(new_session) = new_session {
return Ok(ReceiveResult::OkNewSession(new_session.unwrap())); return Ok(ReceiveResult::OkNewSession(new_session));
} else { } else {
return Ok(ReceiveResult::Ok); return Ok(ReceiveResult::Ok);
} }
@ -1106,7 +1096,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?; parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?;
// Check that this is a counter offer to the original offer we sent. // Check that this is a counter offer to the original offer we sent.
if !offer.id.eq(offer_id) { if !(offer.id.eq(offer_id) & (offer.key_id == key_id)) {
return Ok(ReceiveResult::Ignored); return Ok(ReceiveResult::Ignored);
} }
@ -1151,47 +1141,21 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
// Alice has now completed and validated the full hybrid exchange. // Alice has now completed and validated the full hybrid exchange.
let reply_counter = session.send_counter.next();
let session_key = SessionKey::new( let session_key = SessionKey::new(
session_key, session_key,
Role::Alice, Role::Alice,
current_time, current_time,
reply_counter,
last_ratchet_count + 1, last_ratchet_count + 1,
hybrid_kk.is_some(), hybrid_kk.is_some(),
); );
session.receive_window.init_authenticated(counter);
////////////////////////////////////////////////////////////////
// packet encoding for post-noise session start ack
////////////////////////////////////////////////////////////////
let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
create_packet_header(
&mut reply_buf,
HEADER_SIZE + AES_GCM_TAG_SIZE,
mtu,
PACKET_TYPE_NOP,
bob_session_id.into(),
reply_counter,
)?;
let mut c = session_key.get_send_cipher(reply_counter)?;
c.reset_init_gcm(
CanonicalHeader::make(bob_session_id.into(), PACKET_TYPE_NOP, reply_counter.to_u32()).as_bytes(),
);
let gcm_tag = c.finish_encrypt();
safe_write_all(&mut reply_buf, HEADER_SIZE, &gcm_tag)?;
session_key.return_send_cipher(c);
set_header_check_code(&mut reply_buf, &session.header_check_cipher);
send(&mut reply_buf);
drop(state); drop(state);
let mut state = session.state.write().unwrap(); let mut state = session.state.write().unwrap();
let _ = state.remote_session_id.replace(bob_session_id); let _ = state.remote_session_id.replace(bob_session_id);
let next_key_idx = (state.cur_session_key_idx + 1) % KEY_HISTORY_SIZE; let _ = state.session_keys[key_id as usize].replace(session_key);
let _ = state.session_keys[next_key_idx].replace(session_key); session.send_counter.reset_after_initial_offer();
state.cur_session_key_id = key_id;
session.receive_window[key_id as usize].message_authenticated(counter);
let _ = state.offer.take(); let _ = state.offer.take();
return Ok(ReceiveResult::Ok); return Ok(ReceiveResult::Ok);
@ -1213,6 +1177,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>( fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
send: &mut SendFunction, send: &mut SendFunction,
counter: CounterValue, counter: CounterValue,
key_id: bool,
alice_session_id: SessionId, alice_session_id: SessionId,
bob_session_id: Option<SessionId>, bob_session_id: Option<SessionId>,
alice_s_public_blob: &[u8], alice_s_public_blob: &[u8],
@ -1298,6 +1263,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
PACKET_TYPE_INITIAL_KEY_OFFER, PACKET_TYPE_INITIAL_KEY_OFFER,
bob_session_id, bob_session_id,
counter, counter,
key_id,
)?; )?;
let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32()); let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32());
@ -1351,6 +1317,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
*ret_ephemeral_offer = Some(EphemeralOffer { *ret_ephemeral_offer = Some(EphemeralOffer {
id, id,
key_id,
creation_time: current_time, creation_time: current_time,
ratchet_count, ratchet_count,
ratchet_key, ratchet_key,
@ -1370,6 +1337,7 @@ fn create_packet_header(
packet_type: u8, packet_type: u8,
recipient_session_id: SessionId, recipient_session_id: SessionId,
counter: CounterValue, counter: CounterValue,
key_id: bool
) -> Result<(), Error> { ) -> Result<(), Error> {
let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize; let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize;
@ -1379,6 +1347,7 @@ fn create_packet_header(
debug_assert!(fragment_count > 0); debug_assert!(fragment_count > 0);
debug_assert!(fragment_count <= MAX_FRAGMENTS); debug_assert!(fragment_count <= MAX_FRAGMENTS);
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
let counter = counter.to_u32().wrapping_shl(1) | (key_id as u32);
if fragment_count <= MAX_FRAGMENTS { if fragment_count <= MAX_FRAGMENTS {
// Header indexed by bit/byte: // Header indexed by bit/byte:
@ -1388,7 +1357,7 @@ fn create_packet_header(
// [112-115]/[14-] packet type (0-15) // [112-115]/[14-] packet type (0-15)
// [116-121]/[-] number of fragments (0..63 for 1..64 fragments total) // [116-121]/[-] number of fragments (0..63 for 1..64 fragments total)
// [122-127]/[-15] fragment number (0, 1, 2, ...) // [122-127]/[-15] fragment number (0, 1, 2, ...)
memory::store_raw((counter.to_u32() as u64).to_le(), header_destination_buffer); memory::store_raw((counter as u64).to_le(), header_destination_buffer);
memory::store_raw( memory::store_raw(
(u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)) (u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52))
.to_le(), .to_le(),
@ -1499,7 +1468,6 @@ impl SessionKey {
key: Secret<64>, key: Secret<64>,
role: Role, role: Role,
current_time: i64, current_time: i64,
current_counter: CounterValue,
ratchet_count: u64, ratchet_count: u64,
jedi: bool, jedi: bool,
) -> Self { ) -> Self {
@ -1512,8 +1480,7 @@ impl SessionKey {
Self { Self {
secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(), secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(),
creation_time: current_time, creation_time: current_time,
creation_counter: current_counter, lifetime: KeyLifetime::new(current_time),
lifetime: KeyLifetime::new(current_counter, current_time),
ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING), ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING),
receive_key, receive_key,
send_key, send_key,
@ -1560,12 +1527,9 @@ impl SessionKey {
} }
impl KeyLifetime { impl KeyLifetime {
fn new(current_counter: CounterValue, current_time: i64) -> Self { fn new(current_time: i64) -> Self {
Self { Self {
rekey_at_or_after_counter: current_counter rekey_at_or_after_counter: REKEY_AFTER_USES as u32 + (random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER),
.counter_value_after_uses(REKEY_AFTER_USES)
.counter_value_after_uses((random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER) as u64),
hard_expire_at_counter: current_counter.counter_value_after_uses(EXPIRE_AFTER_USES),
rekey_at_or_after_timestamp: current_time rekey_at_or_after_timestamp: current_time
+ REKEY_AFTER_TIME_MS + REKEY_AFTER_TIME_MS
+ (random::next_u32_secure() % REKEY_AFTER_TIME_MS_MAX_JITTER) as i64, + (random::next_u32_secure() % REKEY_AFTER_TIME_MS_MAX_JITTER) as i64,
@ -1573,11 +1537,11 @@ impl KeyLifetime {
} }
fn should_rekey(&self, counter: CounterValue, current_time: i64) -> bool { fn should_rekey(&self, counter: CounterValue, current_time: i64) -> bool {
counter >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp counter.to_u32() >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp
} }
fn expired(&self, counter: CounterValue) -> bool { fn expired(&self, counter: CounterValue) -> bool {
counter >= self.hard_expire_at_counter counter.to_u32() >= EXPIRE_AFTER_USES as u32
} }
} }