Fix OpenSSL AES, and ZSSP now passes some pretty solid tests.

This commit is contained in:
Adam Ierymenko 2022-09-07 16:08:51 -04:00
parent 42f6f016e9
commit b4e17b38e8
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
2 changed files with 228 additions and 150 deletions

View file

@ -22,7 +22,20 @@ mod fruit_flavored {
const kCCOptionECBMode: i32 = 2;
extern "C" {
fn CCCryptorCreateWithMode(op: i32, mode: i32, alg: i32, padding: i32, iv: *const c_void, key: *const c_void, key_len: usize, tweak: *const c_void, tweak_len: usize, num_rounds: c_int, options: i32, cryyptor_ref: *mut *mut c_void) -> i32;
fn CCCryptorCreateWithMode(
op: i32,
mode: i32,
alg: i32,
padding: i32,
iv: *const c_void,
key: *const c_void,
key_len: usize,
tweak: *const c_void,
tweak_len: usize,
num_rounds: c_int,
options: i32,
cryyptor_ref: *mut *mut c_void,
) -> i32;
fn CCCryptorUpdate(cryptor_ref: *mut c_void, data_in: *const c_void, data_in_len: usize, data_out: *mut c_void, data_out_len: usize, data_out_written: *mut usize) -> i32;
//fn CCCryptorReset(cryptor_ref: *mut c_void, iv: *const c_void) -> i32;
fn CCCryptorRelease(cryptor_ref: *mut c_void) -> i32;
@ -53,8 +66,40 @@ mod fruit_flavored {
panic!("AES supports 128, 192, or 256 bits keys");
}
let mut aes: Self = std::mem::zeroed();
assert_eq!(CCCryptorCreateWithMode(kCCEncrypt, kCCModeECB, kCCAlgorithmAES, 0, null(), k.as_ptr().cast(), k.len(), null(), 0, 0, kCCOptionECBMode, &mut aes.0), 0);
assert_eq!(CCCryptorCreateWithMode(kCCDecrypt, kCCModeECB, kCCAlgorithmAES, 0, null(), k.as_ptr().cast(), k.len(), null(), 0, 0, kCCOptionECBMode, &mut aes.1), 0);
assert_eq!(
CCCryptorCreateWithMode(
kCCEncrypt,
kCCModeECB,
kCCAlgorithmAES,
0,
null(),
k.as_ptr().cast(),
k.len(),
null(),
0,
0,
kCCOptionECBMode,
&mut aes.0
),
0
);
assert_eq!(
CCCryptorCreateWithMode(
kCCDecrypt,
kCCModeECB,
kCCAlgorithmAES,
0,
null(),
k.as_ptr().cast(),
k.len(),
null(),
0,
0,
kCCOptionECBMode,
&mut aes.1
),
0
);
aes
}
}
@ -189,7 +234,7 @@ mod fruit_flavored {
}
#[inline(always)]
pub fn finish(&mut self) -> [u8; 16] {
pub fn finish_encrypt(&mut self) -> [u8; 16] {
let mut tag = 0_u128.to_ne_bytes();
unsafe {
let mut tag_len = 16;
@ -200,30 +245,23 @@ mod fruit_flavored {
}
tag
}
#[inline(always)]
pub fn finish_decrypt(&mut self, expected_tag: &[u8]) -> bool {
self.finish_encrypt().eq(expected_tag)
}
}
unsafe impl Send for AesGcm {}
}
#[cfg(not(target_os = "macos"))]
mod openssl {
mod openssl_aes {
use crate::secret::Secret;
use openssl::symm::{Cipher, Crypter, Mode};
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
#[inline(always)]
fn aes_ctr_by_key_size(ks: usize) -> Cipher {
match ks {
16 => Cipher::aes_128_ctr(),
24 => Cipher::aes_192_ctr(),
32 => Cipher::aes_256_ctr(),
_ => {
panic!("AES supports 128, 192, or 256 bits keys");
}
}
}
#[inline(always)]
fn aes_gcm_by_key_size(ks: usize) -> Cipher {
match ks {
16 => Cipher::aes_128_gcm(),
@ -235,7 +273,6 @@ mod openssl {
}
}
#[inline(always)]
fn aes_ecb_by_key_size(ks: usize) -> Cipher {
match ks {
16 => Cipher::aes_128_ecb(),
@ -250,9 +287,11 @@ mod openssl {
pub struct Aes(UnsafeCell<Crypter>, UnsafeCell<Crypter>);
impl Aes {
#[inline(always)]
pub fn new(k: &[u8]) -> Self {
let (mut c, mut d) = (Crypter::new(aes_ecb_by_key_size(k.len()), Mode::Encrypt, k, None).unwrap(), Crypter::new(aes_ecb_by_key_size(k.len()), Mode::Decrypt, k, None).unwrap());
let (mut c, mut d) = (
Crypter::new(aes_ecb_by_key_size(k.len()), Mode::Encrypt, k, None).unwrap(),
Crypter::new(aes_ecb_by_key_size(k.len()), Mode::Decrypt, k, None).unwrap(),
);
c.pad(false);
d.pad(false);
Self(UnsafeCell::new(c), UnsafeCell::new(d))
@ -260,7 +299,7 @@ mod openssl {
#[inline(always)]
pub fn encrypt_block(&self, plaintext: &[u8], ciphertext: &mut [u8]) {
let mut tmp = [0_u8; 32];
let mut tmp: [u8; 32] = unsafe { MaybeUninit::uninit().assume_init() };
let c: &mut Crypter = unsafe { &mut *self.0.get() };
if c.update(plaintext, &mut tmp).unwrap() != 16 {
assert_eq!(c.finalize(&mut tmp).unwrap(), 16);
@ -270,7 +309,7 @@ mod openssl {
#[inline(always)]
pub fn encrypt_block_in_place(&self, data: &mut [u8]) {
let mut tmp = [0_u8; 32];
let mut tmp: [u8; 32] = unsafe { MaybeUninit::uninit().assume_init() };
let c: &mut Crypter = unsafe { &mut *self.0.get() };
if c.update(data, &mut tmp).unwrap() != 16 {
assert_eq!(c.finalize(&mut tmp).unwrap(), 16);
@ -280,9 +319,9 @@ mod openssl {
#[inline(always)]
pub fn decrypt_block(&self, ciphertext: &[u8], plaintext: &mut [u8]) {
let mut tmp = [0_u8; 32];
let mut tmp: [u8; 32] = unsafe { MaybeUninit::uninit().assume_init() };
let c: &mut Crypter = unsafe { &mut *self.1.get() };
if c.update(plaintext, &mut tmp).unwrap() != 16 {
if c.update(ciphertext, &mut tmp).unwrap() != 16 {
assert_eq!(c.finalize(&mut tmp).unwrap(), 16);
}
plaintext[..16].copy_from_slice(&tmp[..16]);
@ -290,7 +329,7 @@ mod openssl {
#[inline(always)]
pub fn decrypt_block_in_place(&self, data: &mut [u8]) {
let mut tmp = [0_u8; 32];
let mut tmp: [u8; 32] = unsafe { MaybeUninit::uninit().assume_init() };
let c: &mut Crypter = unsafe { &mut *self.1.get() };
if c.update(data, &mut tmp).unwrap() != 16 {
assert_eq!(c.finalize(&mut tmp).unwrap(), 16);
@ -307,7 +346,6 @@ mod openssl {
impl AesGcm {
/// Construct a new AES-GCM cipher.
/// Key must be 16, 24, or 32 bytes in length or a panic will occur.
#[inline(always)]
pub fn new(k: &[u8], encrypt: bool) -> Self {
let mut s: Secret<32> = Secret::default();
match k.len() {
@ -323,9 +361,7 @@ mod openssl {
/// Initialize AES-CTR for encryption or decryption with the given IV.
/// If it's already been used, this also resets the cipher. There is no separate reset.
#[inline(always)]
pub fn init(&mut self, iv: &[u8]) {
assert_eq!(iv.len(), 12);
let mut c = Crypter::new(
aes_gcm_by_key_size(self.1),
if self.3 {
@ -338,37 +374,46 @@ mod openssl {
)
.unwrap();
c.pad(false);
let _ = c.set_tag_len(16);
let _ = self.2.replace(c);
}
#[inline(always)]
pub fn aad(&mut self, aad: &[u8]) {
let _ = self.2.as_mut().unwrap().aad_update(aad);
assert!(self.2.as_mut().unwrap().aad_update(aad).is_ok());
}
/// Encrypt or decrypt (same operation with CTR mode)
#[inline(always)]
pub fn crypt(&mut self, input: &[u8], output: &mut [u8]) {
let _ = self.2.as_mut().unwrap().update(input, output);
assert!(self.2.as_mut().unwrap().update(input, output).is_ok());
}
/// Encrypt or decrypt in place (same operation with CTR mode)
#[inline(always)]
pub fn crypt_in_place(&mut self, data: &mut [u8]) {
let _ = self.2.as_mut().unwrap().update(unsafe { &*std::slice::from_raw_parts(data.as_ptr(), data.len()) }, data);
assert!(self.2.as_mut().unwrap().update(unsafe { &*std::slice::from_raw_parts(data.as_ptr(), data.len()) }, data).is_ok());
}
#[inline(always)]
pub fn finish(&mut self) -> [u8; 16] {
pub fn finish_encrypt(&mut self) -> [u8; 16] {
let mut tag = [0_u8; 16];
let c = self.2.as_mut().unwrap();
if c.finalize(&mut []).is_ok() {
if !c.get_tag(&mut tag).is_ok() {
tag.fill(0);
}
}
let mut c = self.2.take().unwrap();
assert!(c.finalize(&mut tag).is_ok());
assert!(c.get_tag(&mut tag).is_ok());
tag
}
#[inline(always)]
pub fn finish_decrypt(&mut self, expected_tag: &[u8]) -> bool {
let mut c = self.2.take().unwrap();
if c.set_tag(expected_tag).is_ok() {
let result = c.finalize(&mut []).is_ok();
result
} else {
false
}
}
}
unsafe impl Send for AesGcm {}
@ -378,7 +423,7 @@ mod openssl {
pub use fruit_flavored::{Aes, AesGcm};
#[cfg(not(target_os = "macos"))]
pub use openssl::{Aes, AesGcm};
pub use openssl_aes::{Aes, AesGcm};
#[cfg(test)]
mod tests {
@ -391,7 +436,7 @@ mod tests {
for i in 1..12345 {
buf[i] = i as u8;
}
let iv = [1_u8; 16];
let iv = [1_u8; 12];
let mut c = AesGcm::new(&[1_u8; 32], true);
@ -402,7 +447,10 @@ mod tests {
c.crypt_in_place(&mut buf);
}
let duration = SystemTime::now().duration_since(start).unwrap();
println!("AES-256-GCM encrypt benchmark: {} MiB/sec", (((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64());
println!(
"AES-256-GCM encrypt benchmark: {} MiB/sec",
(((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64()
);
let mut c = AesGcm::new(&[1_u8; 32], false);
@ -412,6 +460,9 @@ mod tests {
c.crypt_in_place(&mut buf);
}
let duration = SystemTime::now().duration_since(start).unwrap();
println!("AES-256-GCM decrypt benchmark: {} MiB/sec", (((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64());
println!(
"AES-256-GCM decrypt benchmark: {} MiB/sec",
(((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64()
);
}
}

View file

@ -57,7 +57,6 @@ const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identi
const HEADER_SIZE: usize = 16;
const HEADER_CHECK_SIZE: usize = 4;
const AES_GCM_TAG_SIZE: usize = 16;
const AES_GCM_NONCE_SIZE: usize = 12;
const AES_GCM_NONCE_START: usize = 4;
const AES_GCM_NONCE_END: usize = 16;
const HMAC_SIZE: usize = 48;
@ -65,7 +64,7 @@ const SESSION_ID_SIZE: usize = 6;
const KEY_HISTORY_SIZE_MAX: usize = 3;
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
const KBKDF_KEY_USAGE_LABEL_HEADER_MAC: u8 = b'H';
const KBKDF_KEY_USAGE_LABEL_HEADER_CHECK: u8 = b'H';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
@ -286,10 +285,10 @@ impl<H: Host> Session<H> {
if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) {
if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) {
let send_counter = Counter::new();
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>());
let remote_s_public_hash = SHA384::hash(remote_s_public);
let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
if let Ok(offer) = EphemeralOffer::create_alice_offer(
let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>());
if let Ok(offer) = create_initial_offer(
&mut send,
send_counter.next(),
local_session_id,
@ -337,22 +336,27 @@ impl<H: Host> Session<H> {
send_with_fragmentation_init_header(mtu_buffer, packet_len, mtu_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), counter);
let mut c = key.get_send_cipher(counter)?;
c.init(mtu_buffer);
c.init(&mtu_buffer[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
if packet_len > mtu_buffer.len() {
let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap();
let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE;
let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
loop {
debug_assert!(data.len() > last_fragment_data_mtu);
c.crypt(&data[HEADER_CHECK_SIZE..fragment_data_mtu], &mut mtu_buffer[HEADER_SIZE..]);
let fragment_data_size = fragment_data_mtu.min(data.len());
let fragment_size = fragment_data_size + HEADER_SIZE;
c.crypt(&data[..fragment_data_size], &mut mtu_buffer[HEADER_SIZE..fragment_size]);
data = &data[fragment_data_size..];
let hc = header_check(mtu_buffer, &self.header_check_cipher);
mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
send(mtu_buffer);
data = &data[fragment_data_mtu..];
send(&mut mtu_buffer[..fragment_size]);
debug_assert!(header[7].wrapping_shr(2) < 63);
header[7] += 0x04; // increment fragment number
mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].copy_from_slice(&header);
if data.len() <= last_fragment_data_mtu {
break;
}
@ -362,9 +366,11 @@ impl<H: Host> Session<H> {
let gcm_tag_idx = data.len() + HEADER_SIZE;
c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]);
mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish_encrypt());
let hc = header_check(mtu_buffer, &self.header_check_cipher);
mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes());
mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish());
send(&mut mtu_buffer[..packet_len]);
key.return_send_cipher(c);
@ -375,13 +381,18 @@ impl<H: Host> Session<H> {
return Err(Error::SessionNotEstablished);
}
pub fn established(&self) -> bool {
let state = self.state.read();
state.remote_session_id.is_some() && !state.keys.is_empty()
}
#[inline]
pub fn rekey_check<SendFunction: FnMut(&mut [u8])>(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) {
let state = self.state.upgradable_read();
if let Some(key) = state.keys.front() {
if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) {
if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) {
if let Ok(offer) = EphemeralOffer::create_alice_offer(
if let Ok(offer) = create_initial_offer(
&mut send,
self.send_counter.next(),
self.id,
@ -409,7 +420,7 @@ impl<H: Host> ReceiveContext<H> {
pub fn new(host: &H) -> Self {
Self {
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()),
incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>()),
}
}
@ -536,10 +547,9 @@ impl<H: Host> ReceiveContext<H> {
}
c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
let tag = c.finish();
let ok = c.finish_decrypt(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]);
key.return_receive_cipher(c);
if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) {
if ok {
// Drop obsolete keys if we had to iterate past the first key to get here.
if key_index > 0 {
unlikely_branch();
@ -627,8 +637,7 @@ impl<H: Host> ReceiveContext<H> {
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false);
c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
let c = c.finish();
if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) {
if !c.finish_decrypt(&incoming_packet[payload_end..aes_gcm_tag_end]) {
return Err(Error::FailedAuthentication);
}
@ -665,7 +674,7 @@ impl<H: Host> ReceiveContext<H> {
None
} else {
if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) {
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>());
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>());
Some(Session::<H> {
id: new_session_id,
associated_object,
@ -736,7 +745,7 @@ impl<H: Host> ReceiveContext<H> {
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true);
c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]);
let c = c.finish();
let c = c.finish_encrypt();
reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c);
reply_len += AES_GCM_TAG_SIZE;
@ -791,8 +800,7 @@ impl<H: Host> ReceiveContext<H> {
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false);
c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
let c = c.finish();
if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) {
if !c.finish_decrypt(&incoming_packet[payload_end..aes_gcm_tag_end]) {
return Err(Error::FailedAuthentication);
}
@ -831,7 +839,7 @@ impl<H: Host> ReceiveContext<H> {
let mut c = key.get_send_cipher(counter)?;
c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish());
reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt());
key.return_send_cipher(c);
let hc = header_check(&reply_buf, &session.header_check_cipher);
@ -928,97 +936,96 @@ struct EphemeralOffer {
alice_e1_keypair: Option<pqc_kyber::Keypair>,
}
impl EphemeralOffer {
fn create_alice_offer<SendFunction: FnMut(&mut [u8])>(
send: &mut SendFunction,
counter: CounterValue,
alice_session_id: SessionId,
bob_session_id: Option<SessionId>,
alice_s_public: &[u8],
alice_metadata: &[u8],
bob_s_public_p384: &P384PublicKey,
bob_s_public_hash: &[u8],
ss: &Secret<48>,
header_check_cipher: &Aes,
mtu: usize,
current_time: i64,
jedi: bool,
) -> Result<EphemeralOffer, Error> {
let alice_e0_keypair = P384KeyPair::generate();
let e0s = alice_e0_keypair.agree(bob_s_public_p384);
if e0s.is_none() {
return Err(Error::InvalidPacket);
fn create_initial_offer<SendFunction: FnMut(&mut [u8])>(
send: &mut SendFunction,
counter: CounterValue,
alice_session_id: SessionId,
bob_session_id: Option<SessionId>,
alice_s_public: &[u8],
alice_metadata: &[u8],
bob_s_public_p384: &P384PublicKey,
bob_s_public_hash: &[u8],
ss: &Secret<48>,
header_check_cipher: &Aes,
mtu: usize,
current_time: i64,
jedi: bool,
) -> Result<EphemeralOffer, Error> {
let alice_e0_keypair = P384KeyPair::generate();
let e0s = alice_e0_keypair.agree(bob_s_public_p384);
if e0s.is_none() {
return Err(Error::InvalidPacket);
}
let alice_e1_keypair = if jedi {
Some(pqc_kyber::keypair(&mut random::SecureRandom::get()))
} else {
None
};
const PACKET_BUF_SIZE: usize = MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS;
let mut packet_buf = [0_u8; PACKET_BUF_SIZE];
let mut packet_len = {
let mut p = &mut packet_buf[HEADER_SIZE..];
p.write_all(&[SESSION_PROTOCOL_VERSION])?;
p.write_all(alice_e0_keypair.public_key_bytes())?;
p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?;
varint::write(&mut p, alice_s_public.len() as u64)?;
p.write_all(alice_s_public)?;
varint::write(&mut p, alice_metadata.len() as u64)?;
p.write_all(alice_metadata)?;
if let Some(e1kp) = alice_e1_keypair {
p.write_all(&[E1_TYPE_KYBER1024])?;
p.write_all(&e1kp.public)?;
} else {
p.write_all(&[E1_TYPE_NONE])?;
}
let alice_e1_keypair = if jedi {
Some(pqc_kyber::keypair(&mut random::SecureRandom::get()))
} else {
None
};
PACKET_BUF_SIZE - p.len()
};
const PACKET_BUF_SIZE: usize = MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS;
let mut packet_buf = [0_u8; PACKET_BUF_SIZE];
let mut packet_len = {
let mut p = &mut packet_buf[HEADER_SIZE..];
send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter);
p.write_all(&[SESSION_PROTOCOL_VERSION])?;
p.write_all(alice_e0_keypair.public_key_bytes())?;
let key = Secret(hmac_sha512(
&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()),
e0s.unwrap().as_bytes(),
));
p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?;
varint::write(&mut p, alice_s_public.len() as u64)?;
p.write_all(alice_s_public)?;
varint::write(&mut p, alice_metadata.len() as u64)?;
p.write_all(alice_metadata)?;
if let Some(e1kp) = alice_e1_keypair {
p.write_all(&[E1_TYPE_KYBER1024])?;
p.write_all(&e1kp.public)?;
} else {
p.write_all(&[E1_TYPE_NONE])?;
}
let gcm_tag = {
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true);
c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]);
c.finish_encrypt()
};
packet_buf[packet_len..(packet_len + AES_GCM_TAG_SIZE)].copy_from_slice(&gcm_tag);
packet_len += AES_GCM_TAG_SIZE;
PACKET_BUF_SIZE - p.len()
};
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter);
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE;
let key = Secret(hmac_sha512(
&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()),
e0s.unwrap().as_bytes(),
));
let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE;
let gcm_tag = {
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true);
c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]);
c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]);
c.finish()
};
packet_buf[packet_len..(packet_len + AES_GCM_TAG_SIZE)].copy_from_slice(&gcm_tag);
packet_len += AES_GCM_TAG_SIZE;
send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher);
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE;
let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]);
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE;
send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher);
Ok(EphemeralOffer {
creation_time: current_time,
key,
alice_e0_keypair,
alice_e1_keypair,
})
}
Ok(EphemeralOffer {
creation_time: current_time,
key,
alice_e0_keypair,
alice_e1_keypair,
})
}
/// Create a header to send a packet with optional fragmentation.
#[inline(always)]
fn send_with_fragmentation_init_header(header: &mut [u8], packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) {
let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
let fragment_count = ((packet_len as f32) / (mtu - HEADER_SIZE) as f32).ceil() as usize;
debug_assert!(mtu >= MIN_MTU);
debug_assert!(packet_len >= HEADER_SIZE);
debug_assert!(fragment_count <= MAX_FRAGMENTS);
@ -1068,6 +1075,13 @@ fn header_check(packet: &[u8], header_check_cipher: &Aes) -> u32 {
/// Add a new session key to the key list, retiring older non-active keys if necessary.
fn add_key(keys: &mut LinkedList<SessionKey>, key: SessionKey) {
// Sanity check to make sure duplicates can't get in here. Should be impossible.
for k in keys.iter() {
if k.receive_key.eq(&key.receive_key) {
return;
}
}
debug_assert!(KEY_HISTORY_SIZE_MAX >= 2);
while keys.len() >= KEY_HISTORY_SIZE_MAX {
let current = keys.pop_front().unwrap();
@ -1132,7 +1146,7 @@ struct SessionKey {
receive_cipher_pool: Mutex<Vec<Box<AesGcm>>>,
send_cipher_pool: Mutex<Vec<Box<AesGcm>>>,
role: Role,
jedi: bool, // true if kyber was used
jedi: bool, // true if kyber was enabled on both sides
}
impl SessionKey {
@ -1266,13 +1280,16 @@ mod tests {
#[test]
fn establish_session() {
let jedi = true;
let mut data_buf = [0_u8; (1280 - 32) * MAX_FRAGMENTS];
let mut mtu_buffer = [0_u8; 1280];
let mut psk: Secret<64> = Secret::default();
random::fill_bytes_secure(&mut psk.0);
let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob"));
let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice"));
let alice_rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new(&alice_host));
let bob_rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new(&bob_host));
let mut data_buf = [0_u8; 4096];
//println!("zssp: size of session (bytes): {}", std::mem::size_of::<Session<Box<TestHost>>>());
@ -1285,7 +1302,7 @@ mod tests {
&[],
&psk,
1,
1280,
mtu_buffer.len(),
1,
jedi,
)
@ -1293,7 +1310,7 @@ mod tests {
));
let mut ts = 0;
for _ in 0..256 {
for _ in 0..4 {
for host in [&alice_host, &bob_host] {
let send_to_other = |data: &mut [u8]| {
if std::ptr::eq(host, &alice_host) {
@ -1313,15 +1330,16 @@ mod tests {
if let Some(qi) = host.queue.lock().pop_back() {
let qi_len = qi.len();
ts += 1;
let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, jedi, ts);
let r = rc.receive(host, send_to_other, &mut data_buf, qi, mtu_buffer.len(), jedi, ts);
if r.is_ok() {
let r = r.unwrap();
match r {
ReceiveResult::Ok => {
println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len);
//println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len);
}
ReceiveResult::OkData(data) => {
println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len());
//println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len());
assert!(!data.iter().any(|x| *x != 0x12));
}
ReceiveResult::OkNewSession(new_session) => {
println!("zssp: {} => {} ({}): OkNewSession ({})", host.other_name, host.this_name, qi_len, u64::from(new_session.id));
@ -1341,6 +1359,15 @@ mod tests {
break;
}
}
data_buf.fill(0x12);
if let Some(session) = host.session.lock().as_ref().cloned() {
if session.established() {
for dl in 0..data_buf.len() {
assert!(session.send(send_to_other, &mut mtu_buffer, &data_buf[..dl]).is_ok());
}
}
}
}
}
}