From 826f0d3ab5b85a15b9a1febd4564e880b1963220 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Fri, 16 Dec 2022 10:13:16 -0500 Subject: [PATCH] Some more renaming to make code more readable. --- zssp/src/constants.rs | 4 +- zssp/src/zssp.rs | 157 ++++++++++++++++++++++-------------------- 2 files changed, 83 insertions(+), 78 deletions(-) diff --git a/zssp/src/constants.rs b/zssp/src/constants.rs index 4ed977baa..16f3c039b 100644 --- a/zssp/src/constants.rs +++ b/zssp/src/constants.rs @@ -52,10 +52,10 @@ pub(crate) const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 10; // 10 min pub(crate) const SESSION_PROTOCOL_VERSION: u8 = 0x00; /// Secondary key type: none, use only P-384 for forward secrecy. -pub(crate) const E1_TYPE_NONE: u8 = 0; +pub(crate) const HYBRID_KEY_TYPE_NONE: u8 = 0; /// Secondary key type: Kyber1024, PQ forward secrecy enabled. -pub(crate) const E1_TYPE_KYBER1024: u8 = 1; +pub(crate) const HYBRID_KEY_TYPE_KYBER1024: u8 = 1; /// Size of packet header pub(crate) const HEADER_SIZE: usize = 16; diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 1d118e4ba..b10313cb8 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -35,7 +35,7 @@ pub enum Error { /// Packet failed one or more authentication (MAC) checks FailedAuthentication, - /// New session was rejected via Host::check_new_session_attempt or Host::accept_new_session. + /// New session was rejected by the application layer. NewSessionRejected, /// Rekeying failed and session secret has reached its hard usage count limit @@ -56,7 +56,10 @@ pub enum Error { /// Data object is too large to send, even with fragmentation DataTooLarge, - /// An unexpected buffer overrun occured while attempting to encode or decode a packet, this can only ever happen if exceptionally large key blobs or metadata are being used, or as the result of an internal encoding bug. + /// An unexpected buffer overrun occured while attempting to encode or decode a packet. + /// + /// This can only ever happen if exceptionally large key blobs or metadata are being used, + /// or as the result of an internal encoding bug. UnexpectedBufferOverrun, } @@ -99,12 +102,12 @@ pub struct ReceiveContext { } /// ZSSP bi-directional packet transport channel. -pub struct Session { +pub struct Session { /// This side's session ID (unique on this side) pub id: SessionId, /// An arbitrary object associated with session (type defined in Host trait) - pub user_data: Layer::SessionUserData, + pub user_data: Application::SessionUserData, send_counter: Counter, // Outgoing packet counter and nonce state psk: Secret<64>, // Arbitrary PSK provided by external code @@ -114,7 +117,7 @@ pub struct Session { remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_raw: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key - defrag: Mutex, 8, 8>>, + defrag: Mutex, 8, 8>>, } struct SessionMutableState { @@ -208,7 +211,6 @@ fn safe_write_all(buffer: &mut [u8], idx: usize, src: &[u8]) -> Result Result { let mut b = [0_u8; varint::VARINT_MAX_SIZE_BYTES]; let i = varint::encode(&mut b, v); @@ -227,7 +229,6 @@ fn safe_read_exact<'a>(src: &mut &'a [u8], amt: usize) -> Result<&'a [u8], Error } } /// Read a variable length integer, which can consume up to 10 bytes. Uses varint_safe_read to do so. -#[inline(always)] fn varint_safe_read(src: &mut &[u8]) -> Result { let (v, amt) = varint::decode(*src).ok_or(Error::InvalidPacket)?; let (_, b) = src.split_at(amt); @@ -235,7 +236,7 @@ fn varint_safe_read(src: &mut &[u8]) -> Result { Ok(v) } -impl Session { +impl Session { /// Create a new session and send an initial key offer message to the other end. /// /// * `host` - Interface to application using ZSSP @@ -247,19 +248,19 @@ impl Session { /// * `mtu` - Physical wire maximum transmition unit /// * `current_time` - Current monotonic time in milliseconds pub fn start_new( - host: &Layer, + app: &Application, mut send: SendFunction, local_session_id: SessionId, remote_s_public_blob: &[u8], offer_metadata: &[u8], psk: &Secret<64>, - user_data: Layer::SessionUserData, + user_data: Application::SessionUserData, mtu: usize, current_time: i64, ) -> Result { let bob_s_public_blob = remote_s_public_blob; - if let Some(bob_s_public) = Layer::extract_s_public_from_raw(bob_s_public_blob) { - if let Some(noise_ss) = host.get_local_s_keypair().agree(&bob_s_public) { + if let Some(bob_s_public) = Application::extract_s_public_from_raw(bob_s_public_blob) { + if let Some(noise_ss) = app.get_local_s_keypair().agree(&bob_s_public) { let send_counter = Counter::new(); let bob_s_public_blob_hash = SHA384::hash(bob_s_public_blob); let header_check_cipher = @@ -270,7 +271,7 @@ impl Session { send_counter.next(), local_session_id, None, - host.get_local_s_public_blob(), + app.get_local_s_public_blob(), offer_metadata, &bob_s_public, &bob_s_public_blob_hash, @@ -310,16 +311,16 @@ impl Session { /// Send data over the session. /// /// * `send` - Function to call to send physical packet(s) - /// * `mtu_buffer` - A writable work buffer whose size also specifies the physical MTU + /// * `mtu_sized_buffer` - A writable work buffer whose size also specifies the physical MTU /// * `data` - Data to send #[inline] pub fn send( &self, mut send: SendFunction, - mtu_buffer: &mut [u8], + mtu_sized_buffer: &mut [u8], mut data: &[u8], ) -> Result<(), Error> { - debug_assert!(mtu_buffer.len() >= MIN_TRANSPORT_MTU); + debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); let state = self.state.read().unwrap(); if let Some(remote_session_id) = state.remote_session_id { if let Some(session_key) = state.session_keys[state.cur_session_key_idx].as_ref() { @@ -335,9 +336,9 @@ impl Session { // Create initial header for first fragment of packet and place in first HEADER_SIZE bytes of buffer. create_packet_header( - mtu_buffer, + mtu_sized_buffer, packet_len, - mtu_buffer.len(), + mtu_sized_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), counter, @@ -350,21 +351,21 @@ impl Session { // Send first N-1 fragments of N total fragments. let last_fragment_size; - if packet_len > mtu_buffer.len() { - let mut header: [u8; 16] = mtu_buffer[..HEADER_SIZE].try_into().unwrap(); - let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE; - let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); + if packet_len > mtu_sized_buffer.len() { + let mut header: [u8; 16] = mtu_sized_buffer[..HEADER_SIZE].try_into().unwrap(); + let fragment_data_mtu = mtu_sized_buffer.len() - HEADER_SIZE; + let last_fragment_data_mtu = mtu_sized_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); loop { let fragment_data_size = fragment_data_mtu.min(data.len()); let fragment_size = fragment_data_size + HEADER_SIZE; - c.crypt(&data[..fragment_data_size], &mut mtu_buffer[HEADER_SIZE..fragment_size]); + c.crypt(&data[..fragment_data_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); data = &data[fragment_data_size..]; - set_header_check_code(mtu_buffer, &self.header_check_cipher); - send(&mut mtu_buffer[..fragment_size]); + set_header_check_code(mtu_sized_buffer, &self.header_check_cipher); + send(&mut mtu_sized_buffer[..fragment_size]); debug_assert!(header[15].wrapping_shr(2) < 63); header[15] += 0x04; // increment fragment number - mtu_buffer[..HEADER_SIZE].copy_from_slice(&header); + mtu_sized_buffer[..HEADER_SIZE].copy_from_slice(&header); if data.len() <= last_fragment_data_mtu { break; @@ -377,11 +378,11 @@ impl Session { // Send final fragment (or only fragment if no fragmentation was needed) let payload_end = data.len() + HEADER_SIZE; - c.crypt(data, &mut mtu_buffer[HEADER_SIZE..payload_end]); + c.crypt(data, &mut mtu_sized_buffer[HEADER_SIZE..payload_end]); let gcm_tag = c.finish_encrypt(); - mtu_buffer[payload_end..last_fragment_size].copy_from_slice(&gcm_tag); - set_header_check_code(mtu_buffer, &self.header_check_cipher); - send(&mut mtu_buffer[..last_fragment_size]); + mtu_sized_buffer[payload_end..last_fragment_size].copy_from_slice(&gcm_tag); + set_header_check_code(mtu_sized_buffer, &self.header_check_cipher); + send(&mut mtu_sized_buffer[..last_fragment_size]); // Check reusable AES-GCM instance back into pool. session_key.return_send_cipher(c); @@ -425,7 +426,7 @@ impl Session { /// * `force_rekey` - Re-key the session now regardless of key aging (still subject to rate limiting) pub fn service( &self, - host: &Layer, + host: &Application, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, @@ -440,7 +441,7 @@ impl Session { && state .offer .as_ref() - .map_or(true, |o| (current_time - o.creation_time) > Layer::REKEY_RATE_LIMIT_MS) + .map_or(true, |o| (current_time - o.creation_time) > Application::REKEY_RATE_LIMIT_MS) { if let Some(remote_s_public) = P384PublicKey::from_bytes(&self.remote_s_public_raw) { let mut offer = None; @@ -474,8 +475,8 @@ impl Session { } } -impl ReceiveContext { - pub fn new(host: &Layer) -> Self { +impl ReceiveContext { + pub fn new(host: &Application) -> Self { Self { initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), incoming_init_header_check_cipher: Aes::new( @@ -495,14 +496,14 @@ impl ReceiveContext { #[inline] pub fn receive<'a, SendFunction: FnMut(&mut [u8])>( &self, - host: &Layer, - remote_address: &Layer::RemoteAddress, + app: &Application, + remote_address: &Application::RemoteAddress, mut send: SendFunction, data_buf: &'a mut [u8], - incoming_packet_buf: Layer::IncomingPacketBuffer, + incoming_packet_buf: Application::IncomingPacketBuffer, mtu: usize, current_time: i64, - ) -> Result, Error> { + ) -> Result, Error> { let incoming_packet = incoming_packet_buf.as_ref(); if incoming_packet.len() < MIN_PACKET_SIZE { unlikely_branch(); @@ -517,7 +518,7 @@ impl ReceiveContext { if let Some(local_session_id) = SessionId::new_from_u64(u64::from_le(memory::load_raw(&incoming_packet[8..16])) & 0xffffffffffffu64) { - if let Some(session) = host.lookup_session(local_session_id) { + if let Some(session) = app.lookup_session(local_session_id) { if verify_header_check_code(incoming_packet, &session.header_check_cipher) { let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); if fragment_count > 1 { @@ -527,7 +528,7 @@ impl ReceiveContext { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock return self.receive_complete( - host, + app, remote_address, &mut send, data_buf, @@ -545,7 +546,7 @@ impl ReceiveContext { } } else { return self.receive_complete( - host, + app, remote_address, &mut send, data_buf, @@ -576,7 +577,7 @@ impl ReceiveContext { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock return self.receive_complete( - host, + app, remote_address, &mut send, data_buf, @@ -590,7 +591,7 @@ impl ReceiveContext { } } else { return self.receive_complete( - host, + app, remote_address, &mut send, data_buf, @@ -618,17 +619,17 @@ impl ReceiveContext { #[inline] fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>( &self, - host: &Layer, - remote_address: &Layer::RemoteAddress, + app: &Application, + remote_address: &Application::RemoteAddress, send: &mut SendFunction, data_buf: &'a mut [u8], - canonical_header_bytes: &[u8; 12], - fragments: &[Layer::IncomingPacketBuffer], + canonical_header_bytes: &[u8; AES_GCM_TAG_SIZE], + fragments: &[Application::IncomingPacketBuffer], packet_type: u8, - session: Option, + session: Option, mtu: usize, current_time: i64, - ) -> Result, Error> { + ) -> Result, Error> { debug_assert!(fragments.len() >= 1); // The first 'if' below should capture both DATA and NOP but not other types. Sanity check this. @@ -769,7 +770,7 @@ impl ReceiveContext { // Check the second HMAC first, which proves that the sender knows the recipient's full static identity. let hmac2 = &kex_packet[hmac1_end..kex_packet_len]; if !hmac_sha384_2( - host.get_local_s_public_blob_hash(), + app.get_local_s_public_blob_hash(), canonical_header_bytes, &kex_packet[HEADER_SIZE..hmac1_end], ) @@ -780,11 +781,11 @@ impl ReceiveContext { // Check rate limits. if let Some(session) = session.as_ref() { - if (current_time - session.state.read().unwrap().last_remote_offer) < Layer::REKEY_RATE_LIMIT_MS { + if (current_time - session.state.read().unwrap().last_remote_offer) < Application::REKEY_RATE_LIMIT_MS { return Err(Error::RateLimited); } } else { - if !host.check_new_session(self, remote_address) { + if !app.check_new_session(self, remote_address) { return Err(Error::RateLimited); } } @@ -792,7 +793,7 @@ impl ReceiveContext { // Key agreement: alice (remote) ephemeral NIST P-384 <> local static NIST P-384 let alice_e_public = P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..plaintext_end]).ok_or(Error::FailedAuthentication)?; - let noise_es = host + let noise_es = app .get_local_s_keypair() .agree(&alice_e_public) .ok_or(Error::FailedAuthentication)?; @@ -832,10 +833,10 @@ impl ReceiveContext { } // Extract alice's static NIST P-384 public key from her public blob. - let alice_s_public = Layer::extract_s_public_from_raw(alice_s_public_blob).ok_or(Error::InvalidPacket)?; + let alice_s_public = Application::extract_s_public_from_raw(alice_s_public_blob).ok_or(Error::InvalidPacket)?; // Key agreement: both sides' static P-384 keys. - let noise_ss = host + let noise_ss = app .get_local_s_keypair() .agree(&alice_s_public) .ok_or(Error::FailedAuthentication)?; @@ -874,7 +875,7 @@ impl ReceiveContext { let state = session.state.read().unwrap(); for k in state.session_keys.iter() { if let Some(k) = k.as_ref() { - if secret_fingerprint(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { + if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { ratchet_key = Some(k.ratchet_key.clone()); ratchet_count = k.ratchet_count; break; @@ -888,13 +889,13 @@ impl ReceiveContext { (None, ratchet_key, ratchet_count) } else { if let Some((new_session_id, psk, associated_object)) = - host.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata) + app.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata) { let header_check_cipher = Aes::new( kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::(), ); ( - Some(Session:: { + Some(Session:: { id: new_session_id, user_data: associated_object, send_counter: Counter::new(), @@ -983,10 +984,10 @@ impl ReceiveContext { idx = varint_safe_write(&mut reply_buf, idx, 0)?; // they don't need our static public; they have it idx = varint_safe_write(&mut reply_buf, idx, 0)?; // no meta-data in counter-offers (could be used in the future) if let Some(bob_hk_public) = bob_hk_public.as_ref() { - idx = safe_write_all(&mut reply_buf, idx, &[E1_TYPE_KYBER1024])?; + idx = safe_write_all(&mut reply_buf, idx, &[HYBRID_KEY_TYPE_KYBER1024])?; idx = safe_write_all(&mut reply_buf, idx, bob_hk_public)?; } else { - idx = safe_write_all(&mut reply_buf, idx, &[E1_TYPE_NONE])?; + idx = safe_write_all(&mut reply_buf, idx, &[HYBRID_KEY_TYPE_NONE])?; } if ratchet_key.is_some() { idx = safe_write_all(&mut reply_buf, idx, &[0x01])?; @@ -1086,7 +1087,7 @@ impl ReceiveContext { let bob_e_public = P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..plaintext_end]) .ok_or(Error::FailedAuthentication)?; let noise_ee = offer.alice_e_keypair.agree(&bob_e_public).ok_or(Error::FailedAuthentication)?; - let noise_se = host.get_local_s_keypair().agree(&bob_e_public).ok_or(Error::FailedAuthentication)?; + let noise_se = app.get_local_s_keypair().agree(&bob_e_public).ok_or(Error::FailedAuthentication)?; let noise_ik_key = Secret(hmac_sha512( session.psk.as_bytes(), @@ -1229,7 +1230,7 @@ fn send_ephemeral_offer( // Perform key agreement with the other side's static P-384 public key. let noise_es = alice_e_keypair.agree(bob_s_public).ok_or(Error::InvalidPacket)?; - // Generate a Kyber1024 pair if enabled. + // Generate a Kyber1024 (hybrid PQ crypto) pair if enabled. let alice_hk_keypair = if JEDI { Some(pqc_kyber::keypair(&mut random::SecureRandom::get())) } else { @@ -1268,14 +1269,14 @@ fn send_ephemeral_offer( idx = varint_safe_write(&mut packet_buf, idx, alice_metadata.len() as u64)?; idx = safe_write_all(&mut packet_buf, idx, alice_metadata)?; if let Some(hkp) = alice_hk_keypair { - idx = safe_write_all(&mut packet_buf, idx, &[E1_TYPE_KYBER1024])?; + idx = safe_write_all(&mut packet_buf, idx, &[HYBRID_KEY_TYPE_KYBER1024])?; idx = safe_write_all(&mut packet_buf, idx, &hkp.public)?; } else { - idx = safe_write_all(&mut packet_buf, idx, &[E1_TYPE_NONE])?; + idx = safe_write_all(&mut packet_buf, idx, &[HYBRID_KEY_TYPE_NONE])?; } if let Some(ratchet_key) = ratchet_key.as_ref() { idx = safe_write_all(&mut packet_buf, idx, &[0x01])?; - idx = safe_write_all(&mut packet_buf, idx, &secret_fingerprint(ratchet_key.as_bytes())[..16])?; + idx = safe_write_all(&mut packet_buf, idx, &public_fingerprint_of_secret(ratchet_key.as_bytes())[..16])?; } else { idx = safe_write_all(&mut packet_buf, idx, &[0x00])?; } @@ -1362,7 +1363,7 @@ fn send_ephemeral_offer( /// Populate all but the header check code in the first 16 bytes of a packet or fragment. #[inline(always)] fn create_packet_header( - header: &mut [u8], + header_destination_buffer: &mut [u8], packet_len: usize, mtu: usize, packet_type: u8, @@ -1371,7 +1372,7 @@ fn create_packet_header( ) -> Result<(), Error> { let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize; - debug_assert!(header.len() >= HEADER_SIZE); + debug_assert!(header_destination_buffer.len() >= HEADER_SIZE); debug_assert!(mtu >= MIN_TRANSPORT_MTU); debug_assert!(packet_len >= MIN_PACKET_SIZE); debug_assert!(fragment_count > 0); @@ -1386,11 +1387,11 @@ fn create_packet_header( // [112-115] packet type (0-15) // [116-121] number of fragments (0..63 for 1..64 fragments total) // [122-127] fragment number (0, 1, 2, ...) - memory::store_raw((counter.to_u32() as u64).to_le(), header); + memory::store_raw((counter.to_u32() as u64).to_le(), header_destination_buffer); memory::store_raw( (u64::from(recipient_session_id) | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)) .to_le(), - &mut header[8..], + &mut header_destination_buffer[8..], ); Ok(()) } else { @@ -1464,7 +1465,7 @@ fn parse_dec_key_offer_after_header( let alice_metadata = safe_read_exact(&mut p, alice_metadata_len as usize)?; let alice_hk_public_raw = match safe_read_exact(&mut p, 1)?[0] { - E1_TYPE_KYBER1024 => { + HYBRID_KEY_TYPE_KYBER1024 => { if packet_type == PACKET_TYPE_INITIAL_KEY_OFFER { safe_read_exact(&mut p, pqc_kyber::KYBER_PUBLICKEYBYTES)? } else { @@ -1473,6 +1474,7 @@ fn parse_dec_key_offer_after_header( } _ => &[], }; + if p.is_empty() { return Err(Error::InvalidPacket); } @@ -1481,6 +1483,7 @@ fn parse_dec_key_offer_after_header( } else { None }; + Ok(( offer_id, //always 16 bytes alice_session_id, @@ -1501,7 +1504,7 @@ impl SessionKey { Role::Bob => (a2b, b2a), }; Self { - secret_fingerprint: secret_fingerprint(key.as_bytes())[..16].try_into().unwrap(), + secret_fingerprint: public_fingerprint_of_secret(key.as_bytes())[..16].try_into().unwrap(), establish_time: current_time, establish_counter: current_counter, lifetime: KeyLifetime::new(current_counter, current_time), @@ -1599,16 +1602,18 @@ fn hmac_sha384_2(key: &[u8], a: &[u8], b: &[u8]) -> [u8; 48] { hmac.finish() } -/// HMAC-SHA512 key derivation function modeled on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12) -/// Cryptographically this isn't really different from HMAC(key, [label]) with just one byte. +/// HMAC-SHA512 key derivation based on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12) +/// +/// Cryptographically this isn't meaningfully different from HMAC(key, [label]), +/// but NIST seems to like it this way. fn kbkdf512(key: &[u8], label: u8) -> Secret<64> { Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00])) } -/// Get a hash of a secret key that can be used as a public fingerprint. -fn secret_fingerprint(key: &[u8]) -> [u8; 48] { +/// Get a hash of a secret that can be used as a public key fingerprint to check ratcheting during key exchange. +fn public_fingerprint_of_secret(key: &[u8]) -> [u8; 48] { let mut tmp = SHA384::new(); - tmp.update("fp".as_bytes()); + tmp.update(&[0xf0, 0x0d]); // arbitrary salt tmp.update(key); tmp.finish() }