diff --git a/core-crypto/src/aes.rs b/core-crypto/src/aes.rs index f456bd1d6..ae7c95825 100644 --- a/core-crypto/src/aes.rs +++ b/core-crypto/src/aes.rs @@ -22,7 +22,20 @@ mod fruit_flavored { const kCCOptionECBMode: i32 = 2; extern "C" { - fn CCCryptorCreateWithMode(op: i32, mode: i32, alg: i32, padding: i32, iv: *const c_void, key: *const c_void, key_len: usize, tweak: *const c_void, tweak_len: usize, num_rounds: c_int, options: i32, cryyptor_ref: *mut *mut c_void) -> i32; + fn CCCryptorCreateWithMode( + op: i32, + mode: i32, + alg: i32, + padding: i32, + iv: *const c_void, + key: *const c_void, + key_len: usize, + tweak: *const c_void, + tweak_len: usize, + num_rounds: c_int, + options: i32, + cryyptor_ref: *mut *mut c_void, + ) -> i32; fn CCCryptorUpdate(cryptor_ref: *mut c_void, data_in: *const c_void, data_in_len: usize, data_out: *mut c_void, data_out_len: usize, data_out_written: *mut usize) -> i32; //fn CCCryptorReset(cryptor_ref: *mut c_void, iv: *const c_void) -> i32; fn CCCryptorRelease(cryptor_ref: *mut c_void) -> i32; @@ -53,8 +66,40 @@ mod fruit_flavored { panic!("AES supports 128, 192, or 256 bits keys"); } let mut aes: Self = std::mem::zeroed(); - assert_eq!(CCCryptorCreateWithMode(kCCEncrypt, kCCModeECB, kCCAlgorithmAES, 0, null(), k.as_ptr().cast(), k.len(), null(), 0, 0, kCCOptionECBMode, &mut aes.0), 0); - assert_eq!(CCCryptorCreateWithMode(kCCDecrypt, kCCModeECB, kCCAlgorithmAES, 0, null(), k.as_ptr().cast(), k.len(), null(), 0, 0, kCCOptionECBMode, &mut aes.1), 0); + assert_eq!( + CCCryptorCreateWithMode( + kCCEncrypt, + kCCModeECB, + kCCAlgorithmAES, + 0, + null(), + k.as_ptr().cast(), + k.len(), + null(), + 0, + 0, + kCCOptionECBMode, + &mut aes.0 + ), + 0 + ); + assert_eq!( + CCCryptorCreateWithMode( + kCCDecrypt, + kCCModeECB, + kCCAlgorithmAES, + 0, + null(), + k.as_ptr().cast(), + k.len(), + null(), + 0, + 0, + kCCOptionECBMode, + &mut aes.1 + ), + 0 + ); aes } } @@ -189,7 +234,7 @@ mod fruit_flavored { } #[inline(always)] - pub fn finish(&mut self) -> [u8; 16] { + pub fn finish_encrypt(&mut self) -> [u8; 16] { let mut tag = 0_u128.to_ne_bytes(); unsafe { let mut tag_len = 16; @@ -200,30 +245,23 @@ mod fruit_flavored { } tag } + + #[inline(always)] + pub fn finish_decrypt(&mut self, expected_tag: &[u8]) -> bool { + self.finish_encrypt().eq(expected_tag) + } } unsafe impl Send for AesGcm {} } #[cfg(not(target_os = "macos"))] -mod openssl { +mod openssl_aes { use crate::secret::Secret; use openssl::symm::{Cipher, Crypter, Mode}; use std::cell::UnsafeCell; + use std::mem::MaybeUninit; - #[inline(always)] - fn aes_ctr_by_key_size(ks: usize) -> Cipher { - match ks { - 16 => Cipher::aes_128_ctr(), - 24 => Cipher::aes_192_ctr(), - 32 => Cipher::aes_256_ctr(), - _ => { - panic!("AES supports 128, 192, or 256 bits keys"); - } - } - } - - #[inline(always)] fn aes_gcm_by_key_size(ks: usize) -> Cipher { match ks { 16 => Cipher::aes_128_gcm(), @@ -235,7 +273,6 @@ mod openssl { } } - #[inline(always)] fn aes_ecb_by_key_size(ks: usize) -> Cipher { match ks { 16 => Cipher::aes_128_ecb(), @@ -250,9 +287,11 @@ mod openssl { pub struct Aes(UnsafeCell, UnsafeCell); impl Aes { - #[inline(always)] pub fn new(k: &[u8]) -> Self { - let (mut c, mut d) = (Crypter::new(aes_ecb_by_key_size(k.len()), Mode::Encrypt, k, None).unwrap(), Crypter::new(aes_ecb_by_key_size(k.len()), Mode::Decrypt, k, None).unwrap()); + let (mut c, mut d) = ( + Crypter::new(aes_ecb_by_key_size(k.len()), Mode::Encrypt, k, None).unwrap(), + Crypter::new(aes_ecb_by_key_size(k.len()), Mode::Decrypt, k, None).unwrap(), + ); c.pad(false); d.pad(false); Self(UnsafeCell::new(c), UnsafeCell::new(d)) @@ -260,7 +299,7 @@ mod openssl { #[inline(always)] pub fn encrypt_block(&self, plaintext: &[u8], ciphertext: &mut [u8]) { - let mut tmp = [0_u8; 32]; + let mut tmp: [u8; 32] = unsafe { MaybeUninit::uninit().assume_init() }; let c: &mut Crypter = unsafe { &mut *self.0.get() }; if c.update(plaintext, &mut tmp).unwrap() != 16 { assert_eq!(c.finalize(&mut tmp).unwrap(), 16); @@ -270,7 +309,7 @@ mod openssl { #[inline(always)] pub fn encrypt_block_in_place(&self, data: &mut [u8]) { - let mut tmp = [0_u8; 32]; + let mut tmp: [u8; 32] = unsafe { MaybeUninit::uninit().assume_init() }; let c: &mut Crypter = unsafe { &mut *self.0.get() }; if c.update(data, &mut tmp).unwrap() != 16 { assert_eq!(c.finalize(&mut tmp).unwrap(), 16); @@ -280,9 +319,9 @@ mod openssl { #[inline(always)] pub fn decrypt_block(&self, ciphertext: &[u8], plaintext: &mut [u8]) { - let mut tmp = [0_u8; 32]; + let mut tmp: [u8; 32] = unsafe { MaybeUninit::uninit().assume_init() }; let c: &mut Crypter = unsafe { &mut *self.1.get() }; - if c.update(plaintext, &mut tmp).unwrap() != 16 { + if c.update(ciphertext, &mut tmp).unwrap() != 16 { assert_eq!(c.finalize(&mut tmp).unwrap(), 16); } plaintext[..16].copy_from_slice(&tmp[..16]); @@ -290,7 +329,7 @@ mod openssl { #[inline(always)] pub fn decrypt_block_in_place(&self, data: &mut [u8]) { - let mut tmp = [0_u8; 32]; + let mut tmp: [u8; 32] = unsafe { MaybeUninit::uninit().assume_init() }; let c: &mut Crypter = unsafe { &mut *self.1.get() }; if c.update(data, &mut tmp).unwrap() != 16 { assert_eq!(c.finalize(&mut tmp).unwrap(), 16); @@ -307,7 +346,6 @@ mod openssl { impl AesGcm { /// Construct a new AES-GCM cipher. /// Key must be 16, 24, or 32 bytes in length or a panic will occur. - #[inline(always)] pub fn new(k: &[u8], encrypt: bool) -> Self { let mut s: Secret<32> = Secret::default(); match k.len() { @@ -323,9 +361,7 @@ mod openssl { /// Initialize AES-CTR for encryption or decryption with the given IV. /// If it's already been used, this also resets the cipher. There is no separate reset. - #[inline(always)] pub fn init(&mut self, iv: &[u8]) { - assert_eq!(iv.len(), 12); let mut c = Crypter::new( aes_gcm_by_key_size(self.1), if self.3 { @@ -338,37 +374,46 @@ mod openssl { ) .unwrap(); c.pad(false); + let _ = c.set_tag_len(16); let _ = self.2.replace(c); } #[inline(always)] pub fn aad(&mut self, aad: &[u8]) { - let _ = self.2.as_mut().unwrap().aad_update(aad); + assert!(self.2.as_mut().unwrap().aad_update(aad).is_ok()); } /// Encrypt or decrypt (same operation with CTR mode) #[inline(always)] pub fn crypt(&mut self, input: &[u8], output: &mut [u8]) { - let _ = self.2.as_mut().unwrap().update(input, output); + assert!(self.2.as_mut().unwrap().update(input, output).is_ok()); } /// Encrypt or decrypt in place (same operation with CTR mode) #[inline(always)] pub fn crypt_in_place(&mut self, data: &mut [u8]) { - let _ = self.2.as_mut().unwrap().update(unsafe { &*std::slice::from_raw_parts(data.as_ptr(), data.len()) }, data); + assert!(self.2.as_mut().unwrap().update(unsafe { &*std::slice::from_raw_parts(data.as_ptr(), data.len()) }, data).is_ok()); } #[inline(always)] - pub fn finish(&mut self) -> [u8; 16] { + pub fn finish_encrypt(&mut self) -> [u8; 16] { let mut tag = [0_u8; 16]; - let c = self.2.as_mut().unwrap(); - if c.finalize(&mut []).is_ok() { - if !c.get_tag(&mut tag).is_ok() { - tag.fill(0); - } - } + let mut c = self.2.take().unwrap(); + assert!(c.finalize(&mut tag).is_ok()); + assert!(c.get_tag(&mut tag).is_ok()); tag } + + #[inline(always)] + pub fn finish_decrypt(&mut self, expected_tag: &[u8]) -> bool { + let mut c = self.2.take().unwrap(); + if c.set_tag(expected_tag).is_ok() { + let result = c.finalize(&mut []).is_ok(); + result + } else { + false + } + } } unsafe impl Send for AesGcm {} @@ -378,7 +423,7 @@ mod openssl { pub use fruit_flavored::{Aes, AesGcm}; #[cfg(not(target_os = "macos"))] -pub use openssl::{Aes, AesGcm}; +pub use openssl_aes::{Aes, AesGcm}; #[cfg(test)] mod tests { @@ -391,7 +436,7 @@ mod tests { for i in 1..12345 { buf[i] = i as u8; } - let iv = [1_u8; 16]; + let iv = [1_u8; 12]; let mut c = AesGcm::new(&[1_u8; 32], true); @@ -402,7 +447,10 @@ mod tests { c.crypt_in_place(&mut buf); } let duration = SystemTime::now().duration_since(start).unwrap(); - println!("AES-256-GCM encrypt benchmark: {} MiB/sec", (((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64()); + println!( + "AES-256-GCM encrypt benchmark: {} MiB/sec", + (((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64() + ); let mut c = AesGcm::new(&[1_u8; 32], false); @@ -412,6 +460,9 @@ mod tests { c.crypt_in_place(&mut buf); } let duration = SystemTime::now().duration_since(start).unwrap(); - println!("AES-256-GCM decrypt benchmark: {} MiB/sec", (((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64()); + println!( + "AES-256-GCM decrypt benchmark: {} MiB/sec", + (((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64() + ); } } diff --git a/core-crypto/src/zssp.rs b/core-crypto/src/zssp.rs index 09641583a..ab29fcd9b 100644 --- a/core-crypto/src/zssp.rs +++ b/core-crypto/src/zssp.rs @@ -57,7 +57,6 @@ const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identi const HEADER_SIZE: usize = 16; const HEADER_CHECK_SIZE: usize = 4; const AES_GCM_TAG_SIZE: usize = 16; -const AES_GCM_NONCE_SIZE: usize = 12; const AES_GCM_NONCE_START: usize = 4; const AES_GCM_NONCE_END: usize = 16; const HMAC_SIZE: usize = 48; @@ -65,7 +64,7 @@ const SESSION_ID_SIZE: usize = 6; const KEY_HISTORY_SIZE_MAX: usize = 3; const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; -const KBKDF_KEY_USAGE_LABEL_HEADER_MAC: u8 = b'H'; +const KBKDF_KEY_USAGE_LABEL_HEADER_CHECK: u8 = b'H'; const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; @@ -286,10 +285,10 @@ impl Session { if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) { if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) { let send_counter = Counter::new(); - let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); + let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>()); let remote_s_public_hash = SHA384::hash(remote_s_public); - let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); - if let Ok(offer) = EphemeralOffer::create_alice_offer( + let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>()); + if let Ok(offer) = create_initial_offer( &mut send, send_counter.next(), local_session_id, @@ -337,22 +336,27 @@ impl Session { send_with_fragmentation_init_header(mtu_buffer, packet_len, mtu_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), counter); let mut c = key.get_send_cipher(counter)?; - c.init(mtu_buffer); + c.init(&mtu_buffer[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); if packet_len > mtu_buffer.len() { let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap(); let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE; let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); loop { - debug_assert!(data.len() > last_fragment_data_mtu); - c.crypt(&data[HEADER_CHECK_SIZE..fragment_data_mtu], &mut mtu_buffer[HEADER_SIZE..]); + let fragment_data_size = fragment_data_mtu.min(data.len()); + let fragment_size = fragment_data_size + HEADER_SIZE; + c.crypt(&data[..fragment_data_size], &mut mtu_buffer[HEADER_SIZE..fragment_size]); + data = &data[fragment_data_size..]; + let hc = header_check(mtu_buffer, &self.header_check_cipher); mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes()); - send(mtu_buffer); - data = &data[fragment_data_mtu..]; + + send(&mut mtu_buffer[..fragment_size]); + debug_assert!(header[7].wrapping_shr(2) < 63); header[7] += 0x04; // increment fragment number mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].copy_from_slice(&header); + if data.len() <= last_fragment_data_mtu { break; } @@ -362,9 +366,11 @@ impl Session { let gcm_tag_idx = data.len() + HEADER_SIZE; c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]); + mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish_encrypt()); + let hc = header_check(mtu_buffer, &self.header_check_cipher); mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes()); - mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish()); + send(&mut mtu_buffer[..packet_len]); key.return_send_cipher(c); @@ -375,13 +381,18 @@ impl Session { return Err(Error::SessionNotEstablished); } + pub fn established(&self) -> bool { + let state = self.state.read(); + state.remote_session_id.is_some() && !state.keys.is_empty() + } + #[inline] pub fn rekey_check(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) { let state = self.state.upgradable_read(); if let Some(key) = state.keys.front() { if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) { if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) { - if let Ok(offer) = EphemeralOffer::create_alice_offer( + if let Ok(offer) = create_initial_offer( &mut send, self.send_counter.next(), self.id, @@ -409,7 +420,7 @@ impl ReceiveContext { pub fn new(host: &H) -> Self { Self { initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), - incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()), + incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>()), } } @@ -536,10 +547,9 @@ impl ReceiveContext { } c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]); - let tag = c.finish(); + let ok = c.finish_decrypt(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]); key.return_receive_cipher(c); - - if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) { + if ok { // Drop obsolete keys if we had to iterate past the first key to get here. if key_index > 0 { unlikely_branch(); @@ -627,8 +637,7 @@ impl ReceiveContext { let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false); c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); - let c = c.finish(); - if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { + if !c.finish_decrypt(&incoming_packet[payload_end..aes_gcm_tag_end]) { return Err(Error::FailedAuthentication); } @@ -665,7 +674,7 @@ impl ReceiveContext { None } else { if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) { - let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); + let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>()); Some(Session:: { id: new_session_id, associated_object, @@ -736,7 +745,7 @@ impl ReceiveContext { let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true); c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]); - let c = c.finish(); + let c = c.finish_encrypt(); reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c); reply_len += AES_GCM_TAG_SIZE; @@ -791,8 +800,7 @@ impl ReceiveContext { let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false); c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); - let c = c.finish(); - if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { + if !c.finish_decrypt(&incoming_packet[payload_end..aes_gcm_tag_end]) { return Err(Error::FailedAuthentication); } @@ -831,7 +839,7 @@ impl ReceiveContext { let mut c = key.get_send_cipher(counter)?; c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); - reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish()); + reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); key.return_send_cipher(c); let hc = header_check(&reply_buf, &session.header_check_cipher); @@ -928,97 +936,96 @@ struct EphemeralOffer { alice_e1_keypair: Option, } -impl EphemeralOffer { - fn create_alice_offer( - send: &mut SendFunction, - counter: CounterValue, - alice_session_id: SessionId, - bob_session_id: Option, - alice_s_public: &[u8], - alice_metadata: &[u8], - bob_s_public_p384: &P384PublicKey, - bob_s_public_hash: &[u8], - ss: &Secret<48>, - header_check_cipher: &Aes, - mtu: usize, - current_time: i64, - jedi: bool, - ) -> Result { - let alice_e0_keypair = P384KeyPair::generate(); - let e0s = alice_e0_keypair.agree(bob_s_public_p384); - if e0s.is_none() { - return Err(Error::InvalidPacket); +fn create_initial_offer( + send: &mut SendFunction, + counter: CounterValue, + alice_session_id: SessionId, + bob_session_id: Option, + alice_s_public: &[u8], + alice_metadata: &[u8], + bob_s_public_p384: &P384PublicKey, + bob_s_public_hash: &[u8], + ss: &Secret<48>, + header_check_cipher: &Aes, + mtu: usize, + current_time: i64, + jedi: bool, +) -> Result { + let alice_e0_keypair = P384KeyPair::generate(); + let e0s = alice_e0_keypair.agree(bob_s_public_p384); + if e0s.is_none() { + return Err(Error::InvalidPacket); + } + + let alice_e1_keypair = if jedi { + Some(pqc_kyber::keypair(&mut random::SecureRandom::get())) + } else { + None + }; + + const PACKET_BUF_SIZE: usize = MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS; + let mut packet_buf = [0_u8; PACKET_BUF_SIZE]; + let mut packet_len = { + let mut p = &mut packet_buf[HEADER_SIZE..]; + + p.write_all(&[SESSION_PROTOCOL_VERSION])?; + p.write_all(alice_e0_keypair.public_key_bytes())?; + + p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?; + varint::write(&mut p, alice_s_public.len() as u64)?; + p.write_all(alice_s_public)?; + varint::write(&mut p, alice_metadata.len() as u64)?; + p.write_all(alice_metadata)?; + if let Some(e1kp) = alice_e1_keypair { + p.write_all(&[E1_TYPE_KYBER1024])?; + p.write_all(&e1kp.public)?; + } else { + p.write_all(&[E1_TYPE_NONE])?; } - let alice_e1_keypair = if jedi { - Some(pqc_kyber::keypair(&mut random::SecureRandom::get())) - } else { - None - }; + PACKET_BUF_SIZE - p.len() + }; - const PACKET_BUF_SIZE: usize = MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS; - let mut packet_buf = [0_u8; PACKET_BUF_SIZE]; - let mut packet_len = { - let mut p = &mut packet_buf[HEADER_SIZE..]; + send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter); - p.write_all(&[SESSION_PROTOCOL_VERSION])?; - p.write_all(alice_e0_keypair.public_key_bytes())?; + let key = Secret(hmac_sha512( + &hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()), + e0s.unwrap().as_bytes(), + )); - p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?; - varint::write(&mut p, alice_s_public.len() as u64)?; - p.write_all(alice_s_public)?; - varint::write(&mut p, alice_metadata.len() as u64)?; - p.write_all(alice_metadata)?; - if let Some(e1kp) = alice_e1_keypair { - p.write_all(&[E1_TYPE_KYBER1024])?; - p.write_all(&e1kp.public)?; - } else { - p.write_all(&[E1_TYPE_NONE])?; - } + let gcm_tag = { + let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true); + c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); + c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]); + c.finish_encrypt() + }; + packet_buf[packet_len..(packet_len + AES_GCM_TAG_SIZE)].copy_from_slice(&gcm_tag); + packet_len += AES_GCM_TAG_SIZE; - PACKET_BUF_SIZE - p.len() - }; + let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); - send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter); + let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]); + packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); + packet_len += HMAC_SIZE; - let key = Secret(hmac_sha512( - &hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()), - e0s.unwrap().as_bytes(), - )); + let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]); + packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); + packet_len += HMAC_SIZE; - let gcm_tag = { - let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true); - c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); - c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]); - c.finish() - }; - packet_buf[packet_len..(packet_len + AES_GCM_TAG_SIZE)].copy_from_slice(&gcm_tag); - packet_len += AES_GCM_TAG_SIZE; + send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher); - let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); - - let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]); - packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); - packet_len += HMAC_SIZE; - - let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]); - packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); - packet_len += HMAC_SIZE; - - send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher); - - Ok(EphemeralOffer { - creation_time: current_time, - key, - alice_e0_keypair, - alice_e1_keypair, - }) - } + Ok(EphemeralOffer { + creation_time: current_time, + key, + alice_e0_keypair, + alice_e1_keypair, + }) } +/// Create a header to send a packet with optional fragmentation. #[inline(always)] fn send_with_fragmentation_init_header(header: &mut [u8], packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) { - let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize; + let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize; debug_assert!(mtu >= MIN_MTU); debug_assert!(packet_len >= HEADER_SIZE); debug_assert!(fragment_count <= MAX_FRAGMENTS); @@ -1068,6 +1075,13 @@ fn header_check(packet: &[u8], header_check_cipher: &Aes) -> u32 { /// Add a new session key to the key list, retiring older non-active keys if necessary. fn add_key(keys: &mut LinkedList, key: SessionKey) { + // Sanity check to make sure duplicates can't get in here. Should be impossible. + for k in keys.iter() { + if k.receive_key.eq(&key.receive_key) { + return; + } + } + debug_assert!(KEY_HISTORY_SIZE_MAX >= 2); while keys.len() >= KEY_HISTORY_SIZE_MAX { let current = keys.pop_front().unwrap(); @@ -1132,7 +1146,7 @@ struct SessionKey { receive_cipher_pool: Mutex>>, send_cipher_pool: Mutex>>, role: Role, - jedi: bool, // true if kyber was used + jedi: bool, // true if kyber was enabled on both sides } impl SessionKey { @@ -1266,13 +1280,16 @@ mod tests { #[test] fn establish_session() { let jedi = true; + + let mut data_buf = [0_u8; (1280 - 32) * MAX_FRAGMENTS]; + let mut mtu_buffer = [0_u8; 1280]; let mut psk: Secret<64> = Secret::default(); random::fill_bytes_secure(&mut psk.0); + let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob")); let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice")); let alice_rc: Box>> = Box::new(ReceiveContext::new(&alice_host)); let bob_rc: Box>> = Box::new(ReceiveContext::new(&bob_host)); - let mut data_buf = [0_u8; 4096]; //println!("zssp: size of session (bytes): {}", std::mem::size_of::>>()); @@ -1285,7 +1302,7 @@ mod tests { &[], &psk, 1, - 1280, + mtu_buffer.len(), 1, jedi, ) @@ -1293,7 +1310,7 @@ mod tests { )); let mut ts = 0; - for _ in 0..256 { + for _ in 0..4 { for host in [&alice_host, &bob_host] { let send_to_other = |data: &mut [u8]| { if std::ptr::eq(host, &alice_host) { @@ -1313,15 +1330,16 @@ mod tests { if let Some(qi) = host.queue.lock().pop_back() { let qi_len = qi.len(); ts += 1; - let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, jedi, ts); + let r = rc.receive(host, send_to_other, &mut data_buf, qi, mtu_buffer.len(), jedi, ts); if r.is_ok() { let r = r.unwrap(); match r { ReceiveResult::Ok => { - println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len); + //println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len); } ReceiveResult::OkData(data) => { - println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len()); + //println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len()); + assert!(!data.iter().any(|x| *x != 0x12)); } ReceiveResult::OkNewSession(new_session) => { println!("zssp: {} => {} ({}): OkNewSession ({})", host.other_name, host.this_name, qi_len, u64::from(new_session.id)); @@ -1341,6 +1359,15 @@ mod tests { break; } } + + data_buf.fill(0x12); + if let Some(session) = host.session.lock().as_ref().cloned() { + if session.established() { + for dl in 0..data_buf.len() { + assert!(session.send(send_to_other, &mut mtu_buffer, &data_buf[..dl]).is_ok()); + } + } + } } } }