ZSSP comments and cleanup.

This commit is contained in:
Adam Ierymenko 2022-12-01 12:03:52 -05:00
parent 580496cbd7
commit e433b670fc

View file

@ -82,6 +82,9 @@ const E1_TYPE_KYBER1024: u8 = 1;
/// Size of packet header /// Size of packet header
const HEADER_SIZE: usize = 16; const HEADER_SIZE: usize = 16;
/// Size of AES-GCM keys (256 bits)
const AES_KEY_SIZE: usize = 32;
/// Size of AES-GCM MAC tags /// Size of AES-GCM MAC tags
const AES_GCM_TAG_SIZE: usize = 16; const AES_GCM_TAG_SIZE: usize = 16;
@ -109,6 +112,9 @@ const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; // AES-GCM in A->B
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; // AES-GCM in B->A direction const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; // AES-GCM in B->A direction
const KBKDF_KEY_USAGE_LABEL_RATCHETING: u8 = b'R'; // Key input for next ephemeral ratcheting const KBKDF_KEY_USAGE_LABEL_RATCHETING: u8 = b'R'; // Key input for next ephemeral ratcheting
// AES key size for header check code generation
const HEADER_CHECK_AES_KEY_SIZE: usize = 16;
/// Aribitrary starting value for master key derivation. /// Aribitrary starting value for master key derivation.
/// ///
/// It doesn't matter very much what this is but it's good for it to be unique. It should /// It doesn't matter very much what this is but it's good for it to be unique. It should
@ -214,7 +220,7 @@ pub enum ReceiveResult<'a, H: Host> {
Ignored, Ignored,
} }
/// 48-bit session ID (most significant 24 bits of u64 are unused) /// 48-bit session ID (most significant 16 bits of u64 are unused)
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(transparent)] #[repr(transparent)]
pub struct SessionId(u64); pub struct SessionId(u64);
@ -347,7 +353,7 @@ pub struct Session<H: Host> {
remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 16, 4>>, defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 8, 8>>,
} }
struct SessionMutableState { struct SessionMutableState {
@ -359,8 +365,9 @@ struct SessionMutableState {
} }
impl<H: Host> Session<H> { impl<H: Host> Session<H> {
/// Create a new session and send the first key offer message. /// Create a new session and send an initial key offer message to the other end.
/// ///
/// * `host` - Interface to application using ZSSP
/// * `local_session_id` - ID for this side of the session, must be locally unique /// * `local_session_id` - ID for this side of the session, must be locally unique
/// * `remote_s_public` - Remote side's public key/identity /// * `remote_s_public` - Remote side's public key/identity
/// * `offer_metadata` - Arbitrary meta-data to send with key offer (empty if none) /// * `offer_metadata` - Arbitrary meta-data to send with key offer (empty if none)
@ -382,11 +389,9 @@ impl<H: Host> Session<H> {
if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) { if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) {
if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) { if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) {
let send_counter = Counter::new(); let send_counter = Counter::new();
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>());
let remote_s_public_hash = SHA384::hash(remote_s_public); let remote_s_public_hash = SHA384::hash(remote_s_public);
let outgoing_init_header_check_cipher = let header_check_cipher =
Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>()); Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>());
if let Ok(offer) = send_ephemeral_offer( if let Ok(offer) = send_ephemeral_offer(
&mut send, &mut send,
send_counter.next(), send_counter.next(),
@ -398,7 +403,7 @@ impl<H: Host> Session<H> {
&remote_s_public_hash, &remote_s_public_hash,
&ss, &ss,
None, None,
&outgoing_init_header_check_cipher, None,
mtu, mtu,
current_time, current_time,
) { ) {
@ -428,7 +433,8 @@ impl<H: Host> Session<H> {
/// Send data over the session. /// Send data over the session.
/// ///
/// * `mtu_buffer` - A writable work buffer whose size must be equal to the physical MTU /// * `send` - Function to call to send physical packet(s)
/// * `mtu_buffer` - A writable work buffer whose size also specifies the physical MTU
/// * `data` - Data to send /// * `data` - Data to send
#[inline] #[inline]
pub fn send<SendFunction: FnMut(&mut [u8])>( pub fn send<SendFunction: FnMut(&mut [u8])>(
@ -516,7 +522,7 @@ impl<H: Host> Session<H> {
/// ///
/// This returns a tuple of: the key fingerprint, the time it was established, the length of its ratchet chain, /// This returns a tuple of: the key fingerprint, the time it was established, the length of its ratchet chain,
/// and whether Kyber1024 was used. None is returned if the session isn't established. /// and whether Kyber1024 was used. None is returned if the session isn't established.
pub fn security_info(&self) -> Option<([u8; 16], i64, u64, bool)> { pub fn status(&self) -> Option<([u8; 16], i64, u64, bool)> {
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
if let Some(key) = state.keys[state.key_ptr].as_ref() { if let Some(key) = state.keys[state.key_ptr].as_ref() {
Some((key.secret_fingerprint, key.establish_time, key.ratchet_count, key.jedi)) Some((key.secret_fingerprint, key.establish_time, key.ratchet_count, key.jedi))
@ -527,6 +533,8 @@ impl<H: Host> Session<H> {
/// This function needs to be called on each session at least every SERVICE_INTERVAL milliseconds. /// This function needs to be called on each session at least every SERVICE_INTERVAL milliseconds.
/// ///
/// * `host` - Interface to application using ZSSP
/// * `send` - Function to call to send physical packet(s)
/// * `offer_metadata' - Any meta-data to include with initial key offers sent. /// * `offer_metadata' - Any meta-data to include with initial key offers sent.
/// * `mtu` - Physical MTU for sent packets /// * `mtu` - Physical MTU for sent packets
/// * `current_time` - Current monotonic time in milliseconds /// * `current_time` - Current monotonic time in milliseconds
@ -551,7 +559,6 @@ impl<H: Host> Session<H> {
.map_or(true, |o| (current_time - o.creation_time) > H::REKEY_RATE_LIMIT_MS) .map_or(true, |o| (current_time - o.creation_time) > H::REKEY_RATE_LIMIT_MS)
{ {
if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) { if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) {
let mut tmp_header_check_cipher = None;
if let Ok(offer) = send_ephemeral_offer( if let Ok(offer) = send_ephemeral_offer(
&mut send, &mut send,
self.send_counter.next(), self.send_counter.next(),
@ -564,12 +571,9 @@ impl<H: Host> Session<H> {
&self.ss, &self.ss,
state.keys[state.key_ptr].as_ref(), state.keys[state.key_ptr].as_ref(),
if state.remote_session_id.is_some() { if state.remote_session_id.is_some() {
&self.header_check_cipher Some(&self.header_check_cipher)
} else { } else {
let _ = tmp_header_check_cipher.insert(Aes::new( None
kbkdf512(&self.remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>(),
));
tmp_header_check_cipher.as_ref().unwrap()
}, },
mtu, mtu,
current_time, current_time,
@ -587,17 +591,20 @@ impl<H: Host> ReceiveContext<H> {
Self { Self {
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
incoming_init_header_check_cipher: Aes::new( incoming_init_header_check_cipher: Aes::new(
kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>(), kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(),
), ),
} }
} }
/// Receive, authenticate, decrypt, and process a physical wire packet. /// Receive, authenticate, decrypt, and process a physical wire packet.
/// ///
/// `data_buf` - Data buffer that must be as large as the largest supported data object to be transferred (or you'll get errors) /// * `host` - Interface to application using ZSSP
/// `incoming_packet_buf` - Buffer containing incoming wire packet, ownership taken by receive(). /// * `remote_address` - Remote physical address of source endpoint
/// `mtu` - Physical wire MTU /// * `data_buf` - Buffer to receive decrypted and authenticated object data (an error is returned if too small)
/// `current_time` - Current monotonic time in milliseconds /// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership)
/// * `mtu` - Physical wire MTU for sending packets
/// * `current_time` - Current monotonic time in milliseconds
#[inline]
pub fn receive<'a, SendFunction: FnMut(&mut [u8])>( pub fn receive<'a, SendFunction: FnMut(&mut [u8])>(
&self, &self,
host: &H, host: &H,
@ -717,7 +724,10 @@ impl<H: Host> ReceiveContext<H> {
} }
/// Called internally when all fragments of a packet are received. /// Called internally when all fragments of a packet are received.
/// Header check codes will already have been validated for each fragment. ///
/// NOTE: header check codes will already have been validated on receipt of each fragment. AEAD authentication
/// and decryption has NOT yet been performed, and is done here.
#[inline]
fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>( fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>(
&self, &self,
host: &H, host: &H,
@ -733,7 +743,7 @@ impl<H: Host> ReceiveContext<H> {
) -> Result<ReceiveResult<'a, H>, Error> { ) -> Result<ReceiveResult<'a, H>, Error> {
debug_assert!(fragments.len() >= 1); debug_assert!(fragments.len() >= 1);
// These just confirm that the first 'if' below does what it should do. // The first 'if' below should capture both DATA and NOP but not other types. Sanity check this.
debug_assert_eq!(PACKET_TYPE_DATA, 0); debug_assert_eq!(PACKET_TYPE_DATA, 0);
debug_assert_eq!(PACKET_TYPE_NOP, 1); debug_assert_eq!(PACKET_TYPE_NOP, 1);
@ -743,17 +753,12 @@ impl<H: Host> ReceiveContext<H> {
for p in 0..KEY_HISTORY_SIZE { for p in 0..KEY_HISTORY_SIZE {
let key_ptr = (state.key_ptr + p) % KEY_HISTORY_SIZE; let key_ptr = (state.key_ptr + p) % KEY_HISTORY_SIZE;
if let Some(key) = state.keys[key_ptr].as_ref() { if let Some(key) = state.keys[key_ptr].as_ref() {
let tail = fragments.last().unwrap().as_ref();
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
let mut c = key.get_receive_cipher(); let mut c = key.get_receive_cipher();
c.init(canonical_header_bytes); c.init(canonical_header_bytes);
let mut data_len = 0; let mut data_len = 0;
// Decrypt fragments 0..N-1 where N is the number of fragments.
for f in fragments[..(fragments.len() - 1)].iter() { for f in fragments[..(fragments.len() - 1)].iter() {
let f = f.as_ref(); let f = f.as_ref();
debug_assert!(f.len() >= HEADER_SIZE); debug_assert!(f.len() >= HEADER_SIZE);
@ -767,21 +772,28 @@ impl<H: Host> ReceiveContext<H> {
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
} }
// Decrypt final fragment (or only fragment if not fragmented)
let current_frag_data_start = data_len; let current_frag_data_start = data_len;
data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); let last_fragment = fragments.last().unwrap().as_ref();
if last_fragment.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
data_len += last_fragment.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
if data_len > data_buf.len() { if data_len > data_buf.len() {
unlikely_branch(); unlikely_branch();
key.return_receive_cipher(c); key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall); return Err(Error::DataBufferTooSmall);
} }
c.crypt( c.crypt(
&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &last_fragment[HEADER_SIZE..(last_fragment.len() - AES_GCM_TAG_SIZE)],
&mut data_buf[current_frag_data_start..data_len], &mut data_buf[current_frag_data_start..data_len],
); );
let ok = c.finish_decrypt(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]); let aead_authentication_ok = c.finish_decrypt(&last_fragment[(last_fragment.len() - AES_GCM_TAG_SIZE)..]);
key.return_receive_cipher(c); key.return_receive_cipher(c);
if ok {
if aead_authentication_ok {
// Select this key as the new default if it's newer than the current key. // Select this key as the new default if it's newer than the current key.
if p > 0 if p > 0
&& state.keys[state.key_ptr] && state.keys[state.key_ptr]
@ -801,6 +813,7 @@ impl<H: Host> ReceiveContext<H> {
} }
} }
} }
if packet_type == PACKET_TYPE_DATA { if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else { } else {
@ -810,6 +823,8 @@ impl<H: Host> ReceiveContext<H> {
} }
} }
} }
// If no known key authenticated the packet, decryption has failed.
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} else { } else {
unlikely_branch(); unlikely_branch();
@ -819,26 +834,27 @@ impl<H: Host> ReceiveContext<H> {
unlikely_branch(); unlikely_branch();
// To greatly simplify logic handling key exchange packets, assemble these first. // To greatly simplify logic handling key exchange packets, assemble these first.
// This adds some extra memory copying but this is not the fast path. // Handling KEX packets isn't the fast path so the extra copying isn't significant.
let mut incoming_packet_buf = [0_u8; 4096]; const KEX_BUF_LEN: usize = MIN_TRANSPORT_MTU * KEY_EXCHANGE_MAX_FRAGMENTS;
let mut incoming_packet_len = 0; let mut kex_packet = [0_u8; KEX_BUF_LEN];
let mut kex_packet_len = 0;
for i in 0..fragments.len() { for i in 0..fragments.len() {
let mut ff = fragments[i].as_ref(); let mut ff = fragments[i].as_ref();
debug_assert!(ff.len() >= MIN_PACKET_SIZE); debug_assert!(ff.len() >= MIN_PACKET_SIZE);
if i > 0 { if i > 0 {
ff = &ff[HEADER_SIZE..]; ff = &ff[HEADER_SIZE..];
} }
let j = incoming_packet_len + ff.len(); let j = kex_packet_len + ff.len();
if j > incoming_packet_buf.len() { if j > KEX_BUF_LEN {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
incoming_packet_buf[incoming_packet_len..j].copy_from_slice(ff); kex_packet[kex_packet_len..j].copy_from_slice(ff);
incoming_packet_len = j; kex_packet_len = j;
} }
let original_ciphertext = incoming_packet_buf.clone(); let kex_packet_saved_ciphertext = kex_packet.clone(); // save for HMAC check later
let incoming_packet = &mut incoming_packet_buf[..incoming_packet_len];
if incoming_packet[HEADER_SIZE] != SESSION_PROTOCOL_VERSION { // Key exchange packets begin (after header) with the session protocol version.
if kex_packet[HEADER_SIZE] != SESSION_PROTOCOL_VERSION {
return Err(Error::UnknownProtocolVersion); return Err(Error::UnknownProtocolVersion);
} }
@ -846,20 +862,20 @@ impl<H: Host> ReceiveContext<H> {
PACKET_TYPE_KEY_OFFER => { PACKET_TYPE_KEY_OFFER => {
// alice (remote) -> bob (local) // alice (remote) -> bob (local)
if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE) { if kex_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE) {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE); let payload_end = kex_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE);
let aes_gcm_tag_end = incoming_packet_len - (HMAC_SIZE + HMAC_SIZE); let aes_gcm_tag_end = kex_packet_len - (HMAC_SIZE + HMAC_SIZE);
let hmac1_end = incoming_packet_len - HMAC_SIZE; let hmac1_end = kex_packet_len - HMAC_SIZE;
// Check that the sender knows this host's identity before doing anything else. // Check the second HMAC first, which proves that the sender knows the recipient's full static identity.
if !hmac_sha384_2( if !hmac_sha384_2(
host.get_local_s_public_hash(), host.get_local_s_public_hash(),
canonical_header_bytes, canonical_header_bytes,
&incoming_packet[HEADER_SIZE..hmac1_end], &kex_packet[HEADER_SIZE..hmac1_end],
) )
.eq(&incoming_packet[hmac1_end..]) .eq(&kex_packet[hmac1_end..])
{ {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} }
@ -877,7 +893,7 @@ impl<H: Host> ReceiveContext<H> {
// Key agreement: alice (remote) ephemeral NIST P-384 <> local static NIST P-384 // Key agreement: alice (remote) ephemeral NIST P-384 <> local static NIST P-384
let (alice_e0_public, e0s) = let (alice_e0_public, e0s) =
P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
.and_then(|pk| host.get_local_s_keypair_p384().agree(&pk).map(move |s| (pk, s))) .and_then(|pk| host.get_local_s_keypair_p384().agree(&pk).map(move |s| (pk, s)))
.ok_or(Error::FailedAuthentication)?; .ok_or(Error::FailedAuthentication)?;
@ -886,18 +902,18 @@ impl<H: Host> ReceiveContext<H> {
// Decrypt the encrypted part of the packet payload and authenticate the above key exchange via AES-GCM auth. // Decrypt the encrypted part of the packet payload and authenticate the above key exchange via AES-GCM auth.
let mut c = AesGcm::new( let mut c = AesGcm::new(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<AES_KEY_SIZE>(),
false, false,
); );
c.init(canonical_header_bytes); c.init(canonical_header_bytes);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); c.crypt_in_place(&mut kex_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
if !c.finish_decrypt(&incoming_packet[payload_end..aes_gcm_tag_end]) { if !c.finish_decrypt(&kex_packet[payload_end..aes_gcm_tag_end]) {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} }
// Parse payload and get alice's session ID, alice's public blob, metadata, and (if present) Alice's Kyber1024 public. // Parse payload and get alice's session ID, alice's public blob, metadata, and (if present) Alice's Kyber1024 public.
let (offer_id, alice_session_id, alice_s_public, alice_metadata, alice_e1_public, alice_ratchet_key_fingerprint) = let (offer_id, alice_session_id, alice_s_public, alice_metadata, alice_e1_public, alice_ratchet_key_fingerprint) =
parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?; parse_key_offer_after_header(&kex_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?;
// We either have a session, in which case they should have supplied a ratchet key fingerprint, or // We either have a session, in which case they should have supplied a ratchet key fingerprint, or
// we don't and they should not have supplied one. // we don't and they should not have supplied one.
@ -922,9 +938,9 @@ impl<H: Host> ReceiveContext<H> {
if !hmac_sha384_2( if !hmac_sha384_2(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
canonical_header_bytes, canonical_header_bytes,
&original_ciphertext[HEADER_SIZE..aes_gcm_tag_end], &kex_packet_saved_ciphertext[HEADER_SIZE..aes_gcm_tag_end],
) )
.eq(&incoming_packet[aes_gcm_tag_end..hmac1_end]) .eq(&kex_packet[aes_gcm_tag_end..hmac1_end])
{ {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} }
@ -962,7 +978,9 @@ impl<H: Host> ReceiveContext<H> {
if let Some((new_session_id, psk, associated_object)) = if let Some((new_session_id, psk, associated_object)) =
host.accept_new_session(self, remote_address, alice_s_public, alice_metadata) host.accept_new_session(self, remote_address, alice_s_public, alice_metadata)
{ {
let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<16>()); let header_check_cipher = Aes::new(
kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(),
);
( (
Some(Session::<H> { Some(Session::<H> {
id: new_session_id, id: new_session_id,
@ -1031,8 +1049,7 @@ impl<H: Host> ReceiveContext<H> {
}; };
// Create reply packet. // Create reply packet.
const REPLY_BUF_LEN: usize = MIN_TRANSPORT_MTU * KEY_EXCHANGE_MAX_FRAGMENTS; let mut reply_buf = [0_u8; KEX_BUF_LEN];
let mut reply_buf = [0_u8; REPLY_BUF_LEN];
let reply_counter = session.send_counter.next(); let reply_counter = session.send_counter.next();
let mut reply_len = { let mut reply_len = {
let mut rp = &mut reply_buf[HEADER_SIZE..]; let mut rp = &mut reply_buf[HEADER_SIZE..];
@ -1057,7 +1074,7 @@ impl<H: Host> ReceiveContext<H> {
rp.write_all(&[0x00])?; rp.write_all(&[0x00])?;
} }
REPLY_BUF_LEN - rp.len() KEX_BUF_LEN - rp.len()
}; };
create_packet_header( create_packet_header(
&mut reply_buf, &mut reply_buf,
@ -1073,7 +1090,7 @@ impl<H: Host> ReceiveContext<H> {
// Encrypt reply packet using final Noise_IK key BEFORE mixing hybrid or ratcheting, since the other side // Encrypt reply packet using final Noise_IK key BEFORE mixing hybrid or ratcheting, since the other side
// must decrypt before doing these things. // must decrypt before doing these things.
let mut c = AesGcm::new( let mut c = AesGcm::new(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<AES_KEY_SIZE>(),
true, true,
); );
c.init(reply_canonical_header.as_bytes()); c.init(reply_canonical_header.as_bytes());
@ -1124,17 +1141,17 @@ impl<H: Host> ReceiveContext<H> {
PACKET_TYPE_KEY_COUNTER_OFFER => { PACKET_TYPE_KEY_COUNTER_OFFER => {
// bob (remote) -> alice (local) // bob (remote) -> alice (local)
if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) { if kex_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE); let payload_end = kex_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE);
let aes_gcm_tag_end = incoming_packet_len - HMAC_SIZE; let aes_gcm_tag_end = kex_packet_len - HMAC_SIZE;
if let Some(session) = session { if let Some(session) = session {
let state = session.state.read().unwrap(); let state = session.state.read().unwrap();
if let Some(offer) = state.offer.as_ref() { if let Some(offer) = state.offer.as_ref() {
let (bob_e0_public, e0e0) = let (bob_e0_public, e0e0) =
P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
.and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s))) .and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s)))
.ok_or(Error::FailedAuthentication)?; .ok_or(Error::FailedAuthentication)?;
let se0 = host let se0 = host
@ -1151,19 +1168,19 @@ impl<H: Host> ReceiveContext<H> {
)); ));
let mut c = AesGcm::new( let mut c = AesGcm::new(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<AES_KEY_SIZE>(),
false, false,
); );
c.init(canonical_header_bytes); c.init(canonical_header_bytes);
c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); c.crypt_in_place(&mut kex_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
if !c.finish_decrypt(&incoming_packet[payload_end..aes_gcm_tag_end]) { if !c.finish_decrypt(&kex_packet[payload_end..aes_gcm_tag_end]) {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} }
// Alice has now completed Noise_IK with NIST P-384 and verified with GCM auth, but now for hybrid... // Alice has now completed Noise_IK with NIST P-384 and verified with GCM auth, but now for hybrid...
let (offer_id, bob_session_id, _, _, bob_e1_public, bob_ratchet_key_id) = let (offer_id, bob_session_id, _, _, bob_e1_public, bob_ratchet_key_id) =
parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?; parse_key_offer_after_header(&kex_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?;
if !offer.id.eq(&offer_id) { if !offer.id.eq(&offer_id) {
return Ok(ReceiveResult::Ignored); return Ok(ReceiveResult::Ignored);
@ -1191,9 +1208,9 @@ impl<H: Host> ReceiveContext<H> {
if !hmac_sha384_2( if !hmac_sha384_2(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
canonical_header_bytes, canonical_header_bytes,
&original_ciphertext[HEADER_SIZE..aes_gcm_tag_end], &kex_packet_saved_ciphertext[HEADER_SIZE..aes_gcm_tag_end],
) )
.eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()]) .eq(&kex_packet[aes_gcm_tag_end..kex_packet.len()])
{ {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} }
@ -1244,6 +1261,7 @@ impl<H: Host> ReceiveContext<H> {
} }
/// Outgoing packet counter with strictly ordered atomic semantics. /// Outgoing packet counter with strictly ordered atomic semantics.
#[repr(transparent)]
struct Counter(AtomicU64); struct Counter(AtomicU64);
impl Counter { impl Counter {
@ -1333,7 +1351,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
bob_s_public_hash: &[u8], bob_s_public_hash: &[u8],
ss: &Secret<48>, ss: &Secret<48>,
current_key: Option<&SessionKey>, current_key: Option<&SessionKey>,
header_check_cipher: &Aes, header_check_cipher: Option<&Aes>, // None to use one based on the recipient's public key for initial contact
mtu: usize, mtu: usize,
current_time: i64, current_time: i64,
) -> Result<Box<EphemeralOffer>, Error> { ) -> Result<Box<EphemeralOffer>, Error> {
@ -1405,7 +1423,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
// Encrypt packet and attach AES-GCM tag. // Encrypt packet and attach AES-GCM tag.
let gcm_tag = { let gcm_tag = {
let mut c = AesGcm::new( let mut c = AesGcm::new(
kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<AES_KEY_SIZE>(),
true, true,
); );
c.init(canonical_header.as_bytes()); c.init(canonical_header.as_bytes());
@ -1432,7 +1450,16 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac);
packet_len += HMAC_SIZE; packet_len += HMAC_SIZE;
if let Some(header_check_cipher) = header_check_cipher {
send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher); send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher);
} else {
send_with_fragmentation(
send,
&mut packet_buf[..packet_len],
mtu,
&Aes::new(kbkdf512(&bob_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>()),
);
}
Ok(Box::new(EphemeralOffer { Ok(Box::new(EphemeralOffer {
id, id,
@ -1646,8 +1673,8 @@ struct SessionKey {
establish_counter: u64, // Counter value at which session was established establish_counter: u64, // Counter value at which session was established
lifetime: KeyLifetime, // Key expiration time and counter lifetime: KeyLifetime, // Key expiration time and counter
ratchet_key: Secret<64>, // Ratchet key for deriving the next session key ratchet_key: Secret<64>, // Ratchet key for deriving the next session key
receive_key: Secret<32>, // Receive side AES-GCM key receive_key: Secret<AES_KEY_SIZE>, // Receive side AES-GCM key
send_key: Secret<32>, // Send side AES-GCM key send_key: Secret<AES_KEY_SIZE>, // Send side AES-GCM key
receive_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of initialized sending ciphers receive_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of initialized sending ciphers
send_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of initialized receiving ciphers send_cipher_pool: Mutex<Vec<Box<AesGcm>>>, // Pool of initialized receiving ciphers
ratchet_count: u64, // Number of new keys negotiated in this session ratchet_count: u64, // Number of new keys negotiated in this session
@ -1657,8 +1684,8 @@ struct SessionKey {
impl SessionKey { impl SessionKey {
/// Create a new symmetric shared session key and set its key expiration times, etc. /// Create a new symmetric shared session key and set its key expiration times, etc.
fn new(key: Secret<64>, role: Role, current_time: i64, current_counter: CounterValue, ratchet_count: u64, jedi: bool) -> Self { fn new(key: Secret<64>, role: Role, current_time: i64, current_counter: CounterValue, ratchet_count: u64, jedi: bool) -> Self {
let a2b: Secret<32> = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n_clone(); let a2b: Secret<AES_KEY_SIZE> = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n_clone();
let b2a: Secret<32> = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n_clone(); let b2a: Secret<AES_KEY_SIZE> = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n_clone();
let (receive_key, send_key) = match role { let (receive_key, send_key) = match role {
Role::Alice => (b2a, a2b), Role::Alice => (b2a, a2b),
Role::Bob => (a2b, b2a), Role::Bob => (a2b, b2a),
@ -1929,7 +1956,7 @@ mod tests {
if session.established() { if session.established() {
{ {
let mut key_id = host.key_id.lock().unwrap(); let mut key_id = host.key_id.lock().unwrap();
let security_info = session.security_info().unwrap(); let security_info = session.status().unwrap();
if !security_info.0.eq(key_id.as_ref()) { if !security_info.0.eq(key_id.as_ref()) {
*key_id = security_info.0; *key_id = security_info.0;
println!( println!(