diff --git a/zssp/changes.txt b/zssp/changes.txt deleted file mode 100644 index ff8a5efdc..000000000 --- a/zssp/changes.txt +++ /dev/null @@ -1,19 +0,0 @@ -zssp has been moved into it's own crate. - -zssp has been cut up into several files, only the new zssp.rs file contains the critical security path. - -Standardized the naming conventions for security variables throughout zssp. - -Implemented a safer version of write_all for zssp to use. This has 3 benefits: it completely prevents unknown io errors, making error handling easier and self-documenting; it completely prevents src from being truncated in dest, putting in an extra barrier to prevent catastrophic key truncation; and it has slightly less performance overhead than a write_all. - -Implemented a safer version of read_exact for zssp to use. This has similar benefits to the previous change. - -Refactored most buffer logic to use safe_read_exact and safe_write_all, the resulting code is less verbose and easier to analyze: Because of this refactor the buffer overrun below was caught. - -Fixed a buffer overrun panic when decoding alice_ratchet_key_fingerprint - -Renamed variables and added extra intermediate values so encoding and decoding are more obviously symmetric. - -Added multiple comments. - -Removed Box, EphemeralOffer is now passed out by reference instead of returned up the stack. diff --git a/zssp/src/applicationlayer.rs b/zssp/src/applicationlayer.rs index f72450588..68358aa95 100644 --- a/zssp/src/applicationlayer.rs +++ b/zssp/src/applicationlayer.rs @@ -67,6 +67,17 @@ pub trait ApplicationLayer: Sized { /// On success a tuple of local session ID, static secret, and associated object is returned. The /// static secret is whatever results from agreement between the local and remote static public /// keys. + /// + /// When `accept_new_session` is called, `remote_static_public` and `remote_metadata` have not yet been + /// authenticated. As such avoid mutating state until OkNewSession(Session) is returned, as the connection + /// may be adversarial. + /// + /// When `remote_static_public` and `remote_metadata` are eventually authenticated, the zssp protocol cannot + /// guarantee that they are unique, i.e. `remote_static_public` and `remote_metadata` may be duplicates from + /// an old attempt to establish a session, and may even have been replayed by an adversary. If your use-case + /// needs uniqueness for reliability or security, consider either including a timestamp in the metadata, or + /// sending the metadata as an extra transport packet after the session is fully established. + /// It is guaranteed they will be unique for at least the lifetime of the returned session. fn accept_new_session( &self, receive_context: &ReceiveContext, diff --git a/zssp/src/constants.rs b/zssp/src/constants.rs index 523b6692e..7ff145de5 100644 --- a/zssp/src/constants.rs +++ b/zssp/src/constants.rs @@ -74,11 +74,8 @@ pub(crate) const HMAC_SIZE: usize = 48; /// This is large since some ZeroTier nodes handle huge numbers of links, like roots and controllers. pub(crate) const SESSION_ID_SIZE: usize = 6; -/// Number of session keys to hold at a given time (current, previous, next). -pub(crate) const KEY_HISTORY_SIZE: usize = 3; - /// Maximum difference between out-of-order incoming packet counters, and size of deduplication buffer. -pub(crate) const COUNTER_MAX_DELTA: u32 = 16; +pub(crate) const COUNTER_MAX_ALLOWED_OOO: usize = 16; // Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use pub(crate) const PACKET_TYPE_DATA: u8 = 0; diff --git a/zssp/src/counter.rs b/zssp/src/counter.rs index 9aa147670..2eaa558cc 100644 --- a/zssp/src/counter.rs +++ b/zssp/src/counter.rs @@ -1,28 +1,29 @@ -use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::sync::atomic::{Ordering, AtomicU32}; -use zerotier_crypto::random; - -use crate::constants::COUNTER_MAX_DELTA; +use crate::constants::COUNTER_MAX_ALLOWED_OOO; /// Outgoing packet counter with strictly ordered atomic semantics. +/// Count sequence always starts at 1u32, it must never be allowed to overflow /// -/// The counter used in packets is actually 32 bits, but using a 64-bit integer internally -/// lets us more safely implement key lifetime limits without confusing logic to handle 32-bit -/// wrap-around. #[repr(transparent)] -pub(crate) struct Counter(AtomicU64); +pub(crate) struct Counter(AtomicU32); impl Counter { #[inline(always)] pub fn new() -> Self { // Using a random value has no security implication. Zero would be fine. This just // helps randomize packet contents a bit. - Self(AtomicU64::new(random::next_u32_secure() as u64)) + Self(AtomicU32::new(1u32)) + } + + #[inline(always)] + pub fn reset_for_new_key_offer(&self) { + self.0.store(1u32, Ordering::SeqCst); } /// Get the value most recently used to send a packet. #[inline(always)] - pub fn previous(&self) -> CounterValue { + pub fn current(&self) -> CounterValue { CounterValue(self.0.load(Ordering::SeqCst)) } @@ -36,7 +37,7 @@ impl Counter { /// A value of the outgoing packet counter. #[repr(transparent)] #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub(crate) struct CounterValue(u64); +pub(crate) struct CounterValue(u32); impl CounterValue { /// Get the 32-bit counter value used to build packets. @@ -44,45 +45,44 @@ impl CounterValue { pub fn to_u32(&self) -> u32 { self.0 as u32 } - - /// Get the counter value after N more uses of the parent counter. - /// - /// This checks for u64 overflow for the sake of correctness. Be careful if using ZSSP in a - /// generational starship where sessions may last for millions of years. - #[inline(always)] - pub fn counter_value_after_uses(&self, uses: u64) -> Self { - Self(self.0.checked_add(uses).unwrap()) - } } /// Incoming packet deduplication and replay protection window. -pub(crate) struct CounterWindow(AtomicU32, [AtomicU32; COUNTER_MAX_DELTA as usize]); +pub(crate) struct CounterWindow([AtomicU32; COUNTER_MAX_ALLOWED_OOO]); impl CounterWindow { #[inline(always)] - pub fn new(initial: u32) -> Self { - Self(AtomicU32::new(initial), std::array::from_fn(|_| AtomicU32::new(initial))) + pub fn new() -> Self { + Self(std::array::from_fn(|_| AtomicU32::new(0))) + } + ///this creates a counter window that rejects everything + pub fn new_invalid() -> Self { + Self(std::array::from_fn(|_| AtomicU32::new(u32::MAX))) + } + pub fn reset_for_new_key_offer(&self) { + for i in 0..COUNTER_MAX_ALLOWED_OOO { + self.0[i].store(0, Ordering::SeqCst) + } + } + pub fn invalidate(&self) { + for i in 0..COUNTER_MAX_ALLOWED_OOO { + self.0[i].store(u32::MAX, Ordering::SeqCst) + } } #[inline(always)] pub fn message_received(&self, received_counter_value: u32) -> bool { - let prev_max = self.0.fetch_max(received_counter_value, Ordering::AcqRel); - if received_counter_value >= prev_max || prev_max.wrapping_sub(received_counter_value) <= COUNTER_MAX_DELTA { - // First, the most common case: counter is higher than the previous maximum OR is no older than MAX_DELTA. - // In that case we accept the packet if it is not a duplicate. Duplicate check is this swap/compare. - self.1[(received_counter_value % COUNTER_MAX_DELTA) as usize].swap(received_counter_value, Ordering::AcqRel) - != received_counter_value - } else if received_counter_value.wrapping_sub(prev_max) <= COUNTER_MAX_DELTA { - // If the received value is lower and wraps when the previous max is subtracted, this means the - // unsigned integer counter has wrapped. In that case we write the new lower-but-actually-higher "max" - // value and then check the deduplication window. - self.0.store(received_counter_value, Ordering::Release); - self.1[(received_counter_value % COUNTER_MAX_DELTA) as usize].swap(received_counter_value, Ordering::AcqRel) - != received_counter_value - } else { - // If the received value is more than MAX_DELTA in the past and wrapping has NOT occurred, this packet - // is too old and is rejected. - false - } + let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize; + //it is highly likely this can be a Relaxed ordering, but I want someone else to confirm that is the case + let pre = self.0[idx].load(Ordering::SeqCst); + return pre < received_counter_value; + } + + #[inline(always)] + pub fn message_authenticated(&self, received_counter_value: u32) -> bool { + //if a valid message is received but one of its fragments was lost, it can technically be replayed. However since the message is incomplete, we know it still exists in the gather array, so the gather array will deduplicate the replayed message. Even if the gather array gets flushed, that flush still effectively deduplicates the replayed message. + //eventually the counter of that kind of message will be too OOO to be accepted anymore so it can't be used to DOS. + let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize; + return self.0[idx].fetch_max(received_counter_value, Ordering::SeqCst) < received_counter_value; } } diff --git a/zssp/src/tests.rs b/zssp/src/tests.rs index 45de756df..96e4e62bb 100644 --- a/zssp/src/tests.rs +++ b/zssp/src/tests.rs @@ -209,7 +209,7 @@ mod tests { ) .is_ok()); } - if (test_loop % 8) == 0 && test_loop >= 8 && host.this_name.eq("alice") { + if (test_loop % 8) == 0 && test_loop >= 8 { session.service(host, send_to_other, &[], mtu_buffer.len(), test_loop as i64, true); } } @@ -218,14 +218,53 @@ mod tests { } } + + #[inline(always)] + pub fn xorshift64(x: &mut u64) -> u32 { + *x ^= x.wrapping_shl(13); + *x ^= x.wrapping_shr(7); + *x ^= x.wrapping_shl(17); + *x as u32 + } #[test] fn counter_window() { - let w = CounterWindow::new(0xffffffff); - assert!(!w.message_received(0xffffffff)); - assert!(w.message_received(0)); - assert!(w.message_received(1)); - assert!(w.message_received(COUNTER_MAX_DELTA * 2)); - assert!(!w.message_received(0xffffffff)); - assert!(w.message_received(0xfffffffe)); + let mut rng = 844632; + let mut counter = 1u32; + let mut history = Vec::new(); + + let w = CounterWindow::new(); + for _i in 0..1000000 { + let p = xorshift64(&mut rng) as f32/(u32::MAX as f32 + 1.0); + let c; + if p < 0.5 { + let r = xorshift64(&mut rng); + c = counter + (r%(COUNTER_MAX_ALLOWED_OOO - 1) as u32 + 1); + } else if p < 0.8 { + counter = counter + (1); + c = counter; + } else if p < 0.9 { + if history.len() > 0 { + let idx = xorshift64(&mut rng) as usize%history.len(); + let c = history[idx]; + assert!(!w.message_authenticated(c)); + } + continue; + } else if p < 0.999 { + c = xorshift64(&mut rng); + w.message_received(c); + continue; + } else { + w.reset_for_new_key_offer(); + counter = 1u32; + history = Vec::new(); + continue; + } + if history.contains(&c) { + assert!(!w.message_authenticated(c)); + } else { + assert!(w.message_authenticated(c)); + history.push(c); + } + } } } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 7ef170c55..b34184dd0 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -3,6 +3,7 @@ // ZSSP: ZeroTier Secure Session Protocol // FIPS compliant Noise_IK with Jedi powers and built-in attack-resistant large payload (fragmentation) support. +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Mutex, RwLock}; use zerotier_crypto::aes::{Aes, AesGcm}; @@ -24,39 +25,43 @@ use crate::counter::{Counter, CounterValue, CounterWindow}; use crate::sessionid::SessionId; pub enum Error { - /// The packet was addressed to an unrecognized local session (should usually be ignored) + /// The packet was addressed to an unrecognized local session (should usually be ignored). UnknownLocalSessionId(SessionId), - /// Packet was not well formed + /// Packet was not well formed. InvalidPacket, - /// An invalid parameter was supplied to the function + /// An invalid parameter was supplied to the function. InvalidParameter, - /// Packet failed one or more authentication (MAC) checks - /// IMPORTANT: Do not reply to a peer who has sent a packet that has failed authentication. Any response at all will leak to an attacker what authentication step their packet failed at (timing attack), which lowers the total authentication entropy they have to brute force. - /// There is a safe way to reply if absolutely necessary, by sending the reply back after a constant amount of time, but this is difficult to get correct. + /// Packet failed one or more authentication (MAC) checks. + /// + /// **IMPORTANT**: Do not reply to a peer who has sent a packet that has failed authentication. + /// Any response at all will leak to an attacker what authentication step their packet failed at + /// (timing attack), which lowers the total authentication entropy they have to brute force. + /// There is a safe way to reply if absolutely necessary; by sending the reply back after a constant + /// amount of time, but this is very difficult to get correct. FailedAuthentication, /// New session was rejected by the application layer. NewSessionRejected, - /// Rekeying failed and session secret has reached its hard usage count limit + /// Rekeying failed and session secret has reached its hard usage count limit. MaxKeyLifetimeExceeded, - /// Attempt to send using session without established key + /// Attempt to send using session without established key. SessionNotEstablished, /// Packet ignored by rate limiter. RateLimited, - /// The other peer specified an unrecognized protocol version + /// The other peer specified an unrecognized protocol version. UnknownProtocolVersion, - /// Caller supplied data buffer is too small to receive data + /// Caller supplied data buffer is too small to receive data. DataBufferTooSmall, - /// Data object is too large to send, even with fragmentation + /// Data object is too large to send, even with fragmentation. DataTooLarge, /// An unexpected buffer overrun occured while attempting to encode or decode a packet. @@ -81,7 +86,10 @@ pub enum ReceiveResult<'a, H: ApplicationLayer> { /// The session will have already been gated by the accept_new_session() method in ApplicationLayer. OkNewSession(Session), - /// Packet appears valid but was ignored e.g. as a duplicate. + /// Packet superficially appears valid but was ignored e.g. as a duplicate. + /// + /// **IMPORTANT**: This packet was not authenticated, so for the most part treat this the same + /// as an `Error::FailedAuthentication`. Ignored, } @@ -89,6 +97,9 @@ pub enum ReceiveResult<'a, H: ApplicationLayer> { /// /// Note that the role can switch through the course of a session. It's the side that most recently /// initiated a session or a rekey event. Initiator is Alice, responder is Bob. +/// +/// We require that after every rekeying event Alice and Bob switch roles. +#[derive(PartialEq)] pub enum Role { Alice, Bob, @@ -112,11 +123,12 @@ pub struct Session { /// An arbitrary application defined object associated with each session pub application_data: Application::Data, - send_counter: Counter, // Outgoing packet counter and nonce state + ratchet_counts: [AtomicU64; 2], // Number of session keys in ratchet, starts at 1 + header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) + receive_windows: [CounterWindow; 2], // Receive window for anti-replay and deduplication + state: RwLock, // Mutable parts of state (other than defrag buffers) psk: Secret<64>, // Arbitrary PSK provided by external code noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key - header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) - state: RwLock, // Mutable parts of state (other than defrag buffers) remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_p384_bytes: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key @@ -125,8 +137,9 @@ pub struct Session { struct SessionMutableState { remote_session_id: Option, // The other side's 48-bit session ID - session_keys: [Option; KEY_HISTORY_SIZE], // Buffers to store current, next, and last active key - cur_session_key_idx: usize, // Pointer used for keys[] circular buffer + send_counters: [Counter; 2], // Outgoing packet counter and nonce state, starts at 1 + session_keys: [Option; 2], // Buffers to store current, next, and last active key + cur_session_key_id: bool, // Pointer used for keys[] circular buffer offer: Option, // Most recent ephemeral offer sent to remote last_remote_offer: i64, // Time of most recent ephemeral offer (ms) } @@ -135,31 +148,27 @@ struct SessionMutableState { struct SessionKey { secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret creation_time: i64, // Time session key was established - creation_counter: CounterValue, // Counter value at which session was established - receive_window: CounterWindow, // Receive window for anti-replay and deduplication lifetime: KeyLifetime, // Key expiration time and counter ratchet_key: Secret<64>, // Ratchet key for deriving the next session key receive_key: Secret, // Receive side AES-GCM key send_key: Secret, // Send side AES-GCM key receive_cipher_pool: Mutex>>, // Pool of reusable sending ciphers send_cipher_pool: Mutex>>, // Pool of reusable receiving ciphers - ratchet_count: u64, // Number of preceding session keys in ratchet jedi: bool, // True if Kyber1024 was used (both sides enabled) + role: Role, // The role of the local party that created this key } /// Key lifetime state struct KeyLifetime { - rekey_at_or_after_counter: CounterValue, - hard_expire_at_counter: CounterValue, + rekey_at_or_after_counter: u32, rekey_at_or_after_timestamp: i64, } /// Alice's KEY_OFFER, remembered so Noise agreement process can resume on KEY_COUNTER_OFFER. struct EphemeralOffer { id: [u8; 16], // Arbitrary random offer ID + key_id: bool, // The key_id bound to this offer, for handling OOO rekeying creation_time: i64, // Local time when offer was created - ratchet_count: u64, // Ratchet count (starting at zero) for initial offer - ratchet_key: Option>, // Ratchet key from previous offer or None if first offer ss_key: Secret<64>, // Noise session key "under construction" at state after offer sent alice_e_keypair: P384KeyPair, // NIST P-384 key pair (Noise ephemeral key for Alice) alice_hk_keypair: Option, // Kyber1024 key pair (PQ hybrid ephemeral key for Alice) @@ -235,6 +244,8 @@ impl Session { if send_ephemeral_offer( &mut send, send_counter.next(), + false, + false, local_session_id, None, app.get_local_s_public_blob(), @@ -243,6 +254,7 @@ impl Session { &bob_s_public_blob_hash, &noise_ss, None, + 1, None, mtu, current_time, @@ -253,17 +265,19 @@ impl Session { return Ok(Self { id: local_session_id, application_data, - send_counter, - psk: psk.clone(), - noise_ss, + ratchet_counts: [AtomicU64::new(1), AtomicU64::new(0)], header_check_cipher, + receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], state: RwLock::new(SessionMutableState { + send_counters: [send_counter, Counter::new()], remote_session_id: None, - session_keys: [None, None, None], - cur_session_key_idx: 0, + session_keys: [None, None], + cur_session_key_id: false, offer, last_remote_offer: i64::MIN, }), + psk: psk.clone(), + noise_ss, remote_s_public_blob_hash: bob_s_public_blob_hash, remote_s_public_p384_bytes: bob_s_public.as_bytes().clone(), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), @@ -289,12 +303,14 @@ impl Session { debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); let state = self.state.read().unwrap(); if let Some(remote_session_id) = state.remote_session_id { - if let Some(session_key) = state.session_keys[state.cur_session_key_idx].as_ref() { + let key_id = state.cur_session_key_id; + if let Some(session_key) = state.session_keys[key_id as usize].as_ref() { // Total size of the armored packet we are going to send (may end up being fragmented) let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; - + //key ratchet count to be used for salting + let ratchet_count = self.ratchet_counts[key_id as usize].load(Ordering::Relaxed); // This outgoing packet's nonce counter value. - let counter = self.send_counter.next(); + let counter = state.send_counters[key_id as usize].next(); //////////////////////////////////////////////////////////////// // packet encoding for post-noise transport @@ -308,6 +324,7 @@ impl Session { PACKET_TYPE_DATA, remote_session_id.into(), counter, + key_id )?; // Get an initialized AES-GCM cipher and re-initialize with a 96-bit IV built from remote session ID, @@ -326,7 +343,7 @@ impl Session { let fragment_size = fragment_data_size + HEADER_SIZE; c.crypt(&data[..fragment_data_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); data = &data[fragment_data_size..]; - set_header_check_code(mtu_sized_buffer, &self.header_check_cipher); + set_header_check_code(mtu_sized_buffer, ratchet_count, &self.header_check_cipher); send(&mut mtu_sized_buffer[..fragment_size]); debug_assert!(header[15].wrapping_shr(2) < 63); @@ -347,7 +364,7 @@ impl Session { c.crypt(data, &mut mtu_sized_buffer[HEADER_SIZE..payload_end]); let gcm_tag = c.finish_encrypt(); mtu_sized_buffer[payload_end..last_fragment_size].copy_from_slice(&gcm_tag); - set_header_check_code(mtu_sized_buffer, &self.header_check_cipher); + set_header_check_code(mtu_sized_buffer, ratchet_count, &self.header_check_cipher); send(&mut mtu_sized_buffer[..last_fragment_size]); // Check reusable AES-GCM instance back into pool. @@ -366,7 +383,7 @@ impl Session { /// Check whether this session is established. pub fn established(&self) -> bool { let state = self.state.read().unwrap(); - state.remote_session_id.is_some() && state.session_keys[state.cur_session_key_idx].is_some() + state.remote_session_id.is_some() && state.session_keys[state.cur_session_key_id as usize].is_some() } /// Get information about this session's security state. @@ -375,8 +392,9 @@ impl Session { /// and whether Kyber1024 was used. None is returned if the session isn't established. pub fn status(&self) -> Option<([u8; 16], i64, u64, bool)> { let state = self.state.read().unwrap(); - if let Some(key) = state.session_keys[state.cur_session_key_idx].as_ref() { - Some((key.secret_fingerprint, key.creation_time, key.ratchet_count, key.jedi)) + let key_id = state.cur_session_key_id; + if let Some(key) = state.session_keys[key_id as usize].as_ref() { + Some((key.secret_fingerprint, key.creation_time, self.ratchet_counts[key_id as usize].load(Ordering::Relaxed), key.jedi)) } else { None } @@ -389,7 +407,7 @@ impl Session { /// * `offer_metadata' - Any meta-data to include with initial key offers sent. /// * `mtu` - Current physical transport MTU /// * `current_time` - Current monotonic time in milliseconds - /// * `force_rekey` - Re-key the session now regardless of key aging (still subject to rate limiting) + /// * `assume_key_is_too_old` - Re-key the session now if the protocol allows it (subject to rate limits and whether it is the local party's turn to rekey) pub fn service( &self, app: &Application, @@ -397,23 +415,32 @@ impl Session { offer_metadata: &[u8], mtu: usize, current_time: i64, - force_rekey: bool, + assume_key_is_too_old: bool, ) { let state = self.state.read().unwrap(); - if (force_rekey - || state.session_keys[state.cur_session_key_idx] + let current_key_id = state.cur_session_key_id; + if (state.session_keys[state.cur_session_key_id as usize] .as_ref() - .map_or(true, |key| key.lifetime.should_rekey(self.send_counter.previous(), current_time))) + .map_or(true, |key| key.role == Role::Bob && (assume_key_is_too_old || key.lifetime.should_rekey(state.send_counters[current_key_id as usize].current(), current_time)))) && state .offer .as_ref() .map_or(true, |o| (current_time - o.creation_time) > Application::REKEY_RATE_LIMIT_MS) { if let Some(remote_s_public) = P384PublicKey::from_bytes(&self.remote_s_public_p384_bytes) { + // This routine handles sending a rekeying packet, resending lost rekeying packets, and resending lost initial offer packets + // The protocol has been designed such that initial rekeying packets are identical to resent rekeying packets, except for the counter, so we can reuse the same code for doing both + let has_existing_session = state.remote_session_id.is_some(); + // Mark the previous key as no longer being supported because it is about to be overwritten + // It should not be possible for a session to accidentally invalidate the key currently in use solely because of the read lock + self.receive_windows[(!current_key_id) as usize].invalidate(); let mut offer = None; + // The session will keep sending ephemeral offers until rekeying is successful if send_ephemeral_offer( &mut send, - self.send_counter.next(), + state.send_counters[current_key_id as usize].next(), + current_key_id, + has_existing_session && !current_key_id, self.id, state.remote_session_id, app.get_local_s_public_blob(), @@ -421,12 +448,9 @@ impl Session { &remote_s_public, &self.remote_s_public_blob_hash, &self.noise_ss, - state.session_keys[state.cur_session_key_idx].as_ref(), - if state.remote_session_id.is_some() { - Some(&self.header_check_cipher) - } else { - None - }, + state.session_keys[current_key_id as usize].as_ref(), + self.ratchet_counts[current_key_id as usize].load(Ordering::Relaxed), + if has_existing_session { Some(&self.header_check_cipher) } else { None }, mtu, current_time, &mut offer, @@ -434,7 +458,7 @@ impl Session { .is_ok() { drop(state); - let _ = self.state.write().unwrap().offer.replace(offer.unwrap()); + self.state.write().unwrap().offer = offer; } } } @@ -476,7 +500,9 @@ impl ReceiveContext { return Err(Error::InvalidPacket); } - let counter = u32::from_le(memory::load_raw(incoming_packet)); + let raw_counter = u32::from_le(memory::load_raw(&incoming_packet[4..8])); + let key_id = (raw_counter & 1) > 0; + let counter = raw_counter.wrapping_shr(1); let packet_type_fragment_info = u16::from_le(memory::load_raw(&incoming_packet[14..16])); let packet_type = (packet_type_fragment_info & 0x0f) as u8; let fragment_count = ((packet_type_fragment_info.wrapping_shr(4) + 1) as u8) & 63; @@ -485,46 +511,57 @@ impl ReceiveContext { if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64) { if let Some(session) = app.lookup_session(local_session_id) { - if verify_header_check_code(incoming_packet, &session.header_check_cipher) { - let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); - if fragment_count > 1 { - if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { - let mut defrag = session.defrag.lock().unwrap(); - let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { - drop(defrag); // release lock - return self.receive_complete( - app, - remote_address, - &mut send, - data_buf, - counter, - canonical_header.as_bytes(), - assembled_packet.as_ref(), - packet_type, - Some(session), - mtu, - current_time, - ); + // This is the only time ratchet_counts is ever accessed outside of a lock + // As such this read can be wrong, but that is incredibly unlikely since we are tracking the last two ratchet counts, and if it's wrong it just means we drop a packet that would have been dropped anyways for being too old or too new + let ratchet_count = session.ratchet_counts[key_id as usize].load(Ordering::SeqCst); + if verify_header_check_code(incoming_packet, ratchet_count, &session.header_check_cipher) { + if session.receive_windows[key_id as usize].message_received(counter) { + let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); + if fragment_count > 1 { + if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { + let mut defrag = session.defrag.lock().unwrap(); + // by using the counter + the key_id as the key we can prevent packet collisions, this only works if defrag hashes + let fragment_gather_array = defrag.get_or_create_mut(&raw_counter, || GatherArray::new(fragment_count)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { + drop(defrag); // release lock + return self.receive_complete( + app, + remote_address, + &mut send, + data_buf, + counter, + key_id, + canonical_header.as_bytes(), + assembled_packet.as_ref(), + packet_type, + Some(session), + mtu, + current_time, + ); + } + } else { + unlikely_branch(); + return Err(Error::InvalidPacket); } } else { - unlikely_branch(); - return Err(Error::InvalidPacket); + return self.receive_complete( + app, + remote_address, + &mut send, + data_buf, + counter, + key_id, + canonical_header.as_bytes(), + &[incoming_packet_buf], + packet_type, + Some(session), + mtu, + current_time, + ); } } else { - return self.receive_complete( - app, - remote_address, - &mut send, - data_buf, - counter, - canonical_header.as_bytes(), - &[incoming_packet_buf], - packet_type, - Some(session), - mtu, - current_time, - ); + unlikely_branch(); + return Ok(ReceiveResult::Ignored); } } else { unlikely_branch(); @@ -535,9 +572,11 @@ impl ReceiveContext { return Err(Error::UnknownLocalSessionId(local_session_id)); } } else { - unlikely_branch(); // we want data receive to be the priority branch, this is only occasionally used - - if verify_header_check_code(incoming_packet, &self.incoming_init_header_check_cipher) { + unlikely_branch(); // We want data receive to be the priority branch, this is only occasionally used + // Salt with a known value so new sessions can be established + // NOTE: this check is trivial to bypass by just replaying recorded packets + // This check isn't security critical so that is fine + if verify_header_check_code(incoming_packet, 1u64, &self.incoming_init_header_check_cipher) { let canonical_header = CanonicalHeader::make(SessionId::NIL, packet_type, counter); if fragment_count > 1 { let mut defrag = self.initial_offer_defrag.lock().unwrap(); @@ -550,6 +589,7 @@ impl ReceiveContext { &mut send, data_buf, counter, + key_id, canonical_header.as_bytes(), assembled_packet.as_ref(), packet_type, @@ -565,6 +605,7 @@ impl ReceiveContext { &mut send, data_buf, counter, + key_id, canonical_header.as_bytes(), &[incoming_packet_buf], packet_type, @@ -593,6 +634,7 @@ impl ReceiveContext { send: &mut SendFunction, data_buf: &'a mut [u8], counter: u32, + key_id: bool, canonical_header_bytes: &[u8; 12], fragments: &[Application::IncomingPacketBuffer], packet_type: u8, @@ -609,85 +651,59 @@ impl ReceiveContext { if packet_type <= PACKET_TYPE_NOP { if let Some(session) = session { let state = session.state.read().unwrap(); - for p in 0..KEY_HISTORY_SIZE { - let key_idx = (state.cur_session_key_idx + p) % KEY_HISTORY_SIZE; - if let Some(session_key) = state.session_keys[key_idx].as_ref() { - let mut c = session_key.get_receive_cipher(); - c.reset_init_gcm(canonical_header_bytes); - //////////////////////////////////////////////////////////////// - // packet decoding for post-noise transport - //////////////////////////////////////////////////////////////// + if let Some(session_key) = state.session_keys[key_id as usize].as_ref() { + let mut c = session_key.get_receive_cipher(); + c.reset_init_gcm(canonical_header_bytes); + //////////////////////////////////////////////////////////////// + // packet decoding for post-noise transport + //////////////////////////////////////////////////////////////// - let mut data_len = 0; + let mut data_len = 0; - // Decrypt fragments 0..N-1 where N is the number of fragments. - for f in fragments[..(fragments.len() - 1)].iter() { - let f = f.as_ref(); - debug_assert!(f.len() >= HEADER_SIZE); - let current_frag_data_start = data_len; - data_len += f.len() - HEADER_SIZE; - if data_len > data_buf.len() { - unlikely_branch(); - session_key.return_receive_cipher(c); - return Err(Error::DataBufferTooSmall); - } - c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); - } - - // Decrypt final fragment (or only fragment if not fragmented) + // Decrypt fragments 0..N-1 where N is the number of fragments. + for f in fragments[..(fragments.len() - 1)].iter() { + let f = f.as_ref(); + debug_assert!(f.len() >= HEADER_SIZE); let current_frag_data_start = data_len; - let last_fragment = fragments.last().unwrap().as_ref(); - if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { - unlikely_branch(); - return Err(Error::InvalidPacket); - } - data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); + data_len += f.len() - HEADER_SIZE; if data_len > data_buf.len() { unlikely_branch(); session_key.return_receive_cipher(c); return Err(Error::DataBufferTooSmall); } - let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE; - c.crypt( - &last_fragment[HEADER_SIZE..payload_end], - &mut data_buf[current_frag_data_start..data_len], - ); + c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); + } - let gcm_tag = &last_fragment[payload_end..]; - let aead_authentication_ok = c.finish_decrypt(gcm_tag); + // Decrypt final fragment (or only fragment if not fragmented) + let current_frag_data_start = data_len; + let last_fragment = fragments.last().unwrap().as_ref(); + if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { + unlikely_branch(); + return Err(Error::InvalidPacket); + } + data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); + if data_len > data_buf.len() { + unlikely_branch(); session_key.return_receive_cipher(c); + return Err(Error::DataBufferTooSmall); + } + let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE; + c.crypt( + &last_fragment[HEADER_SIZE..payload_end], + &mut data_buf[current_frag_data_start..data_len], + ); - if aead_authentication_ok { - if session_key.receive_window.message_received(counter) { - // Select this key as the new default if it's newer than the current key. - if p > 0 - && state.session_keys[state.cur_session_key_idx] - .as_ref() - .map_or(true, |old| old.creation_counter < session_key.creation_counter) - { - drop(state); - let mut state = session.state.write().unwrap(); - state.cur_session_key_idx = key_idx; - for i in 0..KEY_HISTORY_SIZE { - if i != key_idx { - if let Some(old_key) = state.session_keys[key_idx].as_ref() { - // Release pooled cipher memory from old keys. - old_key.receive_cipher_pool.lock().unwrap().clear(); - old_key.send_cipher_pool.lock().unwrap().clear(); - } - } - } - } + let gcm_tag = &last_fragment[payload_end..]; + let aead_authentication_ok = c.finish_decrypt(gcm_tag); + session_key.return_receive_cipher(c); - if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); - } else { - unlikely_branch(); - return Ok(ReceiveResult::Ok); - } + if aead_authentication_ok { + if session.receive_windows[key_id as usize].message_authenticated(counter) { + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); } else { unlikely_branch(); - return Ok(ReceiveResult::Ignored); + return Ok(ReceiveResult::Ok); } } } @@ -761,7 +777,7 @@ impl ReceiveContext { // Check rate limits. if let Some(session) = session.as_ref() { - if (current_time - session.state.read().unwrap().last_remote_offer) < Application::REKEY_RATE_LIMIT_MS { + if current_time < session.state.read().unwrap().last_remote_offer + Application::REKEY_RATE_LIMIT_MS { return Err(Error::RateLimited); } } else { @@ -842,7 +858,7 @@ impl ReceiveContext { // Perform checks and match ratchet key if there's an existing session, or gate (via host) and // then create new sessions. - let (new_session, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() { + let (new_session, reply_counter, new_key_id, ratchet_key) = if let Some(session) = session.as_ref() { // Existing session identity must match the one in this offer. if !secure_eq(&session.remote_s_public_blob_hash, &SHA384::hash(&alice_s_public_blob)) { return Err(Error::FailedAuthentication); @@ -851,50 +867,60 @@ impl ReceiveContext { // Match ratchet key fingerprint and fail if no match, which likely indicates an old offer packet. let alice_ratchet_key_fingerprint = alice_ratchet_key_fingerprint.unwrap(); let mut ratchet_key = None; - let mut last_ratchet_count = 0; let state = session.state.read().unwrap(); - for k in state.session_keys.iter() { - if let Some(k) = k.as_ref() { - if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { - ratchet_key = Some(k.ratchet_key.clone()); - last_ratchet_count = k.ratchet_count; - break; - } + if let Some(k) = state.session_keys[key_id as usize].as_ref() { + if k.role == Role::Bob { + // The local party is not allowed to be Bob twice in a row + // This prevents rekeying failure from both parties attempting to rekey at the same time + return Ok(ReceiveResult::Ignored); + } + if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { + ratchet_key = Some(k.ratchet_key.clone()); } } if ratchet_key.is_none() { return Ok(ReceiveResult::Ignored); // old packet? } - (None, ratchet_key, last_ratchet_count) + (None, state.send_counters[key_id as usize].next(), !key_id, ratchet_key) } else { + if key_id != false { + // All new sessions must start with key_id 0 + // This has no security implications, it just makes programming the initial offer easier + return Ok(ReceiveResult::Ignored); + } if let Some((new_session_id, psk, associated_object)) = app.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata) { let header_check_cipher = Aes::new( kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::(), ); + let send_counter = Counter::new(); + let reply_counter = send_counter.next(); ( Some(Session:: { id: new_session_id, application_data: associated_object, - send_counter: Counter::new(), - psk, - noise_ss, + ratchet_counts: [AtomicU64::new(1), AtomicU64::new(0)], header_check_cipher, + receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], state: RwLock::new(SessionMutableState { + send_counters: [send_counter, Counter::new()], remote_session_id: Some(alice_session_id), - session_keys: [None, None, None], - cur_session_key_idx: 0, + session_keys: [None, None],//this is the only value which will be writen later + cur_session_key_id: false, offer: None, last_remote_offer: current_time, }), + psk, + noise_ss, remote_s_public_blob_hash: SHA384::hash(&alice_s_public_blob), remote_s_public_p384_bytes: alice_s_public.as_bytes().clone(), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), }), - None, - 0, + reply_counter, + false, + None ) } else { return Err(Error::NewSessionRejected); @@ -955,7 +981,6 @@ impl ReceiveContext { //////////////////////////////////////////////////////////////// let mut reply_buf = [0_u8; KEX_BUF_LEN]; - let reply_counter = session.send_counter.next(); let mut idx = HEADER_SIZE; idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?; @@ -987,6 +1012,7 @@ impl ReceiveContext { PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter, + key_id, )?; let reply_canonical_header = CanonicalHeader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32()); @@ -1024,32 +1050,48 @@ impl ReceiveContext { ); idx = safe_write_all(&mut reply_buf, idx, &hmac)?; let packet_end = idx; + if session.receive_windows[key_id as usize].message_authenticated(counter) { + // Initial key offers should only check this if this is a rekey + let session_key = SessionKey::new( + session_key, + Role::Bob, + current_time, + hybrid_kk.is_some(), + ); - let session_key = SessionKey::new( - session_key, - Role::Bob, - current_time, - reply_counter, - counter, - last_ratchet_count + 1, - hybrid_kk.is_some(), - ); + let mut state = session.state.write().unwrap(); + let _ = state.session_keys[new_key_id as usize].replace(session_key); + let _ = state.remote_session_id.replace(alice_session_id); + let ratchet_count = session.ratchet_counts[key_id as usize].load(Ordering::SeqCst); + if state.cur_session_key_id != new_key_id {//this prevents anything from being reset twice if the key offer was made twice + debug_assert!(new_key_id != key_id); + // receive_windows only has race conditions with the counter of the remote party. It is theoretically possible that the local host receives counters under new_key_id while the receive_window is still in the process of resetting, but this is very unlikely. If it does happen, two things could happen: + // 1) The received counter is less than what is currently stored in the window, so a valid packet would be rejected. + // 2) The received counter is greater than what is currently stored in the window, so a valid packet would be accepted *but* its counter is deleted from the window so it can be replayed. + // To prevent these race conditions, we only update the ratchet_count for salting the check code after the window has reset. So if a counter passes the initial check code: it either means the thread sees ratchet count has been update therefore it sees receive_window has been reset (due to memory orderings), or it means a rare accidental check code collision has occurred. + session.receive_windows[new_key_id as usize].reset_for_new_key_offer(); + session.ratchet_counts[new_key_id as usize].fetch_add(2, Ordering::SeqCst); - let mut state = session.state.write().unwrap(); - let _ = state.remote_session_id.replace(alice_session_id); - let next_key_ptr = (state.cur_session_key_idx + 1) % KEY_HISTORY_SIZE; - let _ = state.session_keys[next_key_ptr].replace(session_key); - drop(state); + // If the following wasn't done inside a lock, a theoretical race condition exists where a thread uses the new key id before the counter is reset, or worse: a thread has held onto the previous key id equal to new_key_id, and attempts to use the reset counter for encryption. + // For this reason do not access send_counters without holding the read lock. + state.cur_session_key_id = new_key_id; + state.send_counters[new_key_id as usize].reset_for_new_key_offer(); + } + drop(state); - // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. + // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. - send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher); + send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, ratchet_count, &session.header_check_cipher); - if new_session.is_some() { - return Ok(ReceiveResult::OkNewSession(new_session.unwrap())); + if let Some(new_session) = new_session { + return Ok(ReceiveResult::OkNewSession(new_session)); + } else { + return Ok(ReceiveResult::Ok); + } } else { - return Ok(ReceiveResult::Ok); + return Ok(ReceiveResult::Ignored); } + } PACKET_TYPE_KEY_COUNTER_OFFER => { @@ -1103,8 +1145,9 @@ impl ReceiveContext { parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?; // Check that this is a counter offer to the original offer we sent. - if !offer.id.eq(offer_id) { + if !(offer.id.eq(offer_id)) { return Ok(ReceiveResult::Ignored); + } // Kyber1024 key agreement if enabled. @@ -1124,12 +1167,17 @@ impl ReceiveContext { let mut session_key = noise_ik_complete; // Mix ratchet key from previous session key (if any) and Kyber1024 hybrid shared key (if any). - let last_ratchet_count = if bob_ratchet_key_id.is_some() && offer.ratchet_key.is_some() { - session_key = Secret(hmac_sha512(offer.ratchet_key.as_ref().unwrap().as_bytes(), session_key.as_bytes())); - offer.ratchet_count - } else { - 0 - }; + // We either have a session, in which case they should have supplied a ratchet key fingerprint, or + // we don't and they should not have supplied one. + if let Some(cur_session_key) = state.session_keys[key_id as usize].as_ref() { + if bob_ratchet_key_id.is_some() { + session_key = Secret(hmac_sha512(cur_session_key.ratchet_key.as_bytes(), session_key.as_bytes())); + } else { + return Err(Error::FailedAuthentication); + } + } else if bob_ratchet_key_id.is_some() { + return Err(Error::FailedAuthentication); + } if let Some(hybrid_kk) = hybrid_kk.as_ref() { session_key = Secret(hmac_sha512(hybrid_kk.as_bytes(), session_key.as_bytes())); } @@ -1145,54 +1193,39 @@ impl ReceiveContext { ) { return Err(Error::FailedAuthentication); } + if session.receive_windows[key_id as usize].message_authenticated(counter) { + // Alice has now completed and validated the full hybrid exchange. - // Alice has now completed and validated the full hybrid exchange. + let session_key = SessionKey::new( + session_key, + Role::Alice, + current_time, + hybrid_kk.is_some(), + ); - let reply_counter = session.send_counter.next(); - let session_key = SessionKey::new( - session_key, - Role::Alice, - current_time, - reply_counter, - counter, - last_ratchet_count + 1, - hybrid_kk.is_some(), - ); + let new_key_id = offer.key_id; + drop(state); + let mut state = session.state.write().unwrap(); + let _ = state.remote_session_id.replace(bob_session_id); + let _ = state.session_keys[new_key_id as usize].replace(session_key); + if state.cur_session_key_id != new_key_id { + // When an brand new key offer is sent, it is sent using the new_key_id==false counter, we cannot reset it in that case. + // NOTE: the following code should be properly threadsafe, see the large comment above at the end of KEY_OFFER decoding for more info + session.receive_windows[new_key_id as usize].reset_for_new_key_offer(); + let _ = session.ratchet_counts[new_key_id as usize].fetch_add(2, Ordering::SeqCst); + state.cur_session_key_id = new_key_id; + state.send_counters[new_key_id as usize].reset_for_new_key_offer(); + } + let _ = state.offer.take(); - //////////////////////////////////////////////////////////////// - // packet encoding for post-noise session start ack - //////////////////////////////////////////////////////////////// - - let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; - create_packet_header( - &mut reply_buf, - HEADER_SIZE + AES_GCM_TAG_SIZE, - mtu, - PACKET_TYPE_NOP, - bob_session_id.into(), - reply_counter, - )?; - - let mut c = session_key.get_send_cipher(reply_counter)?; - c.reset_init_gcm( - CanonicalHeader::make(bob_session_id.into(), PACKET_TYPE_NOP, reply_counter.to_u32()).as_bytes(), - ); - let gcm_tag = c.finish_encrypt(); - safe_write_all(&mut reply_buf, HEADER_SIZE, &gcm_tag)?; - session_key.return_send_cipher(c); - - set_header_check_code(&mut reply_buf, &session.header_check_cipher); - send(&mut reply_buf); - - drop(state); - let mut state = session.state.write().unwrap(); - let _ = state.remote_session_id.replace(bob_session_id); - let next_key_idx = (state.cur_session_key_idx + 1) % KEY_HISTORY_SIZE; - let _ = state.session_keys[next_key_idx].replace(session_key); - let _ = state.offer.take(); - - return Ok(ReceiveResult::Ok); + return Ok(ReceiveResult::Ok); + } else { + return Ok(ReceiveResult::Ignored); + } } + } else { + unlikely_branch(); + return Err(Error::SessionNotEstablished); } // Just ignore counter-offers that are out of place. They probably indicate that this side @@ -1207,9 +1240,12 @@ impl ReceiveContext { } /// Create an send an ephemeral offer, populating ret_ephemeral_offer on success. +/// If there is no current session key set `current_key_id == new_key_id == false` fn send_ephemeral_offer( send: &mut SendFunction, counter: CounterValue, + current_key_id: bool, + new_key_id: bool, alice_session_id: SessionId, bob_session_id: Option, alice_s_public_blob: &[u8], @@ -1218,6 +1254,7 @@ fn send_ephemeral_offer( bob_s_public_blob_hash: &[u8], noise_ss: &Secret<48>, current_key: Option<&SessionKey>, + ratchet_count: u64, header_check_cipher: Option<&Aes>, // None to use one based on the recipient's public key for initial contact mtu: usize, current_time: i64, @@ -1236,13 +1273,6 @@ fn send_ephemeral_offer( None }; - // Get ratchet key for current key if one exists. - let (ratchet_key, ratchet_count) = if let Some(current_key) = current_key { - (Some(current_key.ratchet_key.clone()), current_key.ratchet_count) - } else { - (None, 0) - }; - // Random ephemeral offer ID let id: [u8; 16] = random::get_bytes_secure(); @@ -1257,7 +1287,6 @@ fn send_ephemeral_offer( let mut idx = HEADER_SIZE; idx = safe_write_all(&mut packet_buf, idx, &[SESSION_PROTOCOL_VERSION])?; - //TODO: check this, the below line is supposed to be the blob, not just the key, right? idx = safe_write_all(&mut packet_buf, idx, alice_e_keypair.public_key_bytes())?; let plaintext_end = idx; @@ -1273,9 +1302,9 @@ fn send_ephemeral_offer( } else { idx = safe_write_all(&mut packet_buf, idx, &[HYBRID_KEY_TYPE_NONE])?; } - if let Some(ratchet_key) = ratchet_key.as_ref() { + if let Some(current_key) = current_key { idx = safe_write_all(&mut packet_buf, idx, &[0x01])?; - idx = safe_write_all(&mut packet_buf, idx, &public_fingerprint_of_secret(ratchet_key.as_bytes())[..16])?; + idx = safe_write_all(&mut packet_buf, idx, &public_fingerprint_of_secret(current_key.ratchet_key.as_bytes())[..16])?; } else { idx = safe_write_all(&mut packet_buf, idx, &[0x00])?; } @@ -1295,6 +1324,7 @@ fn send_ephemeral_offer( PACKET_TYPE_INITIAL_KEY_OFFER, bob_session_id, counter, + current_key_id, )?; let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32()); @@ -1336,21 +1366,21 @@ fn send_ephemeral_offer( let packet_end = idx; if let Some(header_check_cipher) = header_check_cipher { - send_with_fragmentation(send, &mut packet_buf[..packet_end], mtu, header_check_cipher); + send_with_fragmentation(send, &mut packet_buf[..packet_end], mtu, ratchet_count, header_check_cipher); } else { send_with_fragmentation( send, &mut packet_buf[..packet_end], mtu, + ratchet_count, &Aes::new(kbkdf512(&bob_s_public_blob_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::()), ); } *ret_ephemeral_offer = Some(EphemeralOffer { id, + key_id: new_key_id, creation_time: current_time, - ratchet_count, - ratchet_key, ss_key, alice_e_keypair, alice_hk_keypair, @@ -1367,6 +1397,7 @@ fn create_packet_header( packet_type: u8, recipient_session_id: SessionId, counter: CounterValue, + key_id: bool ) -> Result<(), Error> { let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize; @@ -1379,13 +1410,14 @@ fn create_packet_header( if fragment_count <= MAX_FRAGMENTS { // Header indexed by bit/byte: - // [0-31]/[0-3] counter - // [32-63]/[4-7] header check code (computed later) + // [0-31]/[0-3] header check code (computed later) + // [32-32]/[4-] key id + // [33-63]/[-7] counter // [64-111]/[8-13] recipient's session ID (unique on their side) // [112-115]/[14-] packet type (0-15) // [116-121]/[-] number of fragments (0..63 for 1..64 fragments total) // [122-127]/[-15] fragment number (0, 1, 2, ...) - memory::store_raw((counter.to_u32() as u64).to_le(), header_destination_buffer); + memory::store_raw((counter.to_u32().wrapping_shl(1) | (key_id as u32)).to_le(), &mut header_destination_buffer[4..]); memory::store_raw( (u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)) .to_le(), @@ -1403,6 +1435,7 @@ fn send_with_fragmentation( send: &mut SendFunction, packet: &mut [u8], mtu: usize, + ratchet_count: u64, header_check_cipher: &Aes, ) { let packet_len = packet.len(); @@ -1411,7 +1444,7 @@ fn send_with_fragmentation( let mut header: [u8; 16] = packet[..HEADER_SIZE].try_into().unwrap(); loop { let fragment = &mut packet[fragment_start..fragment_end]; - set_header_check_code(fragment, header_check_cipher); + set_header_check_code(fragment, ratchet_count, header_check_cipher); send(fragment); if fragment_end < packet_len { debug_assert!(header[15].wrapping_shr(2) < 63); @@ -1427,19 +1460,30 @@ fn send_with_fragmentation( } /// Set 32-bit header check code, used to make fragmentation mechanism robust. -fn set_header_check_code(packet: &mut [u8], header_check_cipher: &Aes) { +fn set_header_check_code(packet: &mut [u8], ratchet_count: u64, header_check_cipher: &Aes) { debug_assert!(packet.len() >= MIN_PACKET_SIZE); - let mut check_code = 0u128.to_ne_bytes(); - header_check_cipher.encrypt_block(&packet[8..24], &mut check_code); - packet[4..8].copy_from_slice(&check_code[..4]); + //4 bytes is the ratchet key + //12 bytes is the header we want to verify + let mut header_mac = 0u128.to_le_bytes(); + memory::store_raw((ratchet_count as u32).to_le_bytes(), &mut header_mac[0..4]); + header_mac[4..16].copy_from_slice(&packet[4..16]); + header_check_cipher.encrypt_block_in_place(&mut header_mac); + + packet[..4].copy_from_slice(&header_mac[..4]); } /// Verify 32-bit header check code. -fn verify_header_check_code(packet: &[u8], header_check_cipher: &Aes) -> bool { +/// This is not nearly enough entropy to be cryptographically secure, it only is meant for making DOS attacks very hard +fn verify_header_check_code(packet: &[u8], ratchet_count: u64, header_check_cipher: &Aes) -> bool { debug_assert!(packet.len() >= MIN_PACKET_SIZE); - let mut header_mac = 0u128.to_ne_bytes(); - header_check_cipher.encrypt_block(&packet[8..24], &mut header_mac); - memory::load_raw::(&packet[4..8]) == memory::load_raw::(&header_mac) + //4 bytes is the ratchet key + //12 bytes is the header we want to verify + let mut header_mac = 0u128.to_le_bytes(); + memory::store_raw((ratchet_count as u32).to_le_bytes(), &mut header_mac[0..4]); + header_mac[4..16].copy_from_slice(&packet[4..16]); + header_check_cipher.encrypt_block_in_place(&mut header_mac); + + memory::load_raw::(&packet[..4]) == memory::load_raw::(&header_mac[..4]) } /// Parse KEY_OFFER and KEY_COUNTER_OFFER starting after the unencrypted public key part. @@ -1452,15 +1496,15 @@ fn parse_dec_key_offer_after_header( let mut session_id_buf = 0_u64.to_ne_bytes(); session_id_buf[..SESSION_ID_SIZE].copy_from_slice(safe_read_exact(&mut p, SESSION_ID_SIZE)?); - let alice_session_id = SessionId::new_from_u64(u64::from_le_bytes(session_id_buf)).ok_or(Error::InvalidPacket)?; + let remote_session_id = SessionId::new_from_u64(u64::from_le_bytes(session_id_buf)).ok_or(Error::InvalidPacket)?; - let alice_s_public_blob_len = varint_safe_read(&mut p)?; - let alice_s_public_blob = safe_read_exact(&mut p, alice_s_public_blob_len as usize)?; + let remote_s_public_blob_len = varint_safe_read(&mut p)?; + let remote_s_public_blob = safe_read_exact(&mut p, remote_s_public_blob_len as usize)?; - let alice_metadata_len = varint_safe_read(&mut p)?; - let alice_metadata = safe_read_exact(&mut p, alice_metadata_len as usize)?; + let remote_metadata_len = varint_safe_read(&mut p)?; + let remote_metadata = safe_read_exact(&mut p, remote_metadata_len as usize)?; - let alice_hk_public_raw = match safe_read_exact(&mut p, 1)?[0] { + let remote_hk_public_raw = match safe_read_exact(&mut p, 1)?[0] { HYBRID_KEY_TYPE_KYBER1024 => { if packet_type == PACKET_TYPE_INITIAL_KEY_OFFER { safe_read_exact(&mut p, pqc_kyber::KYBER_PUBLICKEYBYTES)? @@ -1474,7 +1518,7 @@ fn parse_dec_key_offer_after_header( if p.is_empty() { return Err(Error::InvalidPacket); } - let alice_ratchet_key_fingerprint = if safe_read_exact(&mut p, 1)?[0] == 0x01 { + let remote_ratchet_key_fingerprint = if safe_read_exact(&mut p, 1)?[0] == 0x01 { Some(safe_read_exact(&mut p, 16)?) } else { None @@ -1482,11 +1526,11 @@ fn parse_dec_key_offer_after_header( Ok(( offer_id, //always 16 bytes - alice_session_id, - alice_s_public_blob, - alice_metadata, - alice_hk_public_raw, - alice_ratchet_key_fingerprint, //always 16 bytes + remote_session_id, + remote_s_public_blob, + remote_metadata, + remote_hk_public_raw, + remote_ratchet_key_fingerprint, //always 16 bytes )) } @@ -1496,9 +1540,6 @@ impl SessionKey { key: Secret<64>, role: Role, current_time: i64, - current_counter: CounterValue, - remote_counter: u32, - ratchet_count: u64, jedi: bool, ) -> Self { let a2b: Secret = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n_clone(); @@ -1510,16 +1551,14 @@ impl SessionKey { Self { secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(), creation_time: current_time, - creation_counter: current_counter, - receive_window: CounterWindow::new(remote_counter), - lifetime: KeyLifetime::new(current_counter, current_time), + lifetime: KeyLifetime::new(current_time), ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING), receive_key, send_key, receive_cipher_pool: Mutex::new(Vec::with_capacity(2)), send_cipher_pool: Mutex::new(Vec::with_capacity(2)), - ratchet_count, jedi, + role, } } @@ -1559,12 +1598,9 @@ impl SessionKey { } impl KeyLifetime { - fn new(current_counter: CounterValue, current_time: i64) -> Self { + fn new(current_time: i64) -> Self { Self { - rekey_at_or_after_counter: current_counter - .counter_value_after_uses(REKEY_AFTER_USES) - .counter_value_after_uses((random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER) as u64), - hard_expire_at_counter: current_counter.counter_value_after_uses(EXPIRE_AFTER_USES), + rekey_at_or_after_counter: REKEY_AFTER_USES as u32 + (random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER), rekey_at_or_after_timestamp: current_time + REKEY_AFTER_TIME_MS + (random::next_u32_secure() % REKEY_AFTER_TIME_MS_MAX_JITTER) as i64, @@ -1572,11 +1608,11 @@ impl KeyLifetime { } fn should_rekey(&self, counter: CounterValue, current_time: i64) -> bool { - counter >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp + counter.to_u32() >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp } fn expired(&self, counter: CounterValue) -> bool { - counter >= self.hard_expire_at_counter + counter.to_u32() >= EXPIRE_AFTER_USES as u32 } }