From 73e6be7959210cbc0bfe183b04c618a966bfef38 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Fri, 6 Jan 2023 19:51:09 -0500 Subject: [PATCH] Re-implement most of what Monica originally did, but with some variations: - Went back to a single session counter instead of two counter states - Went to a full 64-bit counter in the header as recommended by Noise, turns out there is a good reason. It simplifies everything. - Implemented Monica's simpler stateless counter window algorithm, but also only one on the whole session. - Simplified some counter logic generally. - Header check codes are temporarily gone, coming back in a different form. This is being committed "on top" of what was there instead of reverting the old commits to preserve the history. --- zssp/src/applicationlayer.rs | 16 +- zssp/src/constants.rs | 16 +- zssp/src/counter.rs | 88 ---- zssp/src/error.rs | 26 +- zssp/src/lib.rs | 1 - zssp/src/tests.rs | 71 +--- zssp/src/zssp.rs | 763 ++++++++++++++++------------------- 7 files changed, 384 insertions(+), 597 deletions(-) delete mode 100644 zssp/src/counter.rs diff --git a/zssp/src/applicationlayer.rs b/zssp/src/applicationlayer.rs index 80810944b..af6dab6a6 100644 --- a/zssp/src/applicationlayer.rs +++ b/zssp/src/applicationlayer.rs @@ -64,19 +64,9 @@ pub trait ApplicationLayer: Sized { /// Check whether a new session should be accepted. /// - /// On success a tuple of local session ID, psk, and associated object is returned. - /// Set psk to all zeros if one is not in use with the remote party. - /// - /// When `accept_new_session` is called, `remote_static_public` and `remote_metadata` have not yet been - /// authenticated. As such avoid mutating state until OkNewSession(Session) is returned, as the connection - /// may be adversarial. - /// - /// When `remote_static_public` and `remote_metadata` are eventually authenticated, the zssp protocol cannot - /// guarantee that they are unique, i.e. `remote_static_public` and `remote_metadata` may be duplicates from - /// an old attempt to establish a session, and may even have been replayed by an adversary. If your use-case - /// needs uniqueness for reliability or security, consider either including a timestamp in the metadata, or - /// sending the metadata as an extra transport packet after the session is fully established. - /// It is guaranteed they will be unique for at least the lifetime of the returned session. + /// On success a tuple of local session ID, static secret, and associated object is returned. The + /// static secret is whatever results from agreement between the local and remote static public + /// keys. fn accept_new_session( &self, receive_context: &ReceiveContext, diff --git a/zssp/src/constants.rs b/zssp/src/constants.rs index fc3b6600b..20bd3b345 100644 --- a/zssp/src/constants.rs +++ b/zssp/src/constants.rs @@ -27,20 +27,15 @@ pub(crate) const MAX_FRAGMENTS: usize = 48; // hard protocol max: 63 pub(crate) const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc. /// Start attempting to rekey after a key has been used to send packets this many times. -/// -/// This is 1/4 the NIST recommended maximum and 1/8 the absolute limit where u32 wraps. -/// As such it should leave plenty of margin against nearing key reuse bounds w/AES-GCM. +/// This is 1/4 the recommended NIST limit for AES-GCM key lifetimes under most conditions. pub(crate) const REKEY_AFTER_USES: u64 = 536870912; -/// Maximum random jitter to add to rekey-after usage count. -pub(crate) const REKEY_AFTER_USES_MAX_JITTER: u32 = 1048576; - /// Hard expiration after this many uses. /// /// Use of the key beyond this point is prohibited. If we reach this number of key uses /// the key will be destroyed in memory and the session will cease to function. A hard /// error is also generated. -pub(crate) const EXPIRE_AFTER_USES: u64 = (u32::MAX - 1024) as u64; +pub(crate) const EXPIRE_AFTER_USES: u64 = REKEY_AFTER_USES * 2; /// Start attempting to rekey after a key has been in use for this many milliseconds. pub(crate) const REKEY_AFTER_TIME_MS: i64 = 1000 * 60 * 60; // 1 hour @@ -75,12 +70,13 @@ pub(crate) const HMAC_SIZE: usize = 48; pub(crate) const SESSION_ID_SIZE: usize = 6; /// Maximum difference between out-of-order incoming packet counters, and size of deduplication buffer. -pub(crate) const COUNTER_MAX_ALLOWED_OOO: usize = 16; +pub(crate) const COUNTER_MAX_DELTA: usize = 16; // Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use pub(crate) const PACKET_TYPE_DATA: u8 = 0; -pub(crate) const PACKET_TYPE_INITIAL_KEY_OFFER: u8 = 1; // "alice" -pub(crate) const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 2; // "bob" +pub(crate) const PACKET_TYPE_NOP: u8 = 1; +pub(crate) const PACKET_TYPE_INITIAL_KEY_OFFER: u8 = 2; // "alice" +pub(crate) const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 3; // "bob" // Key usage labels for sub-key derivation using NIST-style KBKDF (basically just HMAC KDF). pub(crate) const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; // HMAC-SHA384 authentication for key exchanges diff --git a/zssp/src/counter.rs b/zssp/src/counter.rs deleted file mode 100644 index 2eaa558cc..000000000 --- a/zssp/src/counter.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::sync::atomic::{Ordering, AtomicU32}; - -use crate::constants::COUNTER_MAX_ALLOWED_OOO; - -/// Outgoing packet counter with strictly ordered atomic semantics. -/// Count sequence always starts at 1u32, it must never be allowed to overflow -/// -#[repr(transparent)] -pub(crate) struct Counter(AtomicU32); - -impl Counter { - #[inline(always)] - pub fn new() -> Self { - // Using a random value has no security implication. Zero would be fine. This just - // helps randomize packet contents a bit. - Self(AtomicU32::new(1u32)) - } - - #[inline(always)] - pub fn reset_for_new_key_offer(&self) { - self.0.store(1u32, Ordering::SeqCst); - } - - /// Get the value most recently used to send a packet. - #[inline(always)] - pub fn current(&self) -> CounterValue { - CounterValue(self.0.load(Ordering::SeqCst)) - } - - /// Get a counter value for the next packet being sent. - #[inline(always)] - pub fn next(&self) -> CounterValue { - CounterValue(self.0.fetch_add(1, Ordering::SeqCst)) - } -} - -/// A value of the outgoing packet counter. -#[repr(transparent)] -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub(crate) struct CounterValue(u32); - -impl CounterValue { - /// Get the 32-bit counter value used to build packets. - #[inline(always)] - pub fn to_u32(&self) -> u32 { - self.0 as u32 - } -} - -/// Incoming packet deduplication and replay protection window. -pub(crate) struct CounterWindow([AtomicU32; COUNTER_MAX_ALLOWED_OOO]); - -impl CounterWindow { - #[inline(always)] - pub fn new() -> Self { - Self(std::array::from_fn(|_| AtomicU32::new(0))) - } - ///this creates a counter window that rejects everything - pub fn new_invalid() -> Self { - Self(std::array::from_fn(|_| AtomicU32::new(u32::MAX))) - } - pub fn reset_for_new_key_offer(&self) { - for i in 0..COUNTER_MAX_ALLOWED_OOO { - self.0[i].store(0, Ordering::SeqCst) - } - } - pub fn invalidate(&self) { - for i in 0..COUNTER_MAX_ALLOWED_OOO { - self.0[i].store(u32::MAX, Ordering::SeqCst) - } - } - - #[inline(always)] - pub fn message_received(&self, received_counter_value: u32) -> bool { - let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize; - //it is highly likely this can be a Relaxed ordering, but I want someone else to confirm that is the case - let pre = self.0[idx].load(Ordering::SeqCst); - return pre < received_counter_value; - } - - #[inline(always)] - pub fn message_authenticated(&self, received_counter_value: u32) -> bool { - //if a valid message is received but one of its fragments was lost, it can technically be replayed. However since the message is incomplete, we know it still exists in the gather array, so the gather array will deduplicate the replayed message. Even if the gather array gets flushed, that flush still effectively deduplicates the replayed message. - //eventually the counter of that kind of message will be too OOO to be accepted anymore so it can't be used to DOS. - let idx = (received_counter_value % COUNTER_MAX_ALLOWED_OOO as u32) as usize; - return self.0[idx].fetch_max(received_counter_value, Ordering::SeqCst) < received_counter_value; - } -} diff --git a/zssp/src/error.rs b/zssp/src/error.rs index 70f1f42f6..20ab254cc 100644 --- a/zssp/src/error.rs +++ b/zssp/src/error.rs @@ -1,43 +1,39 @@ use crate::sessionid::SessionId; pub enum Error { - /// The packet was addressed to an unrecognized local session (should usually be ignored). + /// The packet was addressed to an unrecognized local session (should usually be ignored) UnknownLocalSessionId(SessionId), - /// Packet was not well formed. + /// Packet was not well formed InvalidPacket, - /// An invalid parameter was supplied to the function. + /// An invalid parameter was supplied to the function InvalidParameter, - /// Packet failed one or more authentication (MAC) checks. - /// - /// **IMPORTANT**: Do not reply to a peer who has sent a packet that has failed authentication. - /// Any response at all will leak to an attacker what authentication step their packet failed at - /// (timing attack), which lowers the total authentication entropy they have to brute force. - /// There is a safe way to reply if absolutely necessary; by sending the reply back after a constant - /// amount of time, but this is very difficult to get correct. + /// Packet failed one or more authentication (MAC) checks + /// IMPORTANT: Do not reply to a peer who has sent a packet that has failed authentication. Any response at all will leak to an attacker what authentication step their packet failed at (timing attack), which lowers the total authentication entropy they have to brute force. + /// There is a safe way to reply if absolutely necessary, by sending the reply back after a constant amount of time, but this is difficult to get correct. FailedAuthentication, /// New session was rejected by the application layer. NewSessionRejected, - /// Rekeying failed and session secret has reached its hard usage count limit. + /// Rekeying failed and session secret has reached its hard usage count limit MaxKeyLifetimeExceeded, - /// Attempt to send using session without established key. + /// Attempt to send using session without established key SessionNotEstablished, /// Packet ignored by rate limiter. RateLimited, - /// The other peer specified an unrecognized protocol version. + /// The other peer specified an unrecognized protocol version UnknownProtocolVersion, - /// Caller supplied data buffer is too small to receive data. + /// Caller supplied data buffer is too small to receive data DataBufferTooSmall, - /// Data object is too large to send, even with fragmentation. + /// Data object is too large to send, even with fragmentation DataTooLarge, /// An unexpected buffer overrun occured while attempting to encode or decode a packet. diff --git a/zssp/src/lib.rs b/zssp/src/lib.rs index ed3068740..46e634cfa 100644 --- a/zssp/src/lib.rs +++ b/zssp/src/lib.rs @@ -1,5 +1,4 @@ mod applicationlayer; -mod counter; mod error; mod sessionid; mod tests; diff --git a/zssp/src/tests.rs b/zssp/src/tests.rs index d0f65efbc..18e1c772a 100644 --- a/zssp/src/tests.rs +++ b/zssp/src/tests.rs @@ -9,7 +9,6 @@ mod tests { use zerotier_crypto::secret::Secret; use zerotier_utils::hex; - use crate::counter::CounterWindow; use crate::*; use constants::*; @@ -17,7 +16,7 @@ mod tests { local_s: P384KeyPair, local_s_hash: [u8; 48], psk: Secret<64>, - session: Mutex>>>, + session: Mutex>>>>, session_id_counter: Mutex, queue: Mutex>>, key_id: Mutex<[u8; 16]>, @@ -43,9 +42,9 @@ mod tests { } } - impl ApplicationLayer for TestHost { + impl ApplicationLayer for Box { type Data = u32; - type SessionRef<'a> = Arc>; + type SessionRef<'a> = Arc>>; type IncomingPacketBuffer = Vec; type RemoteAddress = u32; @@ -98,10 +97,10 @@ mod tests { let mut psk: Secret<64> = Secret::default(); random::fill_bytes_secure(&mut psk.0); - let alice_host = TestHost::new(psk.clone(), "alice", "bob"); - let bob_host = TestHost::new(psk.clone(), "bob", "alice"); - let alice_rc: ReceiveContext = ReceiveContext::new(&alice_host); - let bob_rc: ReceiveContext = ReceiveContext::new(&bob_host); + let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob")); + let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice")); + let alice_rc: Box>> = Box::new(ReceiveContext::new(&alice_host)); + let bob_rc: Box>> = Box::new(ReceiveContext::new(&bob_host)); //println!("zssp: size of session (bytes): {}", std::mem::size_of::>>()); @@ -195,8 +194,8 @@ mod tests { "zssp: new key at {}: fingerprint {} ratchet {} kyber {}", host.this_name, hex::to_string(key_id.as_ref()), - security_info.2, - security_info.3 + security_info.1, + security_info.2 ); } } @@ -209,7 +208,7 @@ mod tests { ) .is_ok()); } - if (test_loop % 8) == 0 && test_loop >= 8 { + if (test_loop % 8) == 0 && test_loop >= 8 && host.this_name.eq("alice") { session.service(host, send_to_other, &[], mtu_buffer.len(), test_loop as i64, true); } } @@ -217,54 +216,4 @@ mod tests { } } } - - - #[inline(always)] - pub fn xorshift64(x: &mut u64) -> u32 { - *x ^= x.wrapping_shl(13); - *x ^= x.wrapping_shr(7); - *x ^= x.wrapping_shl(17); - *x as u32 - } - #[test] - fn counter_window() { - let mut rng = 844632; - let mut counter = 1u32; - let mut history = Vec::new(); - - let w = CounterWindow::new(); - for _i in 0..1000000 { - let p = xorshift64(&mut rng) as f32/(u32::MAX as f32 + 1.0); - let c; - if p < 0.5 { - let r = xorshift64(&mut rng); - c = counter + (r%(COUNTER_MAX_ALLOWED_OOO - 1) as u32 + 1); - } else if p < 0.8 { - counter = counter + (1); - c = counter; - } else if p < 0.9 { - if history.len() > 0 { - let idx = xorshift64(&mut rng) as usize%history.len(); - let c = history[idx]; - assert!(!w.message_authenticated(c)); - } - continue; - } else if p < 0.999 { - c = xorshift64(&mut rng); - w.message_received(c); - continue; - } else { - w.reset_for_new_key_offer(); - counter = 1u32; - history = Vec::new(); - continue; - } - if history.contains(&c) { - assert!(!w.message_authenticated(c)); - } else { - assert!(w.message_authenticated(c)); - history.push(c); - } - } - } } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 5c0b8d804..a97c4673a 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -3,7 +3,7 @@ // ZSSP: ZeroTier Secure Session Protocol // FIPS compliant Noise_IK with Jedi powers and built-in attack-resistant large payload (fragmentation) support. -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Mutex, RwLock}; use zerotier_crypto::aes::{Aes, AesGcm}; @@ -21,7 +21,6 @@ use zerotier_utils::varint; use crate::applicationlayer::ApplicationLayer; use crate::constants::*; -use crate::counter::{Counter, CounterValue, CounterWindow}; use crate::error::Error; use crate::sessionid::SessionId; @@ -40,10 +39,7 @@ pub enum ReceiveResult<'a, H: ApplicationLayer> { /// The session will have already been gated by the accept_new_session() method in ApplicationLayer. OkNewSession(Session), - /// Packet superficially appears valid but was ignored e.g. as a duplicate. - /// - /// **IMPORTANT**: This packet was not authenticated, so for the most part treat this the same - /// as an `Error::FailedAuthentication`. + /// Packet appears valid but was ignored e.g. as a duplicate. Ignored, } @@ -51,9 +47,6 @@ pub enum ReceiveResult<'a, H: ApplicationLayer> { /// /// Note that the role can switch through the course of a session. It's the side that most recently /// initiated a session or a rekey event. Initiator is Alice, responder is Bob. -/// -/// We require that after every rekeying event Alice and Bob switch roles. -#[derive(PartialEq, Eq)] pub enum Role { Alice, Bob, @@ -65,7 +58,7 @@ pub enum Role { /// existing session, which would be new attempts to create sessions. Typically one of these is associated /// with a single listen socket, local bound port, or other inbound endpoint. pub struct ReceiveContext { - initial_offer_defrag: Mutex, 1024, 128>>, + initial_offer_defrag: Mutex, 1024, 128>>, incoming_init_header_check_cipher: Aes, } @@ -77,52 +70,49 @@ pub struct Session { /// An arbitrary application defined object associated with each session pub application_data: Application::Data, - ratchet_counts: [AtomicU64; 2], // Number of session keys in ratchet, starts at 1 - header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) - receive_windows: [CounterWindow; 2], // Receive window for anti-replay and deduplication - state: RwLock, // Mutable parts of state (other than defrag buffers) + send_counter: AtomicU64, // Outgoing packet counter and nonce state + receive_window: [AtomicU64; COUNTER_MAX_DELTA], // Receive window for anti-replay and deduplication psk: Secret<64>, // Arbitrary PSK provided by external code noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key + header_check_cipher: Aes, // Cipher used for header check codes (not Noise related) + state: RwLock, // Mutable parts of state (other than defrag buffers) remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_p384_bytes: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key - defrag: Mutex, 8, 8>>, + defrag: Mutex, 8, 8>>, } struct SessionMutableState { remote_session_id: Option, // The other side's 48-bit session ID - send_counters: [Counter; 2], // Outgoing packet counter and nonce state, starts at 1 - session_keys: [Option; 2], // Buffers to store current and previous session key - cur_session_key_id: bool, // Pointer used for keys[] circular buffer + session_keys: [Option; 2], // Buffers to store last and latest key by 1-bit key index + cur_session_key_idx: usize, // Pointer to latest session key other side is confirmed to have offer: Option, // Most recent ephemeral offer sent to remote last_remote_offer: i64, // Time of most recent ephemeral offer (ms) } /// A shared symmetric session key. struct SessionKey { + ratchet_count: u64, // Number of preceding session keys in ratchet + rekey_at_time: i64, // Rekey at or after this time (ticks) + rekey_at_counter: u64, // Rekey at or after this counter + expire_at_counter: u64, // Hard error when this counter value is reached or exceeded secret_fingerprint: [u8; 16], // First 128 bits of a SHA384 computed from the secret - creation_time: i64, // Time session key was established - lifetime: KeyLifetime, // Key expiration time and counter ratchet_key: Secret<64>, // Ratchet key for deriving the next session key receive_key: Secret, // Receive side AES-GCM key send_key: Secret, // Send side AES-GCM key receive_cipher_pool: Mutex>>, // Pool of reusable sending ciphers send_cipher_pool: Mutex>>, // Pool of reusable receiving ciphers + rekey_needed: AtomicBool, // We have reached or exceeded the counter + confirmed: bool, // We have confirmed that the other side has this key jedi: bool, // True if Kyber1024 was used (both sides enabled) - role: Role, // The role of the local party that created this key -} - -/// Key lifetime state -struct KeyLifetime { - rekey_at_or_after_counter: u32, - rekey_at_or_after_timestamp: i64, } /// Alice's KEY_OFFER, remembered so Noise agreement process can resume on KEY_COUNTER_OFFER. struct EphemeralOffer { id: [u8; 16], // Arbitrary random offer ID - key_id: bool, // The key_id bound to this offer, for handling OOO rekeying creation_time: i64, // Local time when offer was created + ratchet_count: u64, // Ratchet count (starting at zero) for initial offer + ratchet_key: Option>, // Ratchet key from previous offer or None if first offer ss_key: Secret<64>, // Noise session key "under construction" at state after offer sent alice_e_keypair: P384KeyPair, // NIST P-384 key pair (Noise ephemeral key for Alice) alice_hk_keypair: Option, // Kyber1024 key pair (PQ hybrid ephemeral key for Alice) @@ -163,16 +153,13 @@ impl Session { let bob_s_public_blob = remote_s_public_blob; if let Some(bob_s_public) = Application::extract_s_public_from_raw(bob_s_public_blob) { if let Some(noise_ss) = app.get_local_s_keypair().agree(&bob_s_public) { - let send_counter = Counter::new(); let bob_s_public_blob_hash = SHA384::hash(bob_s_public_blob); let header_check_cipher = Aes::new(kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::()); let mut offer = None; if send_ephemeral_offer( &mut send, - send_counter.next(), - false, - false, + 1, local_session_id, None, app.get_local_s_public_blob(), @@ -181,7 +168,6 @@ impl Session { &bob_s_public_blob_hash, &noise_ss, None, - 1, None, mtu, current_time, @@ -192,19 +178,18 @@ impl Session { return Ok(Self { id: local_session_id, application_data, - ratchet_counts: [AtomicU64::new(1), AtomicU64::new(0)], + send_counter: AtomicU64::new(2), // 1 was used above + receive_window: std::array::from_fn(|_| AtomicU64::new(0)), + psk: psk.clone(), + noise_ss, header_check_cipher, - receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], state: RwLock::new(SessionMutableState { - send_counters: [send_counter, Counter::new()], remote_session_id: None, session_keys: [None, None], - cur_session_key_id: false, + cur_session_key_idx: 0, offer, last_remote_offer: i64::MIN, }), - psk: psk.clone(), - noise_ss, remote_s_public_blob_hash: bob_s_public_blob_hash, remote_s_public_p384_bytes: bob_s_public.as_bytes().clone(), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), @@ -230,14 +215,12 @@ impl Session { debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); let state = self.state.read().unwrap(); if let Some(remote_session_id) = state.remote_session_id { - let key_id = state.cur_session_key_id; - if let Some(session_key) = state.session_keys[key_id as usize].as_ref() { + if let Some(session_key) = state.session_keys[state.cur_session_key_idx].as_ref() { // Total size of the armored packet we are going to send (may end up being fragmented) let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; - //key ratchet count to be used for salting - let ratchet_count = self.ratchet_counts[key_id as usize].load(Ordering::Relaxed); + // This outgoing packet's nonce counter value. - let counter = state.send_counters[key_id as usize].next(); + let counter = self.send_counter.fetch_add(1, Ordering::SeqCst); //////////////////////////////////////////////////////////////// // packet encoding for post-noise transport @@ -250,14 +233,14 @@ impl Session { mtu_sized_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), + session_key.ratchet_count, counter, - key_id, )?; // Get an initialized AES-GCM cipher and re-initialize with a 96-bit IV built from remote session ID, // packet type, and counter. let mut c = session_key.get_send_cipher(counter)?; - c.reset_init_gcm(CanonicalHeader::make(remote_session_id, PACKET_TYPE_DATA, counter.to_u32()).as_bytes()); + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); // Send first N-1 fragments of N total fragments. let last_fragment_size; @@ -270,7 +253,7 @@ impl Session { let fragment_size = fragment_data_size + HEADER_SIZE; c.crypt(&data[..fragment_data_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); data = &data[fragment_data_size..]; - set_header_check_code(mtu_sized_buffer, ratchet_count, &self.header_check_cipher); + //set_header_check_code(mtu_sized_buffer, &self.header_check_cipher); send(&mut mtu_sized_buffer[..fragment_size]); debug_assert!(header[15].wrapping_shr(2) < 63); @@ -291,7 +274,7 @@ impl Session { c.crypt(data, &mut mtu_sized_buffer[HEADER_SIZE..payload_end]); let gcm_tag = c.finish_encrypt(); mtu_sized_buffer[payload_end..last_fragment_size].copy_from_slice(&gcm_tag); - set_header_check_code(mtu_sized_buffer, ratchet_count, &self.header_check_cipher); + //set_header_check_code(mtu_sized_buffer, &self.header_check_cipher); send(&mut mtu_sized_buffer[..last_fragment_size]); // Check reusable AES-GCM instance back into pool. @@ -310,26 +293,15 @@ impl Session { /// Check whether this session is established. pub fn established(&self) -> bool { let state = self.state.read().unwrap(); - state.remote_session_id.is_some() && state.session_keys[state.cur_session_key_id as usize].is_some() + state.remote_session_id.is_some() && state.session_keys[state.cur_session_key_idx].is_some() } - /// Get information about this session's security state. - /// - /// This returns a tuple of: the key fingerprint, the time it was established, the length of its ratchet chain, - /// and whether Kyber1024 was used. None is returned if the session isn't established. - pub fn status(&self) -> Option<([u8; 16], i64, u64, bool)> { + /// Get the shared key fingerprint, ratchet count, and whether Kyber was used, or None if not yet established. + pub fn status(&self) -> Option<([u8; 16], u64, bool)> { let state = self.state.read().unwrap(); - let key_id = state.cur_session_key_id; - if let Some(key) = state.session_keys[key_id as usize].as_ref() { - Some(( - key.secret_fingerprint, - key.creation_time, - self.ratchet_counts[key_id as usize].load(Ordering::Relaxed), - key.jedi, - )) - } else { - None - } + state.session_keys[state.cur_session_key_idx] + .as_ref() + .map(|k| (k.secret_fingerprint, k.ratchet_count, k.jedi)) } /// This function needs to be called on each session at least every SERVICE_INTERVAL milliseconds. @@ -339,7 +311,7 @@ impl Session { /// * `offer_metadata' - Any meta-data to include with initial key offers sent. /// * `mtu` - Current physical transport MTU /// * `current_time` - Current monotonic time in milliseconds - /// * `assume_key_is_too_old` - Re-key the session now if the protocol allows it (subject to rate limits and whether it is the local party's turn to rekey) + /// * `force_expire` - Re-key now regardless of key aging (if it is our turn!) pub fn service( &self, app: &Application, @@ -347,35 +319,23 @@ impl Session { offer_metadata: &[u8], mtu: usize, current_time: i64, - assume_key_is_too_old: bool, + force_expire: bool, ) { let state = self.state.read().unwrap(); - let current_key_id = state.cur_session_key_id; - if (state.session_keys[state.cur_session_key_id as usize].as_ref().map_or(true, |key| { - key.role == Role::Bob - && (assume_key_is_too_old - || key - .lifetime - .should_rekey(state.send_counters[current_key_id as usize].current(), current_time)) - })) && state - .offer - .as_ref() - .map_or(true, |o| (current_time - o.creation_time) > Application::REKEY_RATE_LIMIT_MS) + if (force_expire + || state.session_keys[state.cur_session_key_idx] + .as_ref() + .map_or(true, |k| k.rekey_needed.load(Ordering::Relaxed) || current_time < k.rekey_at_time)) + && state + .offer + .as_ref() + .map_or(true, |o| (current_time - o.creation_time) > Application::REKEY_RATE_LIMIT_MS) { if let Some(remote_s_public) = P384PublicKey::from_bytes(&self.remote_s_public_p384_bytes) { - // This routine handles sending a rekeying packet, resending lost rekeying packets, and resending lost initial offer packets - // The protocol has been designed such that initial rekeying packets are identical to resent rekeying packets, except for the counter, so we can reuse the same code for doing both - let has_existing_session = state.remote_session_id.is_some(); - // Mark the previous key as no longer being supported because it is about to be overwritten - // It should not be possible for a session to accidentally invalidate the key currently in use solely because of the read lock - self.receive_windows[(!current_key_id) as usize].invalidate(); let mut offer = None; - // The session will keep sending ephemeral offers until rekeying is successful if send_ephemeral_offer( &mut send, - state.send_counters[current_key_id as usize].next(), - current_key_id, - has_existing_session && !current_key_id, + self.send_counter.fetch_add(1, Ordering::SeqCst), self.id, state.remote_session_id, app.get_local_s_public_blob(), @@ -383,9 +343,8 @@ impl Session { &remote_s_public, &self.remote_s_public_blob_hash, &self.noise_ss, - state.session_keys[current_key_id as usize].as_ref(), - self.ratchet_counts[current_key_id as usize].load(Ordering::Relaxed), - if has_existing_session { + state.session_keys[state.cur_session_key_idx].as_ref(), + if state.remote_session_id.is_some() { Some(&self.header_check_cipher) } else { None @@ -397,11 +356,24 @@ impl Session { .is_ok() { drop(state); - self.state.write().unwrap().offer = offer; + let _ = self.state.write().unwrap().offer.replace(offer.unwrap()); } } } } + + /// Check the receive window without mutating state. + #[inline(always)] + fn check_receive_window(&self, counter: u64) -> bool { + self.receive_window[(counter as usize) % COUNTER_MAX_DELTA].load(Ordering::Acquire) < counter + } + + /// Update the receive window, returning true if the packet is still valid. + /// This should only be called after the packet is authenticated. + #[inline(always)] + fn update_receive_window(&self, counter: u64) -> bool { + self.receive_window[(counter as usize) % COUNTER_MAX_DELTA].fetch_max(counter, Ordering::AcqRel) < counter + } } impl ReceiveContext { @@ -439,123 +411,100 @@ impl ReceiveContext { return Err(Error::InvalidPacket); } - let raw_counter = u32::from_le(memory::load_raw(&incoming_packet[4..8])); - let key_id = (raw_counter & 1) > 0; - let counter = raw_counter.wrapping_shr(1); - let packet_type_fragment_info = u16::from_le(memory::load_raw(&incoming_packet[14..16])); - let packet_type = (packet_type_fragment_info & 0x0f) as u8; - let fragment_count = ((packet_type_fragment_info.wrapping_shr(4) + 1) as u8) & 63; - let fragment_no = packet_type_fragment_info.wrapping_shr(10) as u8; // & 63 not needed + let raw_counter_and_key_index = u64::from_le(memory::load_raw(&incoming_packet[0..8])); + let counter = raw_counter_and_key_index.wrapping_shr(1); + let key_index = (raw_counter_and_key_index & 1) as usize; + let raw_session_id_type_and_fragment_info = u64::from_le(memory::load_raw(&incoming_packet[8..16])); + let local_session_id = SessionId::new_from_u64(raw_session_id_type_and_fragment_info & 0xffffffffffffu64); + let packet_type = (raw_session_id_type_and_fragment_info.wrapping_shr(48) & 0x0f) as u8; + let fragment_count = ((raw_session_id_type_and_fragment_info.wrapping_shr(48 + 4) & 63) + 1) as u8; + let fragment_no = raw_session_id_type_and_fragment_info.wrapping_shr(48 + 10) as u8; // & 63 not needed - if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64) - { + if let Some(local_session_id) = local_session_id { if let Some(session) = app.lookup_session(local_session_id) { - // This is the only time ratchet_counts is ever accessed outside of a lock - // As such this read can be wrong, but that is incredibly unlikely since we are tracking the last two ratchet counts, and if it's wrong it just means we drop a packet that would have been dropped anyways for being too old or too new - let ratchet_count = session.ratchet_counts[key_id as usize].load(Ordering::SeqCst); - if verify_header_check_code(incoming_packet, ratchet_count, &session.header_check_cipher) { - if session.receive_windows[key_id as usize].message_received(counter) { - let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); - if fragment_count > 1 { - if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { - let mut defrag = session.defrag.lock().unwrap(); - // by using the counter + the key_id as the key we can prevent packet collisions, this only works if defrag hashes - let fragment_gather_array = defrag.get_or_create_mut(&raw_counter, || GatherArray::new(fragment_count)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { - drop(defrag); // release lock - return self.receive_complete( - app, - remote_address, - &mut send, - data_buf, - counter, - key_id, - canonical_header.as_bytes(), - assembled_packet.as_ref(), - packet_type, - Some(session), - mtu, - current_time, - ); - } - } else { - unlikely_branch(); - return Err(Error::InvalidPacket); + if session.check_receive_window(counter) { + if fragment_count > 1 { + if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { + let mut defrag = session.defrag.lock().unwrap(); + let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { + drop(defrag); // release lock + return self.receive_complete( + app, + remote_address, + &mut send, + data_buf, + counter, + assembled_packet.as_ref(), + packet_type, + Some(session), + key_index, + mtu, + current_time, + ); } } else { - return self.receive_complete( - app, - remote_address, - &mut send, - data_buf, - counter, - key_id, - canonical_header.as_bytes(), - &[incoming_packet_buf], - packet_type, - Some(session), - mtu, - current_time, - ); + unlikely_branch(); + return Err(Error::InvalidPacket); } } else { - unlikely_branch(); - return Ok(ReceiveResult::Ignored); - } - } else { - unlikely_branch(); - return Err(Error::FailedAuthentication); - } - } else { - unlikely_branch(); - return Err(Error::UnknownLocalSessionId(local_session_id)); - } - } else { - unlikely_branch(); // We want data receive to be the priority branch, this is only occasionally used - // Salt with a known value so new sessions can be established - // NOTE: this check is trivial to bypass by just replaying recorded packets - // This check isn't security critical so that is fine - if verify_header_check_code(incoming_packet, 1u64, &self.incoming_init_header_check_cipher) { - let canonical_header = CanonicalHeader::make(SessionId::NIL, packet_type, counter); - if fragment_count > 1 { - let mut defrag = self.initial_offer_defrag.lock().unwrap(); - let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); - if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { - drop(defrag); // release lock return self.receive_complete( app, remote_address, &mut send, data_buf, counter, - key_id, - canonical_header.as_bytes(), - assembled_packet.as_ref(), + &[incoming_packet_buf], packet_type, - None, + Some(session), + key_index, mtu, current_time, ); } } else { + unlikely_branch(); + return Ok(ReceiveResult::Ignored); + } + } else { + unlikely_branch(); + return Err(Error::UnknownLocalSessionId(local_session_id)); + } + } else { + unlikely_branch(); // we want data receive to be the priority branch, this is only occasionally used + if fragment_count > 1 { + let mut defrag = self.initial_offer_defrag.lock().unwrap(); + let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { + drop(defrag); // release lock return self.receive_complete( app, remote_address, &mut send, data_buf, counter, - key_id, - canonical_header.as_bytes(), - &[incoming_packet_buf], + assembled_packet.as_ref(), packet_type, None, + key_index, mtu, current_time, ); } } else { - unlikely_branch(); - return Err(Error::FailedAuthentication); + return self.receive_complete( + app, + remote_address, + &mut send, + data_buf, + counter, + &[incoming_packet_buf], + packet_type, + None, + key_index, + mtu, + current_time, + ); } }; @@ -572,26 +521,33 @@ impl ReceiveContext { remote_address: &Application::RemoteAddress, send: &mut SendFunction, data_buf: &'a mut [u8], - counter: u32, - key_id: bool, - canonical_header_bytes: &[u8; 12], + counter: u64, fragments: &[Application::IncomingPacketBuffer], packet_type: u8, session: Option>, + key_index: usize, mtu: usize, current_time: i64, ) -> Result, Error> { debug_assert!(fragments.len() >= 1); - if packet_type == PACKET_TYPE_DATA { + + // The first 'if' below should capture both DATA and NOP but not other types. Sanity check this. + debug_assert_eq!(PACKET_TYPE_DATA, 0); + debug_assert_eq!(PACKET_TYPE_NOP, 1); + + let message_nonce = create_message_nonce(packet_type, counter); + + if packet_type <= PACKET_TYPE_NOP { if let Some(session) = session { let state = session.state.read().unwrap(); - if let Some(session_key) = state.session_keys[key_id as usize].as_ref() { - let mut c = session_key.get_receive_cipher(); - c.reset_init_gcm(canonical_header_bytes); + if let Some(session_key) = state.session_keys[key_index].as_ref() { //////////////////////////////////////////////////////////////// // packet decoding for post-noise transport //////////////////////////////////////////////////////////////// + let mut c = session_key.get_receive_cipher(); + c.reset_init_gcm(&message_nonce); + let mut data_len = 0; // Decrypt fragments 0..N-1 where N is the number of fragments. @@ -632,13 +588,38 @@ impl ReceiveContext { session_key.return_receive_cipher(c); if aead_authentication_ok { - if session.receive_windows[key_id as usize].message_authenticated(counter) { - return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); + if session.update_receive_window(counter) { + // If the packet authenticated, this confirms that the other side indeed + // knows this session key. In that case mark the session key as confirmed + // and if the current active key is older switch it to point to this one. + if !session_key.confirmed { + unlikely_branch(); + let this_ratchet_count = session_key.ratchet_count; + drop(state); + let mut state = session.state.write().unwrap(); + + state.session_keys[key_index].as_mut().unwrap().confirmed = true; + if state.cur_session_key_idx != key_index { + if let Some(other_session_key) = state.session_keys[state.cur_session_key_idx].as_ref() { + if other_session_key.ratchet_count < this_ratchet_count { + state.cur_session_key_idx = key_index; + } + } + } + } + + if packet_type == PACKET_TYPE_DATA { + return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); + } else { + unlikely_branch(); + return Ok(ReceiveResult::Ok); + } + } else { + unlikely_branch(); + return Ok(ReceiveResult::Ignored); } } } - - // If no known key authenticated the packet, decryption has failed. return Err(Error::FailedAuthentication); } else { unlikely_branch(); @@ -696,7 +677,7 @@ impl ReceiveContext { if !secure_eq( &hmac_sha384_2( app.get_local_s_public_blob_hash(), - canonical_header_bytes, + &message_nonce, &kex_packet[HEADER_SIZE..hmac1_end], ), &kex_packet[hmac1_end..kex_packet_len], @@ -706,7 +687,7 @@ impl ReceiveContext { // Check rate limits. if let Some(session) = session.as_ref() { - if current_time < session.state.read().unwrap().last_remote_offer + Application::REKEY_RATE_LIMIT_MS { + if (current_time - session.state.read().unwrap().last_remote_offer) < Application::REKEY_RATE_LIMIT_MS { return Err(Error::RateLimited); } } else { @@ -734,7 +715,7 @@ impl ReceiveContext { kbkdf512(noise_ik_incomplete_es.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::(), false, ); - c.reset_init_gcm(canonical_header_bytes); + c.reset_init_gcm(&message_nonce); c.crypt_in_place(&mut kex_packet[plaintext_end..payload_end]); let gcm_tag = &kex_packet[payload_end..aes_gcm_tag_end]; if !c.finish_decrypt(gcm_tag) { @@ -775,7 +756,7 @@ impl ReceiveContext { if !secure_eq( &hmac_sha384_2( kbkdf512(noise_ik_incomplete_es_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), - canonical_header_bytes, + &message_nonce, &kex_packet_saved_ciphertext[HEADER_SIZE..aes_gcm_tag_end], ), &kex_packet[aes_gcm_tag_end..hmac1_end], @@ -787,7 +768,7 @@ impl ReceiveContext { // Perform checks and match ratchet key if there's an existing session, or gate (via host) and // then create new sessions. - let (new_session, reply_counter, new_key_id, ratchet_key) = if let Some(session) = session.as_ref() { + let (new_session, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() { // Existing session identity must match the one in this offer. if !secure_eq(&session.remote_s_public_blob_hash, &SHA384::hash(&alice_s_public_blob)) { return Err(Error::FailedAuthentication); @@ -796,60 +777,51 @@ impl ReceiveContext { // Match ratchet key fingerprint and fail if no match, which likely indicates an old offer packet. let alice_ratchet_key_fingerprint = alice_ratchet_key_fingerprint.unwrap(); let mut ratchet_key = None; + let mut last_ratchet_count = 0; let state = session.state.read().unwrap(); - if let Some(k) = state.session_keys[key_id as usize].as_ref() { - if k.role == Role::Bob { - // The local party is not allowed to be Bob twice in a row - // This prevents rekeying failure from both parties attempting to rekey at the same time - return Ok(ReceiveResult::Ignored); - } - if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { - ratchet_key = Some(k.ratchet_key.clone()); + for k in state.session_keys.iter() { + if let Some(k) = k.as_ref() { + if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { + ratchet_key = Some(k.ratchet_key.clone()); + last_ratchet_count = k.ratchet_count; + break; + } } } if ratchet_key.is_none() { return Ok(ReceiveResult::Ignored); // old packet? } - (None, state.send_counters[key_id as usize].next(), !key_id, ratchet_key) + (None, ratchet_key, last_ratchet_count) } else { - if key_id != false { - // All new sessions must start with key_id 0 - // This has no security implications, it just makes programming the initial offer easier - return Ok(ReceiveResult::Ignored); - } if let Some((new_session_id, psk, associated_object)) = app.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata) { let header_check_cipher = Aes::new( kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::(), ); - let send_counter = Counter::new(); - let reply_counter = send_counter.next(); ( Some(Session:: { id: new_session_id, application_data: associated_object, - ratchet_counts: [AtomicU64::new(1), AtomicU64::new(0)], + receive_window: std::array::from_fn(|_| AtomicU64::new(0)), + send_counter: AtomicU64::new(1), + psk, + noise_ss, header_check_cipher, - receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()], state: RwLock::new(SessionMutableState { - send_counters: [send_counter, Counter::new()], remote_session_id: Some(alice_session_id), - session_keys: [None, None], //this is the only value which will be writen later - cur_session_key_id: false, + session_keys: [None, None], + cur_session_key_idx: 0, offer: None, last_remote_offer: current_time, }), - psk, - noise_ss, remote_s_public_blob_hash: SHA384::hash(&alice_s_public_blob), remote_s_public_p384_bytes: alice_s_public.as_bytes().clone(), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), }), - reply_counter, - false, None, + 0, ) } else { return Err(Error::NewSessionRejected); @@ -860,6 +832,10 @@ impl ReceiveContext { let existing_session = session; let session = existing_session.as_ref().map_or_else(|| new_session.as_ref().unwrap(), |s| &*s); + if !session.update_receive_window(counter) { + return Ok(ReceiveResult::Ignored); + } + // Generate our ephemeral NIST P-384 key pair. let bob_e_keypair = P384KeyPair::generate(); @@ -909,7 +885,10 @@ impl ReceiveContext { // <- e, ee, se //////////////////////////////////////////////////////////////// + let next_ratchet_count = last_ratchet_count + 1; + let mut reply_buf = [0_u8; KEX_BUF_LEN]; + let reply_counter = session.send_counter.fetch_add(1, Ordering::SeqCst); let mut idx = HEADER_SIZE; idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?; @@ -940,11 +919,11 @@ impl ReceiveContext { mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), + next_ratchet_count, reply_counter, - key_id, )?; - let reply_canonical_header = - CanonicalHeader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32()); + + let reply_message_nonce = create_message_nonce(PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter); // Encrypt reply packet using final Noise_IK key BEFORE mixing hybrid or ratcheting, since the other side // must decrypt before doing these things. @@ -952,7 +931,7 @@ impl ReceiveContext { kbkdf512(noise_ik_complete.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::(), true, ); - c.reset_init_gcm(reply_canonical_header.as_bytes()); + c.reset_init_gcm(&reply_message_nonce); c.crypt_in_place(&mut reply_buf[plaintext_end..payload_end]); let gcm_tag = c.finish_encrypt(); @@ -974,47 +953,35 @@ impl ReceiveContext { // Kyber exchange, but you'd need a not-yet-existing quantum computer for that. let hmac = hmac_sha384_2( kbkdf512(session_key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), - reply_canonical_header.as_bytes(), + &reply_message_nonce, &reply_buf[HEADER_SIZE..aes_gcm_tag_end], ); idx = safe_write_all(&mut reply_buf, idx, &hmac)?; let packet_end = idx; - if session.receive_windows[key_id as usize].message_authenticated(counter) { - // Initial key offers should only check this if this is a rekey - let session_key = SessionKey::new(session_key, Role::Bob, current_time, hybrid_kk.is_some()); - let mut state = session.state.write().unwrap(); - let _ = state.session_keys[new_key_id as usize].replace(session_key); - let _ = state.remote_session_id.replace(alice_session_id); - let ratchet_count = session.ratchet_counts[key_id as usize].load(Ordering::SeqCst); - if state.cur_session_key_id != new_key_id { - //this prevents anything from being reset twice if the key offer was made twice - debug_assert!(new_key_id != key_id); - // receive_windows only has race conditions with the counter of the remote party. It is theoretically possible that the local host receives counters under new_key_id while the receive_window is still in the process of resetting, but this is very unlikely. If it does happen, two things could happen: - // 1) The received counter is less than what is currently stored in the window, so a valid packet would be rejected. - // 2) The received counter is greater than what is currently stored in the window, so a valid packet would be accepted *but* its counter is deleted from the window so it can be replayed. - // To prevent these race conditions, we only update the ratchet_count for salting the check code after the window has reset. So if a counter passes the initial check code: it either means the thread sees ratchet count has been update therefore it sees receive_window has been reset (due to memory orderings), or it means a rare accidental check code collision has occurred. - session.receive_windows[new_key_id as usize].reset_for_new_key_offer(); - session.ratchet_counts[new_key_id as usize].fetch_add(2, Ordering::SeqCst); + let session_key = SessionKey::new( + session_key, + Role::Bob, + current_time, + reply_counter, + next_ratchet_count, + false, // Bob can't know yet if Alice got the counter offer + hybrid_kk.is_some(), + ); - // If the following wasn't done inside a lock, a theoretical race condition exists where a thread uses the new key id before the counter is reset, or worse: a thread has held onto the previous key id equal to new_key_id, and attempts to use the reset counter for encryption. - // For this reason do not access send_counters without holding the read lock. - state.cur_session_key_id = new_key_id; - state.send_counters[new_key_id as usize].reset_for_new_key_offer(); - } - drop(state); + let mut state = session.state.write().unwrap(); + let _ = state.remote_session_id.replace(alice_session_id); + let _ = state.session_keys[(next_ratchet_count as usize) & 1].replace(session_key); + drop(state); - // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. + // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. - send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, ratchet_count, &session.header_check_cipher); + send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher); - if let Some(new_session) = new_session { - return Ok(ReceiveResult::OkNewSession(new_session)); - } else { - return Ok(ReceiveResult::Ok); - } + if new_session.is_some() { + return Ok(ReceiveResult::OkNewSession(new_session.unwrap())); } else { - return Ok(ReceiveResult::Ignored); + return Ok(ReceiveResult::Ok); } } @@ -1056,7 +1023,7 @@ impl ReceiveContext { .first_n::(), false, ); - c.reset_init_gcm(canonical_header_bytes); + c.reset_init_gcm(&message_nonce); c.crypt_in_place(&mut kex_packet[plaintext_end..payload_end]); let gcm_tag = &kex_packet[payload_end..aes_gcm_tag_end]; if !c.finish_decrypt(gcm_tag) { @@ -1069,7 +1036,7 @@ impl ReceiveContext { parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?; // Check that this is a counter offer to the original offer we sent. - if !(offer.id.eq(offer_id)) { + if !offer.id.eq(offer_id) { return Ok(ReceiveResult::Ignored); } @@ -1090,17 +1057,12 @@ impl ReceiveContext { let mut session_key = noise_ik_complete; // Mix ratchet key from previous session key (if any) and Kyber1024 hybrid shared key (if any). - // We either have a session, in which case they should have supplied a ratchet key fingerprint, or - // we don't and they should not have supplied one. - if let Some(cur_session_key) = state.session_keys[key_id as usize].as_ref() { - if bob_ratchet_key_id.is_some() { - session_key = Secret(hmac_sha512(cur_session_key.ratchet_key.as_bytes(), session_key.as_bytes())); - } else { - return Err(Error::FailedAuthentication); - } - } else if bob_ratchet_key_id.is_some() { - return Err(Error::FailedAuthentication); - } + let last_ratchet_count = if bob_ratchet_key_id.is_some() && offer.ratchet_key.is_some() { + session_key = Secret(hmac_sha512(offer.ratchet_key.as_ref().unwrap().as_bytes(), session_key.as_bytes())); + offer.ratchet_count + } else { + 0 + }; if let Some(hybrid_kk) = hybrid_kk.as_ref() { session_key = Secret(hmac_sha512(hybrid_kk.as_bytes(), session_key.as_bytes())); } @@ -1109,41 +1071,65 @@ impl ReceiveContext { if !secure_eq( &hmac_sha384_2( kbkdf512(session_key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), - canonical_header_bytes, + &message_nonce, &kex_packet_saved_ciphertext[HEADER_SIZE..aes_gcm_tag_end], ), &kex_packet[aes_gcm_tag_end..kex_packet_len], ) { return Err(Error::FailedAuthentication); } - if session.receive_windows[key_id as usize].message_authenticated(counter) { - // Alice has now completed and validated the full hybrid exchange. - let session_key = SessionKey::new(session_key, Role::Alice, current_time, hybrid_kk.is_some()); + // Alice has now completed and validated the full hybrid exchange. - let new_key_id = offer.key_id; - drop(state); - let mut state = session.state.write().unwrap(); - let _ = state.remote_session_id.replace(bob_session_id); - let _ = state.session_keys[new_key_id as usize].replace(session_key); - if state.cur_session_key_id != new_key_id { - // When an brand new key offer is sent, it is sent using the new_key_id==false counter, we cannot reset it in that case. - // NOTE: the following code should be properly threadsafe, see the large comment above at the end of KEY_OFFER decoding for more info - session.receive_windows[new_key_id as usize].reset_for_new_key_offer(); - let _ = session.ratchet_counts[new_key_id as usize].fetch_add(2, Ordering::SeqCst); - state.cur_session_key_id = new_key_id; - state.send_counters[new_key_id as usize].reset_for_new_key_offer(); - } - let _ = state.offer.take(); + let reply_counter = session.send_counter.fetch_add(1, Ordering::SeqCst); + let next_ratchet_count = last_ratchet_count + 1; - return Ok(ReceiveResult::Ok); - } else { - return Ok(ReceiveResult::Ignored); - } + let session_key = SessionKey::new( + session_key, + Role::Alice, + current_time, + reply_counter, + next_ratchet_count, + true, // Alice knows Bob got the offer + hybrid_kk.is_some(), + ); + + //////////////////////////////////////////////////////////////// + // packet encoding for post-noise session start ack + //////////////////////////////////////////////////////////////// + + let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; + + create_packet_header( + &mut reply_buf, + HEADER_SIZE + AES_GCM_TAG_SIZE, + mtu, + PACKET_TYPE_NOP, + bob_session_id.into(), + next_ratchet_count, + reply_counter, + )?; + + let mut c = session_key.get_send_cipher(reply_counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, reply_counter)); + let gcm_tag = c.finish_encrypt(); + safe_write_all(&mut reply_buf, HEADER_SIZE, &gcm_tag)?; + session_key.return_send_cipher(c); + + //set_header_check_code(&mut reply_buf, &session.header_check_cipher); + send(&mut reply_buf); + + drop(state); + let mut state = session.state.write().unwrap(); + + let _ = state.remote_session_id.replace(bob_session_id); + let next_key_index = (next_ratchet_count as usize) & 1; + let _ = state.session_keys[next_key_index].replace(session_key); + state.cur_session_key_idx = next_key_index; + let _ = state.offer.take(); + + return Ok(ReceiveResult::Ok); } - } else { - unlikely_branch(); - return Err(Error::SessionNotEstablished); } // Just ignore counter-offers that are out of place. They probably indicate that this side @@ -1158,12 +1144,9 @@ impl ReceiveContext { } /// Create an send an ephemeral offer, populating ret_ephemeral_offer on success. -/// If there is no current session key set `current_key_id == new_key_id == false` fn send_ephemeral_offer( send: &mut SendFunction, - counter: CounterValue, - current_key_id: bool, - new_key_id: bool, + counter: u64, alice_session_id: SessionId, bob_session_id: Option, alice_s_public_blob: &[u8], @@ -1172,7 +1155,6 @@ fn send_ephemeral_offer( bob_s_public_blob_hash: &[u8], noise_ss: &Secret<48>, current_key: Option<&SessionKey>, - ratchet_count: u64, header_check_cipher: Option<&Aes>, // None to use one based on the recipient's public key for initial contact mtu: usize, current_time: i64, @@ -1191,6 +1173,13 @@ fn send_ephemeral_offer( None }; + // Get ratchet key for current key if one exists. + let (ratchet_key, ratchet_count) = if let Some(current_key) = current_key { + (Some(current_key.ratchet_key.clone()), current_key.ratchet_count) + } else { + (None, 0) + }; + // Random ephemeral offer ID let id: [u8; 16] = random::get_bytes_secure(); @@ -1205,6 +1194,7 @@ fn send_ephemeral_offer( let mut idx = HEADER_SIZE; idx = safe_write_all(&mut packet_buf, idx, &[SESSION_PROTOCOL_VERSION])?; + //TODO: check this, the below line is supposed to be the blob, not just the key, right? idx = safe_write_all(&mut packet_buf, idx, alice_e_keypair.public_key_bytes())?; let plaintext_end = idx; @@ -1220,13 +1210,9 @@ fn send_ephemeral_offer( } else { idx = safe_write_all(&mut packet_buf, idx, &[HYBRID_KEY_TYPE_NONE])?; } - if let Some(current_key) = current_key { + if let Some(ratchet_key) = ratchet_key.as_ref() { idx = safe_write_all(&mut packet_buf, idx, &[0x01])?; - idx = safe_write_all( - &mut packet_buf, - idx, - &public_fingerprint_of_secret(current_key.ratchet_key.as_bytes())[..16], - )?; + idx = safe_write_all(&mut packet_buf, idx, &public_fingerprint_of_secret(ratchet_key.as_bytes())[..16])?; } else { idx = safe_write_all(&mut packet_buf, idx, &[0x00])?; } @@ -1245,11 +1231,11 @@ fn send_ephemeral_offer( mtu, PACKET_TYPE_INITIAL_KEY_OFFER, bob_session_id, + ratchet_count, counter, - current_key_id, )?; - let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32()); + let message_nonce = create_message_nonce(PACKET_TYPE_INITIAL_KEY_OFFER, counter); // Encrypt packet and attach AES-GCM tag. let gcm_tag = { @@ -1257,7 +1243,7 @@ fn send_ephemeral_offer( kbkdf512(es_key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::(), true, ); - c.reset_init_gcm(canonical_header.as_bytes()); + c.reset_init_gcm(&message_nonce); c.crypt_in_place(&mut packet_buf[plaintext_end..payload_end]); c.finish_encrypt() }; @@ -1272,37 +1258,33 @@ fn send_ephemeral_offer( // HMAC packet using static + ephemeral key. let hmac1 = hmac_sha384_2( kbkdf512(ss_key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), - canonical_header.as_bytes(), + &message_nonce, &packet_buf[HEADER_SIZE..aes_gcm_tag_end], ); idx = safe_write_all(&mut packet_buf, idx, &hmac1)?; let hmac1_end = idx; // Add secondary HMAC to verify that the caller knows the recipient's full static public identity. - let hmac2 = hmac_sha384_2( - bob_s_public_blob_hash, - canonical_header.as_bytes(), - &packet_buf[HEADER_SIZE..hmac1_end], - ); + let hmac2 = hmac_sha384_2(bob_s_public_blob_hash, &message_nonce, &packet_buf[HEADER_SIZE..hmac1_end]); idx = safe_write_all(&mut packet_buf, idx, &hmac2)?; let packet_end = idx; if let Some(header_check_cipher) = header_check_cipher { - send_with_fragmentation(send, &mut packet_buf[..packet_end], mtu, ratchet_count, header_check_cipher); + send_with_fragmentation(send, &mut packet_buf[..packet_end], mtu, header_check_cipher); } else { send_with_fragmentation( send, &mut packet_buf[..packet_end], mtu, - ratchet_count, &Aes::new(kbkdf512(&bob_s_public_blob_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::()), ); } *ret_ephemeral_offer = Some(EphemeralOffer { id, - key_id: new_key_id, creation_time: current_time, + ratchet_count, + ratchet_key, ss_key, alice_e_keypair, alice_hk_keypair, @@ -1318,8 +1300,8 @@ fn create_packet_header( mtu: usize, packet_type: u8, recipient_session_id: SessionId, - counter: CounterValue, - key_id: bool, + ratchet_count: u64, + counter: u64, ) -> Result<(), Error> { let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize; @@ -1332,16 +1314,15 @@ fn create_packet_header( if fragment_count <= MAX_FRAGMENTS { // Header indexed by bit/byte: - // [0-31]/[0-3] header check code (computed later) - // [32-32]/[4-] key id - // [33-63]/[-7] counter - // [64-111]/[8-13] recipient's session ID (unique on their side) - // [112-115]/[14-] packet type (0-15) - // [116-121]/[-] number of fragments (0..63 for 1..64 fragments total) - // [122-127]/[-15] fragment number (0, 1, 2, ...) + // [0-0] ratchet count least significant bit + // [1-63] counter + // [64-111] recipient's session ID (unique on their side) + // [112-115] packet type (0-15) + // [116-121] number of fragments (0..63 for 1..64 fragments total) + // [122-127] fragment number (0, 1, 2, ...) memory::store_raw( - (counter.to_u32().wrapping_shl(1) | (key_id as u32)).to_le(), - &mut header_destination_buffer[4..], + (counter.wrapping_shl(1) | (ratchet_count & 1)).to_le(), + &mut header_destination_buffer[0..], ); memory::store_raw( (u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)) @@ -1355,12 +1336,29 @@ fn create_packet_header( } } +#[derive(Clone, Copy)] +#[repr(C, packed)] +struct MessageNonce(u64, u32); + +/// Create a 12-bit AES-GCM nonce. +/// +/// The primary information that we want to be contained here is the counter and the +/// packet type. The former makes this unique and the latter's inclusion authenticates +/// it as effectively AAD. Other elements of the header are either not authenticated, +/// like fragmentation info, or their authentication is implied via key exchange like +/// the session ID. +/// +/// This is also used as part of HMAC authentication for key exchange packets. +#[inline(always)] +fn create_message_nonce(packet_type: u8, counter: u64) -> [u8; 12] { + memory::to_byte_array(MessageNonce(counter.to_le(), (packet_type as u32).to_le())) +} + /// Break a packet into fragments and send them all. fn send_with_fragmentation( send: &mut SendFunction, packet: &mut [u8], mtu: usize, - ratchet_count: u64, header_check_cipher: &Aes, ) { let packet_len = packet.len(); @@ -1369,7 +1367,7 @@ fn send_with_fragmentation( let mut header: [u8; 16] = packet[..HEADER_SIZE].try_into().unwrap(); loop { let fragment = &mut packet[fragment_start..fragment_end]; - set_header_check_code(fragment, ratchet_count, header_check_cipher); + //set_header_check_code(fragment, header_check_cipher); send(fragment); if fragment_end < packet_len { debug_assert!(header[15].wrapping_shr(2) < 63); @@ -1384,33 +1382,6 @@ fn send_with_fragmentation( } } -/// Set 32-bit header check code, used to make fragmentation mechanism robust. -fn set_header_check_code(packet: &mut [u8], ratchet_count: u64, header_check_cipher: &Aes) { - debug_assert!(packet.len() >= MIN_PACKET_SIZE); - //4 bytes is the ratchet key - //12 bytes is the header we want to verify - let mut header_mac = 0u128.to_le_bytes(); - memory::store_raw((ratchet_count as u32).to_le_bytes(), &mut header_mac[0..4]); - header_mac[4..16].copy_from_slice(&packet[4..16]); - header_check_cipher.encrypt_block_in_place(&mut header_mac); - - packet[..4].copy_from_slice(&header_mac[..4]); -} - -/// Verify 32-bit header check code. -/// This is not nearly enough entropy to be cryptographically secure, it only is meant for making DOS attacks very hard -fn verify_header_check_code(packet: &[u8], ratchet_count: u64, header_check_cipher: &Aes) -> bool { - debug_assert!(packet.len() >= MIN_PACKET_SIZE); - //4 bytes is the ratchet key - //12 bytes is the header we want to verify - let mut header_mac = 0u128.to_le_bytes(); - memory::store_raw((ratchet_count as u32).to_le_bytes(), &mut header_mac[0..4]); - header_mac[4..16].copy_from_slice(&packet[4..16]); - header_check_cipher.encrypt_block_in_place(&mut header_mac); - - memory::load_raw::(&packet[..4]) == memory::load_raw::(&header_mac[..4]) -} - /// Parse KEY_OFFER and KEY_COUNTER_OFFER starting after the unencrypted public key part. fn parse_dec_key_offer_after_header( incoming_packet: &[u8], @@ -1421,15 +1392,15 @@ fn parse_dec_key_offer_after_header( let mut session_id_buf = 0_u64.to_ne_bytes(); session_id_buf[..SESSION_ID_SIZE].copy_from_slice(safe_read_exact(&mut p, SESSION_ID_SIZE)?); - let remote_session_id = SessionId::new_from_u64(u64::from_le_bytes(session_id_buf)).ok_or(Error::InvalidPacket)?; + let alice_session_id = SessionId::new_from_u64(u64::from_le_bytes(session_id_buf)).ok_or(Error::InvalidPacket)?; - let remote_s_public_blob_len = varint_safe_read(&mut p)?; - let remote_s_public_blob = safe_read_exact(&mut p, remote_s_public_blob_len as usize)?; + let alice_s_public_blob_len = varint_safe_read(&mut p)?; + let alice_s_public_blob = safe_read_exact(&mut p, alice_s_public_blob_len as usize)?; - let remote_metadata_len = varint_safe_read(&mut p)?; - let remote_metadata = safe_read_exact(&mut p, remote_metadata_len as usize)?; + let alice_metadata_len = varint_safe_read(&mut p)?; + let alice_metadata = safe_read_exact(&mut p, alice_metadata_len as usize)?; - let remote_hk_public_raw = match safe_read_exact(&mut p, 1)?[0] { + let alice_hk_public_raw = match safe_read_exact(&mut p, 1)?[0] { HYBRID_KEY_TYPE_KYBER1024 => { if packet_type == PACKET_TYPE_INITIAL_KEY_OFFER { safe_read_exact(&mut p, pqc_kyber::KYBER_PUBLICKEYBYTES)? @@ -1443,7 +1414,7 @@ fn parse_dec_key_offer_after_header( if p.is_empty() { return Err(Error::InvalidPacket); } - let remote_ratchet_key_fingerprint = if safe_read_exact(&mut p, 1)?[0] == 0x01 { + let alice_ratchet_key_fingerprint = if safe_read_exact(&mut p, 1)?[0] == 0x01 { Some(safe_read_exact(&mut p, 16)?) } else { None @@ -1451,17 +1422,17 @@ fn parse_dec_key_offer_after_header( Ok(( offer_id, //always 16 bytes - remote_session_id, - remote_s_public_blob, - remote_metadata, - remote_hk_public_raw, - remote_ratchet_key_fingerprint, //always 16 bytes + alice_session_id, + alice_s_public_blob, + alice_metadata, + alice_hk_public_raw, + alice_ratchet_key_fingerprint, //always 16 bytes )) } impl SessionKey { /// Create a new symmetric shared session key and set its key expiration times, etc. - fn new(key: Secret<64>, role: Role, current_time: i64, jedi: bool) -> Self { + fn new(key: Secret<64>, role: Role, current_time: i64, current_counter: u64, ratchet_count: u64, confirmed: bool, jedi: bool) -> Self { let a2b: Secret = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n_clone(); let b2a: Secret = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n_clone(); let (receive_key, send_key) = match role { @@ -1469,21 +1440,28 @@ impl SessionKey { Role::Bob => (a2b, b2a), }; Self { + ratchet_count, + rekey_at_time: current_time + .checked_add(REKEY_AFTER_TIME_MS + ((random::xorshift64_random() as u32) % REKEY_AFTER_TIME_MS_MAX_JITTER) as i64) + .unwrap(), + rekey_at_counter: current_counter.checked_add(REKEY_AFTER_USES).unwrap(), + expire_at_counter: current_counter.checked_add(EXPIRE_AFTER_USES).unwrap(), secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(), - creation_time: current_time, - lifetime: KeyLifetime::new(current_time), ratchet_key: kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_RATCHETING), receive_key, send_key, receive_cipher_pool: Mutex::new(Vec::with_capacity(2)), send_cipher_pool: Mutex::new(Vec::with_capacity(2)), + rekey_needed: AtomicBool::new(false), + confirmed, jedi, - role, } } - fn get_send_cipher(&self, counter: CounterValue) -> Result, Error> { - if !self.lifetime.expired(counter) { + fn get_send_cipher(&self, counter: u64) -> Result, Error> { + if counter < self.expire_at_counter { + self.rekey_needed + .store(counter >= self.rekey_at_counter, std::sync::atomic::Ordering::Relaxed); Ok(self .send_cipher_pool .lock() @@ -1491,6 +1469,8 @@ impl SessionKey { .pop() .unwrap_or_else(|| Box::new(AesGcm::new(self.send_key.as_bytes(), true)))) } else { + unlikely_branch(); + // Not only do we return an error, but we also destroy the key. let mut scp = self.send_cipher_pool.lock().unwrap(); scp.clear(); @@ -1517,39 +1497,6 @@ impl SessionKey { } } -impl KeyLifetime { - fn new(current_time: i64) -> Self { - Self { - rekey_at_or_after_counter: REKEY_AFTER_USES as u32 + (random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER), - rekey_at_or_after_timestamp: current_time - + REKEY_AFTER_TIME_MS - + (random::next_u32_secure() % REKEY_AFTER_TIME_MS_MAX_JITTER) as i64, - } - } - - fn should_rekey(&self, counter: CounterValue, current_time: i64) -> bool { - counter.to_u32() >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp - } - - fn expired(&self, counter: CounterValue) -> bool { - counter.to_u32() >= EXPIRE_AFTER_USES as u32 - } -} - -impl CanonicalHeader { - pub fn make(session_id: SessionId, packet_type: u8, counter: u32) -> Self { - CanonicalHeader( - (u64::from(session_id) | (packet_type as u64).wrapping_shl(48)).to_le(), - counter.to_le(), - ) - } - - #[inline(always)] - pub fn as_bytes(&self) -> &[u8; 12] { - memory::as_byte_array(self) - } -} - /// Write src into buffer starting at the index idx. If buffer cannot fit src at that location, nothing at all is written and Error::UnexpectedBufferOverrun is returned. No other errors can be returned by this function. An idx incremented by the amount written is returned. fn safe_write_all(buffer: &mut [u8], idx: usize, src: &[u8]) -> Result { let dest = &mut buffer[idx..]; @@ -1599,9 +1546,7 @@ fn hmac_sha384_2(key: &[u8], a: &[u8], b: &[u8]) -> [u8; 48] { } /// HMAC-SHA512 key derivation based on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12) -/// -/// Cryptographically this isn't meaningfully different from HMAC(key, [label]), -/// but NIST seems to like it this way. +/// Cryptographically this isn't meaningfully different from HMAC(key, [label]) but this is how NIST rolls. fn kbkdf512(key: &[u8], label: u8) -> Secret<64> { Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00])) }