mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-06-07 21:13:44 +02:00
finished implementation of counter starting at 1
This commit is contained in:
parent
402cf69b72
commit
52556d0d89
4 changed files with 138 additions and 193 deletions
|
@ -74,9 +74,6 @@ pub(crate) const HMAC_SIZE: usize = 48;
|
|||
/// This is large since some ZeroTier nodes handle huge numbers of links, like roots and controllers.
|
||||
pub(crate) const SESSION_ID_SIZE: usize = 6;
|
||||
|
||||
/// Number of session keys to hold at a given time (current, previous, next).
|
||||
pub(crate) const KEY_HISTORY_SIZE: usize = 3;
|
||||
|
||||
/// Maximum difference between out-of-order incoming packet counters, and size of deduplication buffer.
|
||||
pub(crate) const COUNTER_MAX_ALLOWED_OOO: usize = 16;
|
||||
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
use std::{sync::{
|
||||
atomic::{AtomicU64, Ordering, AtomicU32, AtomicI32, AtomicBool}
|
||||
}, mem};
|
||||
|
||||
use zerotier_crypto::random;
|
||||
use std::sync::atomic::{Ordering, AtomicU32};
|
||||
|
||||
use crate::constants::COUNTER_MAX_ALLOWED_OOO;
|
||||
|
||||
|
@ -12,20 +8,24 @@ use crate::constants::COUNTER_MAX_ALLOWED_OOO;
|
|||
/// lets us more safely implement key lifetime limits without confusing logic to handle 32-bit
|
||||
/// wrap-around.
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct Counter(AtomicU64);
|
||||
pub(crate) struct Counter(AtomicU32);
|
||||
|
||||
impl Counter {
|
||||
#[inline(always)]
|
||||
pub fn new() -> Self {
|
||||
// Using a random value has no security implication. Zero would be fine. This just
|
||||
// helps randomize packet contents a bit.
|
||||
Self(AtomicU64::new((random::next_u32_secure()/2) as u64))
|
||||
Self(AtomicU32::new(1u32))
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn reset_after_initial_offer(&self) {
|
||||
self.0.store(2u32, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Get the value most recently used to send a packet.
|
||||
#[inline(always)]
|
||||
pub fn previous(&self) -> CounterValue {
|
||||
CounterValue(self.0.load(Ordering::SeqCst).wrapping_sub(1))
|
||||
pub fn current(&self) -> CounterValue {
|
||||
CounterValue(self.0.load(Ordering::SeqCst))
|
||||
}
|
||||
|
||||
/// Get a counter value for the next packet being sent.
|
||||
|
@ -38,7 +38,7 @@ impl Counter {
|
|||
/// A value of the outgoing packet counter.
|
||||
#[repr(transparent)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub(crate) struct CounterValue(u64);
|
||||
pub(crate) struct CounterValue(u32);
|
||||
|
||||
impl CounterValue {
|
||||
/// Get the 32-bit counter value used to build packets.
|
||||
|
@ -46,51 +46,35 @@ impl CounterValue {
|
|||
pub fn to_u32(&self) -> u32 {
|
||||
self.0 as u32
|
||||
}
|
||||
|
||||
/// Get the counter value after N more uses of the parent counter.
|
||||
///
|
||||
/// This checks for u64 overflow for the sake of correctness. Be careful if using ZSSP in a
|
||||
/// generational starship where sessions may last for millions of years.
|
||||
#[inline(always)]
|
||||
pub fn counter_value_after_uses(&self, uses: u64) -> Self {
|
||||
Self(self.0.checked_add(uses).unwrap())
|
||||
pub fn get_initial_offer_counter() -> CounterValue {
|
||||
return CounterValue(1u32);
|
||||
}
|
||||
}
|
||||
|
||||
/// Incoming packet deduplication and replay protection window.
|
||||
pub(crate) struct CounterWindow(AtomicBool, AtomicBool, [AtomicU32; COUNTER_MAX_ALLOWED_OOO]);
|
||||
pub(crate) struct CounterWindow([AtomicU32; COUNTER_MAX_ALLOWED_OOO]);
|
||||
|
||||
impl CounterWindow {
|
||||
#[inline(always)]
|
||||
pub fn new(initial: u32) -> Self {
|
||||
Self(AtomicBool::new(true), AtomicBool::new(false), std::array::from_fn(|_| AtomicU32::new(initial)))
|
||||
pub fn new() -> Self {
|
||||
Self(std::array::from_fn(|_| AtomicU32::new(0)))
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn new_uninit() -> Self {
|
||||
Self(AtomicBool::new(false), AtomicBool::new(false), std::array::from_fn(|_| AtomicU32::new(0)))
|
||||
///this creates a counter window that rejects everything
|
||||
pub fn new_invalid() -> Self {
|
||||
Self(std::array::from_fn(|_| AtomicU32::new(u32::MAX)))
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn init_authenticated(&self, received_counter_value: u32) {
|
||||
self.1.store((u32::MAX/4 < received_counter_value) & (received_counter_value <= u32::MAX/4*3), Ordering::SeqCst);
|
||||
for i in 1..COUNTER_MAX_ALLOWED_OOO {
|
||||
self.2[i].store(received_counter_value, Ordering::SeqCst);
|
||||
pub fn reset_after_initial_offer(&self) {
|
||||
for i in 0..COUNTER_MAX_ALLOWED_OOO {
|
||||
self.0[i].store(0, Ordering::SeqCst)
|
||||
}
|
||||
self.0.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn message_received(&self, received_counter_value: u32) -> bool {
|
||||
if self.0.load(Ordering::SeqCst) {
|
||||
let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize;
|
||||
let pre = self.2[idx].load(Ordering::SeqCst);
|
||||
if self.1.load(Ordering::Relaxed) {
|
||||
return pre < received_counter_value;
|
||||
} else {
|
||||
return (pre as i32) < (received_counter_value as i32);
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize;
|
||||
//it is highly likely this can be a Relaxed ordering, but I want someone else to confirm that is the case
|
||||
let pre = self.0[idx].load(Ordering::SeqCst);
|
||||
return pre < received_counter_value;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
|
@ -98,11 +82,6 @@ impl CounterWindow {
|
|||
//if a valid message is received but one of its fragments was lost, it can technically be replayed. However since the message is incomplete, we know it still exists in the gather array, so the gather array will deduplicate the replayed message. Even if the gather array gets flushed, that flush still effectively deduplicates the replayed message.
|
||||
//eventually the counter of that kind of message will be too OOO to be accepted anymore so it can't be used to DOS.
|
||||
let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize;
|
||||
if self.1.swap((u32::MAX/4 < received_counter_value) & (received_counter_value <= u32::MAX/4*3), Ordering::SeqCst) {
|
||||
return self.2[idx].fetch_max(received_counter_value, Ordering::SeqCst) < received_counter_value;
|
||||
} else {
|
||||
let pre_as_signed: &AtomicI32 = unsafe {mem::transmute(&self.2[idx])};
|
||||
return pre_as_signed.fetch_max(received_counter_value as i32, Ordering::SeqCst) < received_counter_value as i32;
|
||||
}
|
||||
return self.0[idx].fetch_max(received_counter_value, Ordering::SeqCst) < received_counter_value;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -232,7 +232,7 @@ mod tests {
|
|||
let mut counter = 1u32;
|
||||
let mut history = Vec::new();
|
||||
|
||||
let mut w = CounterWindow::new(counter);
|
||||
let mut w = CounterWindow::new();
|
||||
for i in 0..1000000 {
|
||||
let p = xorshift64(&mut rng) as f32/(u32::MAX as f32 + 1.0);
|
||||
let c;
|
||||
|
@ -249,10 +249,15 @@ mod tests {
|
|||
assert!(!w.message_authenticated(c));
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
} else if p < 0.999 {
|
||||
c = xorshift64(&mut rng);
|
||||
w.message_received(c);
|
||||
continue;
|
||||
} else {
|
||||
w.reset_after_initial_offer();
|
||||
counter = 1u32;
|
||||
history = Vec::new();
|
||||
continue;
|
||||
}
|
||||
if history.contains(&c) {
|
||||
assert!(!w.message_authenticated(c));
|
||||
|
|
246
zssp/src/zssp.rs
246
zssp/src/zssp.rs
|
@ -113,7 +113,7 @@ pub struct Session<Application: ApplicationLayer> {
|
|||
pub application_data: Application::Data,
|
||||
|
||||
send_counter: Counter, // Outgoing packet counter and nonce state
|
||||
receive_window: CounterWindow, // Receive window for anti-replay and deduplication
|
||||
receive_window: [CounterWindow; 2], // Receive window for anti-replay and deduplication
|
||||
psk: Secret<64>, // Arbitrary PSK provided by external code
|
||||
noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key
|
||||
header_check_cipher: Aes, // Cipher used for header check codes (not Noise related)
|
||||
|
@ -126,17 +126,17 @@ pub struct Session<Application: ApplicationLayer> {
|
|||
|
||||
struct SessionMutableState {
|
||||
remote_session_id: Option<SessionId>, // The other side's 48-bit session ID
|
||||
session_keys: [Option<SessionKey>; KEY_HISTORY_SIZE], // Buffers to store current, next, and last active key
|
||||
cur_session_key_idx: usize, // Pointer used for keys[] circular buffer
|
||||
session_keys: [Option<SessionKey>; 2], // Buffers to store current, next, and last active key
|
||||
cur_session_key_id: bool, // Pointer used for keys[] circular buffer
|
||||
offer: Option<EphemeralOffer>, // Most recent ephemeral offer sent to remote
|
||||
last_remote_offer: i64, // Time of most recent ephemeral offer (ms)
|
||||
}
|
||||
|
||||
/// A shared symmetric session key.
|
||||
/// sessions always start at counter 1u32
|
||||
struct SessionKey {
|
||||
secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret
|
||||
creation_time: i64, // Time session key was established
|
||||
creation_counter: CounterValue, // Counter value at which session was established
|
||||
lifetime: KeyLifetime, // Key expiration time and counter
|
||||
ratchet_key: Secret<64>, // Ratchet key for deriving the next session key
|
||||
receive_key: Secret<AES_KEY_SIZE>, // Receive side AES-GCM key
|
||||
|
@ -149,14 +149,14 @@ struct SessionKey {
|
|||
|
||||
/// Key lifetime state
|
||||
struct KeyLifetime {
|
||||
rekey_at_or_after_counter: CounterValue,
|
||||
hard_expire_at_counter: CounterValue,
|
||||
rekey_at_or_after_counter: u32,
|
||||
rekey_at_or_after_timestamp: i64,
|
||||
}
|
||||
|
||||
/// Alice's KEY_OFFER, remembered so Noise agreement process can resume on KEY_COUNTER_OFFER.
|
||||
struct EphemeralOffer {
|
||||
id: [u8; 16], // Arbitrary random offer ID
|
||||
key_id: bool, // The key_id bound to this offer, for handling OOO rekeying
|
||||
creation_time: i64, // Local time when offer was created
|
||||
ratchet_count: u64, // Ratchet count (starting at zero) for initial offer
|
||||
ratchet_key: Option<Secret<64>>, // Ratchet key from previous offer or None if first offer
|
||||
|
@ -235,6 +235,7 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
if send_ephemeral_offer(
|
||||
&mut send,
|
||||
send_counter.next(),
|
||||
false,
|
||||
local_session_id,
|
||||
None,
|
||||
app.get_local_s_public_blob(),
|
||||
|
@ -254,14 +255,14 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
id: local_session_id,
|
||||
application_data,
|
||||
send_counter,
|
||||
receive_window: CounterWindow::new_uninit(),//alice does not know bob's counter yet
|
||||
receive_window: [CounterWindow::new(), CounterWindow::new_invalid()],
|
||||
psk: psk.clone(),
|
||||
noise_ss,
|
||||
header_check_cipher,
|
||||
state: RwLock::new(SessionMutableState {
|
||||
remote_session_id: None,
|
||||
session_keys: [None, None, None],
|
||||
cur_session_key_idx: 0,
|
||||
session_keys: [None, None],
|
||||
cur_session_key_id: false,
|
||||
offer,
|
||||
last_remote_offer: i64::MIN,
|
||||
}),
|
||||
|
@ -290,7 +291,7 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU);
|
||||
let state = self.state.read().unwrap();
|
||||
if let Some(remote_session_id) = state.remote_session_id {
|
||||
if let Some(session_key) = state.session_keys[state.cur_session_key_idx].as_ref() {
|
||||
if let Some(session_key) = state.session_keys[state.cur_session_key_id as usize].as_ref() {
|
||||
// Total size of the armored packet we are going to send (may end up being fragmented)
|
||||
let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
|
||||
|
||||
|
@ -309,6 +310,7 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
PACKET_TYPE_DATA,
|
||||
remote_session_id.into(),
|
||||
counter,
|
||||
state.cur_session_key_id
|
||||
)?;
|
||||
|
||||
// Get an initialized AES-GCM cipher and re-initialize with a 96-bit IV built from remote session ID,
|
||||
|
@ -367,7 +369,7 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
/// Check whether this session is established.
|
||||
pub fn established(&self) -> bool {
|
||||
let state = self.state.read().unwrap();
|
||||
state.remote_session_id.is_some() && state.session_keys[state.cur_session_key_idx].is_some()
|
||||
state.remote_session_id.is_some() && state.session_keys[state.cur_session_key_id as usize].is_some()
|
||||
}
|
||||
|
||||
/// Get information about this session's security state.
|
||||
|
@ -376,7 +378,7 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
/// and whether Kyber1024 was used. None is returned if the session isn't established.
|
||||
pub fn status(&self) -> Option<([u8; 16], i64, u64, bool)> {
|
||||
let state = self.state.read().unwrap();
|
||||
if let Some(key) = state.session_keys[state.cur_session_key_idx].as_ref() {
|
||||
if let Some(key) = state.session_keys[state.cur_session_key_id as usize].as_ref() {
|
||||
Some((key.secret_fingerprint, key.creation_time, key.ratchet_count, key.jedi))
|
||||
} else {
|
||||
None
|
||||
|
@ -402,9 +404,9 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
) {
|
||||
let state = self.state.read().unwrap();
|
||||
if (force_rekey
|
||||
|| state.session_keys[state.cur_session_key_idx]
|
||||
|| state.session_keys[state.cur_session_key_id as usize]
|
||||
.as_ref()
|
||||
.map_or(true, |key| key.lifetime.should_rekey(self.send_counter.previous(), current_time)))
|
||||
.map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time)))
|
||||
&& state
|
||||
.offer
|
||||
.as_ref()
|
||||
|
@ -414,7 +416,8 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
let mut offer = None;
|
||||
if send_ephemeral_offer(
|
||||
&mut send,
|
||||
self.send_counter.next(),
|
||||
CounterValue::get_initial_offer_counter(),
|
||||
!state.cur_session_key_id,
|
||||
self.id,
|
||||
state.remote_session_id,
|
||||
app.get_local_s_public_blob(),
|
||||
|
@ -422,7 +425,7 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
&remote_s_public,
|
||||
&self.remote_s_public_blob_hash,
|
||||
&self.noise_ss,
|
||||
state.session_keys[state.cur_session_key_idx].as_ref(),
|
||||
state.session_keys[state.cur_session_key_id as usize].as_ref(),
|
||||
if state.remote_session_id.is_some() {
|
||||
Some(&self.header_check_cipher)
|
||||
} else {
|
||||
|
@ -435,7 +438,7 @@ impl<Application: ApplicationLayer> Session<Application> {
|
|||
.is_ok()
|
||||
{
|
||||
drop(state);
|
||||
let _ = self.state.write().unwrap().offer.replace(offer.unwrap());
|
||||
self.state.write().unwrap().offer = offer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -477,7 +480,9 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
return Err(Error::InvalidPacket);
|
||||
}
|
||||
|
||||
let counter = u32::from_le(memory::load_raw(incoming_packet));
|
||||
let mut counter = u32::from_le(memory::load_raw(incoming_packet));
|
||||
let key_id = (counter & 1) > 0;
|
||||
counter = counter.wrapping_shr(1);
|
||||
let packet_type_fragment_info = u16::from_le(memory::load_raw(&incoming_packet[14..16]));
|
||||
let packet_type = (packet_type_fragment_info & 0x0f) as u8;
|
||||
let fragment_count = ((packet_type_fragment_info.wrapping_shr(4) + 1) as u8) & 63;
|
||||
|
@ -487,7 +492,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
{
|
||||
if let Some(session) = app.lookup_session(local_session_id) {
|
||||
if verify_header_check_code(incoming_packet, &session.header_check_cipher) {
|
||||
if session.receive_window.message_received(counter) {
|
||||
if session.receive_window[key_id as usize].message_received(counter) {
|
||||
let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter);
|
||||
if fragment_count > 1 {
|
||||
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
|
||||
|
@ -501,6 +506,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
&mut send,
|
||||
data_buf,
|
||||
counter,
|
||||
key_id,
|
||||
canonical_header.as_bytes(),
|
||||
assembled_packet.as_ref(),
|
||||
packet_type,
|
||||
|
@ -520,6 +526,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
&mut send,
|
||||
data_buf,
|
||||
counter,
|
||||
key_id,
|
||||
canonical_header.as_bytes(),
|
||||
&[incoming_packet_buf],
|
||||
packet_type,
|
||||
|
@ -556,6 +563,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
&mut send,
|
||||
data_buf,
|
||||
counter,
|
||||
key_id,
|
||||
canonical_header.as_bytes(),
|
||||
assembled_packet.as_ref(),
|
||||
packet_type,
|
||||
|
@ -571,6 +579,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
&mut send,
|
||||
data_buf,
|
||||
counter,
|
||||
key_id,
|
||||
canonical_header.as_bytes(),
|
||||
&[incoming_packet_buf],
|
||||
packet_type,
|
||||
|
@ -599,6 +608,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
send: &mut SendFunction,
|
||||
data_buf: &'a mut [u8],
|
||||
counter: u32,
|
||||
key_id: bool,
|
||||
canonical_header_bytes: &[u8; 12],
|
||||
fragments: &[Application::IncomingPacketBuffer],
|
||||
packet_type: u8,
|
||||
|
@ -615,82 +625,59 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
if packet_type <= PACKET_TYPE_NOP {
|
||||
if let Some(session) = session {
|
||||
let state = session.state.read().unwrap();
|
||||
for p in 0..KEY_HISTORY_SIZE {
|
||||
let key_idx = (state.cur_session_key_idx + p) % KEY_HISTORY_SIZE;
|
||||
if let Some(session_key) = state.session_keys[key_idx].as_ref() {
|
||||
let mut c = session_key.get_receive_cipher();
|
||||
c.reset_init_gcm(canonical_header_bytes);
|
||||
////////////////////////////////////////////////////////////////
|
||||
// packet decoding for post-noise transport
|
||||
////////////////////////////////////////////////////////////////
|
||||
if let Some(session_key) = state.session_keys[key_id as usize].as_ref() {
|
||||
let mut c = session_key.get_receive_cipher();
|
||||
c.reset_init_gcm(canonical_header_bytes);
|
||||
////////////////////////////////////////////////////////////////
|
||||
// packet decoding for post-noise transport
|
||||
////////////////////////////////////////////////////////////////
|
||||
|
||||
let mut data_len = 0;
|
||||
let mut data_len = 0;
|
||||
|
||||
// Decrypt fragments 0..N-1 where N is the number of fragments.
|
||||
for f in fragments[..(fragments.len() - 1)].iter() {
|
||||
let f = f.as_ref();
|
||||
debug_assert!(f.len() >= HEADER_SIZE);
|
||||
let current_frag_data_start = data_len;
|
||||
data_len += f.len() - HEADER_SIZE;
|
||||
if data_len > data_buf.len() {
|
||||
unlikely_branch();
|
||||
session_key.return_receive_cipher(c);
|
||||
return Err(Error::DataBufferTooSmall);
|
||||
}
|
||||
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
|
||||
}
|
||||
|
||||
// Decrypt final fragment (or only fragment if not fragmented)
|
||||
// Decrypt fragments 0..N-1 where N is the number of fragments.
|
||||
for f in fragments[..(fragments.len() - 1)].iter() {
|
||||
let f = f.as_ref();
|
||||
debug_assert!(f.len() >= HEADER_SIZE);
|
||||
let current_frag_data_start = data_len;
|
||||
let last_fragment = fragments.last().unwrap().as_ref();
|
||||
if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
|
||||
unlikely_branch();
|
||||
return Err(Error::InvalidPacket);
|
||||
}
|
||||
data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
|
||||
data_len += f.len() - HEADER_SIZE;
|
||||
if data_len > data_buf.len() {
|
||||
unlikely_branch();
|
||||
session_key.return_receive_cipher(c);
|
||||
return Err(Error::DataBufferTooSmall);
|
||||
}
|
||||
let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE;
|
||||
c.crypt(
|
||||
&last_fragment[HEADER_SIZE..payload_end],
|
||||
&mut data_buf[current_frag_data_start..data_len],
|
||||
);
|
||||
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
|
||||
}
|
||||
|
||||
let gcm_tag = &last_fragment[payload_end..];
|
||||
let aead_authentication_ok = c.finish_decrypt(gcm_tag);
|
||||
// Decrypt final fragment (or only fragment if not fragmented)
|
||||
let current_frag_data_start = data_len;
|
||||
let last_fragment = fragments.last().unwrap().as_ref();
|
||||
if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
|
||||
unlikely_branch();
|
||||
return Err(Error::InvalidPacket);
|
||||
}
|
||||
data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
|
||||
if data_len > data_buf.len() {
|
||||
unlikely_branch();
|
||||
session_key.return_receive_cipher(c);
|
||||
return Err(Error::DataBufferTooSmall);
|
||||
}
|
||||
let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE;
|
||||
c.crypt(
|
||||
&last_fragment[HEADER_SIZE..payload_end],
|
||||
&mut data_buf[current_frag_data_start..data_len],
|
||||
);
|
||||
|
||||
if aead_authentication_ok {
|
||||
if session.receive_window.message_authenticated(counter) {
|
||||
// Select this key as the new default if it's newer than the current key.
|
||||
if p > 0
|
||||
&& state.session_keys[state.cur_session_key_idx]
|
||||
.as_ref()
|
||||
.map_or(true, |old| old.creation_counter < session_key.creation_counter)
|
||||
{
|
||||
drop(state);
|
||||
let mut state = session.state.write().unwrap();
|
||||
state.cur_session_key_idx = key_idx;
|
||||
for i in 0..KEY_HISTORY_SIZE {
|
||||
if i != key_idx {
|
||||
if let Some(old_key) = state.session_keys[key_idx].as_ref() {
|
||||
// Release pooled cipher memory from old keys.
|
||||
old_key.receive_cipher_pool.lock().unwrap().clear();
|
||||
old_key.send_cipher_pool.lock().unwrap().clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let gcm_tag = &last_fragment[payload_end..];
|
||||
let aead_authentication_ok = c.finish_decrypt(gcm_tag);
|
||||
session_key.return_receive_cipher(c);
|
||||
|
||||
if packet_type == PACKET_TYPE_DATA {
|
||||
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
|
||||
} else {
|
||||
unlikely_branch();
|
||||
return Ok(ReceiveResult::Ok);
|
||||
}
|
||||
if aead_authentication_ok {
|
||||
if session.receive_window[key_id as usize].message_authenticated(counter) {
|
||||
if packet_type == PACKET_TYPE_DATA {
|
||||
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
|
||||
} else {
|
||||
unlikely_branch();
|
||||
return Ok(ReceiveResult::Ok);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -856,13 +843,13 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
let mut ratchet_key = None;
|
||||
let mut last_ratchet_count = 0;
|
||||
let state = session.state.read().unwrap();
|
||||
for k in state.session_keys.iter() {
|
||||
if let Some(k) = k.as_ref() {
|
||||
if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) {
|
||||
ratchet_key = Some(k.ratchet_key.clone());
|
||||
last_ratchet_count = k.ratchet_count;
|
||||
break;
|
||||
}
|
||||
if state.cur_session_key_id == key_id {
|
||||
return Ok(ReceiveResult::Ignored); // alice is requesting to overwrite the current key, reject it
|
||||
}
|
||||
if let Some(k) = state.session_keys[state.cur_session_key_id as usize].as_ref() {
|
||||
if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) {
|
||||
ratchet_key = Some(k.ratchet_key.clone());
|
||||
last_ratchet_count = k.ratchet_count;
|
||||
}
|
||||
}
|
||||
if ratchet_key.is_none() {
|
||||
|
@ -882,14 +869,14 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
id: new_session_id,
|
||||
application_data: associated_object,
|
||||
send_counter: Counter::new(),
|
||||
receive_window: CounterWindow::new(counter),
|
||||
receive_window: [CounterWindow::new_invalid(), CounterWindow::new_invalid()],
|
||||
psk,
|
||||
noise_ss,
|
||||
header_check_cipher,
|
||||
state: RwLock::new(SessionMutableState {
|
||||
remote_session_id: Some(alice_session_id),
|
||||
session_keys: [None, None, None],
|
||||
cur_session_key_idx: 0,
|
||||
session_keys: [None, None],
|
||||
cur_session_key_id: key_id,
|
||||
offer: None,
|
||||
last_remote_offer: current_time,
|
||||
}),
|
||||
|
@ -959,7 +946,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
////////////////////////////////////////////////////////////////
|
||||
|
||||
let mut reply_buf = [0_u8; KEX_BUF_LEN];
|
||||
let reply_counter = session.send_counter.next();
|
||||
let reply_counter = CounterValue::get_initial_offer_counter();
|
||||
let mut idx = HEADER_SIZE;
|
||||
|
||||
idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?;
|
||||
|
@ -991,6 +978,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
PACKET_TYPE_KEY_COUNTER_OFFER,
|
||||
alice_session_id.into(),
|
||||
reply_counter,
|
||||
key_id,
|
||||
)?;
|
||||
let reply_canonical_header =
|
||||
CanonicalHeader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32());
|
||||
|
@ -1033,23 +1021,25 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
session_key,
|
||||
Role::Bob,
|
||||
current_time,
|
||||
reply_counter,
|
||||
last_ratchet_count + 1,
|
||||
hybrid_kk.is_some(),
|
||||
);
|
||||
|
||||
let mut state = session.state.write().unwrap();
|
||||
let _ = state.remote_session_id.replace(alice_session_id);
|
||||
let next_key_ptr = (state.cur_session_key_idx + 1) % KEY_HISTORY_SIZE;
|
||||
let _ = state.session_keys[next_key_ptr].replace(session_key);
|
||||
let _ = state.session_keys[key_id as usize].replace(session_key);
|
||||
session.send_counter.reset_after_initial_offer();
|
||||
state.cur_session_key_id = key_id;
|
||||
session.receive_window[key_id as usize].reset_after_initial_offer();
|
||||
session.receive_window[key_id as usize].message_authenticated(counter);
|
||||
drop(state);
|
||||
|
||||
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
|
||||
|
||||
send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher);
|
||||
|
||||
if new_session.is_some() {
|
||||
return Ok(ReceiveResult::OkNewSession(new_session.unwrap()));
|
||||
if let Some(new_session) = new_session {
|
||||
return Ok(ReceiveResult::OkNewSession(new_session));
|
||||
} else {
|
||||
return Ok(ReceiveResult::Ok);
|
||||
}
|
||||
|
@ -1106,7 +1096,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?;
|
||||
|
||||
// Check that this is a counter offer to the original offer we sent.
|
||||
if !offer.id.eq(offer_id) {
|
||||
if !(offer.id.eq(offer_id) & (offer.key_id == key_id)) {
|
||||
return Ok(ReceiveResult::Ignored);
|
||||
}
|
||||
|
||||
|
@ -1151,47 +1141,21 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
|
||||
// Alice has now completed and validated the full hybrid exchange.
|
||||
|
||||
let reply_counter = session.send_counter.next();
|
||||
let session_key = SessionKey::new(
|
||||
session_key,
|
||||
Role::Alice,
|
||||
current_time,
|
||||
reply_counter,
|
||||
last_ratchet_count + 1,
|
||||
hybrid_kk.is_some(),
|
||||
);
|
||||
session.receive_window.init_authenticated(counter);
|
||||
|
||||
////////////////////////////////////////////////////////////////
|
||||
// packet encoding for post-noise session start ack
|
||||
////////////////////////////////////////////////////////////////
|
||||
|
||||
let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
|
||||
create_packet_header(
|
||||
&mut reply_buf,
|
||||
HEADER_SIZE + AES_GCM_TAG_SIZE,
|
||||
mtu,
|
||||
PACKET_TYPE_NOP,
|
||||
bob_session_id.into(),
|
||||
reply_counter,
|
||||
)?;
|
||||
|
||||
let mut c = session_key.get_send_cipher(reply_counter)?;
|
||||
c.reset_init_gcm(
|
||||
CanonicalHeader::make(bob_session_id.into(), PACKET_TYPE_NOP, reply_counter.to_u32()).as_bytes(),
|
||||
);
|
||||
let gcm_tag = c.finish_encrypt();
|
||||
safe_write_all(&mut reply_buf, HEADER_SIZE, &gcm_tag)?;
|
||||
session_key.return_send_cipher(c);
|
||||
|
||||
set_header_check_code(&mut reply_buf, &session.header_check_cipher);
|
||||
send(&mut reply_buf);
|
||||
|
||||
drop(state);
|
||||
let mut state = session.state.write().unwrap();
|
||||
let _ = state.remote_session_id.replace(bob_session_id);
|
||||
let next_key_idx = (state.cur_session_key_idx + 1) % KEY_HISTORY_SIZE;
|
||||
let _ = state.session_keys[next_key_idx].replace(session_key);
|
||||
let _ = state.session_keys[key_id as usize].replace(session_key);
|
||||
session.send_counter.reset_after_initial_offer();
|
||||
state.cur_session_key_id = key_id;
|
||||
session.receive_window[key_id as usize].message_authenticated(counter);
|
||||
let _ = state.offer.take();
|
||||
|
||||
return Ok(ReceiveResult::Ok);
|
||||
|
@ -1213,6 +1177,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
|
|||
fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
|
||||
send: &mut SendFunction,
|
||||
counter: CounterValue,
|
||||
key_id: bool,
|
||||
alice_session_id: SessionId,
|
||||
bob_session_id: Option<SessionId>,
|
||||
alice_s_public_blob: &[u8],
|
||||
|
@ -1298,6 +1263,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
|
|||
PACKET_TYPE_INITIAL_KEY_OFFER,
|
||||
bob_session_id,
|
||||
counter,
|
||||
key_id,
|
||||
)?;
|
||||
|
||||
let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32());
|
||||
|
@ -1351,6 +1317,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
|
|||
|
||||
*ret_ephemeral_offer = Some(EphemeralOffer {
|
||||
id,
|
||||
key_id,
|
||||
creation_time: current_time,
|
||||
ratchet_count,
|
||||
ratchet_key,
|
||||
|
@ -1370,6 +1337,7 @@ fn create_packet_header(
|
|||
packet_type: u8,
|
||||
recipient_session_id: SessionId,
|
||||
counter: CounterValue,
|
||||
key_id: bool
|
||||
) -> Result<(), Error> {
|
||||
let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize;
|
||||
|
||||
|
@ -1379,6 +1347,7 @@ fn create_packet_header(
|
|||
debug_assert!(fragment_count > 0);
|
||||
debug_assert!(fragment_count <= MAX_FRAGMENTS);
|
||||
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
|
||||
let counter = counter.to_u32().wrapping_shl(1) | (key_id as u32);
|
||||
|
||||
if fragment_count <= MAX_FRAGMENTS {
|
||||
// Header indexed by bit/byte:
|
||||
|
@ -1388,7 +1357,7 @@ fn create_packet_header(
|
|||
// [112-115]/[14-] packet type (0-15)
|
||||
// [116-121]/[-] number of fragments (0..63 for 1..64 fragments total)
|
||||
// [122-127]/[-15] fragment number (0, 1, 2, ...)
|
||||
memory::store_raw((counter.to_u32() as u64).to_le(), header_destination_buffer);
|
||||
memory::store_raw((counter as u64).to_le(), header_destination_buffer);
|
||||
memory::store_raw(
|
||||
(u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52))
|
||||
.to_le(),
|
||||
|
@ -1499,7 +1468,6 @@ impl SessionKey {
|
|||
key: Secret<64>,
|
||||
role: Role,
|
||||
current_time: i64,
|
||||
current_counter: CounterValue,
|
||||
ratchet_count: u64,
|
||||
jedi: bool,
|
||||
) -> Self {
|
||||
|
@ -1512,8 +1480,7 @@ impl SessionKey {
|
|||
Self {
|
||||
secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(),
|
||||
creation_time: current_time,
|
||||
creation_counter: current_counter,
|
||||
lifetime: KeyLifetime::new(current_counter, current_time),
|
||||
lifetime: KeyLifetime::new(current_time),
|
||||
ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING),
|
||||
receive_key,
|
||||
send_key,
|
||||
|
@ -1560,12 +1527,9 @@ impl SessionKey {
|
|||
}
|
||||
|
||||
impl KeyLifetime {
|
||||
fn new(current_counter: CounterValue, current_time: i64) -> Self {
|
||||
fn new(current_time: i64) -> Self {
|
||||
Self {
|
||||
rekey_at_or_after_counter: current_counter
|
||||
.counter_value_after_uses(REKEY_AFTER_USES)
|
||||
.counter_value_after_uses((random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER) as u64),
|
||||
hard_expire_at_counter: current_counter.counter_value_after_uses(EXPIRE_AFTER_USES),
|
||||
rekey_at_or_after_counter: REKEY_AFTER_USES as u32 + (random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER),
|
||||
rekey_at_or_after_timestamp: current_time
|
||||
+ REKEY_AFTER_TIME_MS
|
||||
+ (random::next_u32_secure() % REKEY_AFTER_TIME_MS_MAX_JITTER) as i64,
|
||||
|
@ -1573,11 +1537,11 @@ impl KeyLifetime {
|
|||
}
|
||||
|
||||
fn should_rekey(&self, counter: CounterValue, current_time: i64) -> bool {
|
||||
counter >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp
|
||||
counter.to_u32() >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp
|
||||
}
|
||||
|
||||
fn expired(&self, counter: CounterValue) -> bool {
|
||||
counter >= self.hard_expire_at_counter
|
||||
counter.to_u32() >= EXPIRE_AFTER_USES as u32
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue