made encoding and decoding more symmetric

This commit is contained in:
mamoniot 2022-12-14 19:51:55 -05:00
parent 194a12c180
commit fc656a02d1

View file

@ -173,7 +173,7 @@ impl std::fmt::Debug for Error {
}
// Write src into buffer starting at the index idx. If buffer cannot fit src at that location, nothing at all is written and Error::UnexpectedBufferOverrun is returned. No other errors can be returned by this function. An idx incremented by the amount written is returned.
/// Write src into buffer starting at the index idx. If buffer cannot fit src at that location, nothing at all is written and Error::UnexpectedBufferOverrun is returned. No other errors can be returned by this function. An idx incremented by the amount written is returned.
fn safe_write_all(buffer: &mut [u8], idx: usize, src: &[u8]) -> Result<usize, Error> {
let dest = &mut buffer[idx..];
let amt = src.len();
@ -193,7 +193,7 @@ fn varint_safe_write(buffer: &mut [u8], idx: usize, v: u64) -> Result<usize, Err
safe_write_all(buffer, idx, &b[0..i])
}
// Read exactly amt bytes from src and return the slice those bytes reside in. If src is smaller than amt, Error::InvalidPacket is returned. if the read was successful src is incremented to start at the first unread byte.
/// Read exactly amt bytes from src and return the slice those bytes reside in. If src is smaller than amt, Error::InvalidPacket is returned. if the read was successful src is incremented to start at the first unread byte.
fn safe_read_exact<'a>(src: &mut &'a [u8], amt: usize) -> Result<&'a [u8], Error> {
if src.len() >= amt {
let (a, b) = src.split_at(amt);
@ -301,7 +301,7 @@ impl<Layer: ApplicationLayer> Session<Layer> {
if let Some(remote_session_id) = state.remote_session_id {
if let Some(sym_key) = state.session_keys[state.cur_session_key_idx].as_ref() {
// Total size of the armored packet we are going to send (may end up being fragmented)
let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
// This outgoing packet's nonce counter value.
let counter = self.send_counter.next();
@ -326,6 +326,7 @@ impl<Layer: ApplicationLayer> Session<Layer> {
c.reset_init_gcm(CanonicalHeader::make(remote_session_id, PACKET_TYPE_DATA, counter.to_u32()).as_bytes());
// Send first N-1 fragments of N total fragments.
let last_fragment_size;
if packet_len > mtu_buffer.len() {
let mut header: [u8; 16] = mtu_buffer[..HEADER_SIZE].try_into().unwrap();
let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE;
@ -346,15 +347,18 @@ impl<Layer: ApplicationLayer> Session<Layer> {
break;
}
}
packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
last_fragment_size = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
} else {
last_fragment_size = packet_len;
}
// Send final fragment (or only fragment if no fragmentation was needed)
let gcm_tag_idx = data.len() + HEADER_SIZE;
c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]);
mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish_encrypt());
let payload_end = data.len() + HEADER_SIZE;
c.crypt(data, &mut mtu_buffer[HEADER_SIZE..payload_end]);
let gcm_tag = c.finish_encrypt();
mtu_buffer[payload_end..last_fragment_size].copy_from_slice(&gcm_tag);
set_header_check_code(mtu_buffer, &self.header_check_cipher);
send(&mut mtu_buffer[..packet_len]);
send(&mut mtu_buffer[..last_fragment_size]);
// Check reusable AES-GCM instance back into pool.
sym_key.return_send_cipher(c);
@ -648,12 +652,14 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
session_key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
let payload_end = last_fragment.len() - AES_GCM_TAG_SIZE;
c.crypt(
&last_fragment[HEADER_SIZE..(last_fragment.len() - AES_GCM_TAG_SIZE)],
&last_fragment[HEADER_SIZE..payload_end],
&mut data_buf[current_frag_data_start..data_len],
);
let aead_authentication_ok = c.finish_decrypt(&last_fragment[(last_fragment.len() - AES_GCM_TAG_SIZE)..]);
let gcm_tag = &last_fragment[payload_end..];
let aead_authentication_ok = c.finish_decrypt(gcm_tag);
session_key.return_receive_cipher(c);
if aead_authentication_ok {
@ -724,10 +730,14 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
match packet_type {
PACKET_TYPE_INITIAL_KEY_OFFER => {
// alice (remote) -> bob (local)
////////////////////////////////////////////////////////////////
// packet decoding for noise initial key offer
// -> e, es, s, ss
////////////////////////////////////////////////////////////////
if kex_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE) {
return Err(Error::InvalidPacket);
}
let plaintext_end = HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE;
let payload_end = kex_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE);
let aes_gcm_tag_end = kex_packet_len - (HMAC_SIZE + HMAC_SIZE);
let hmac1_end = kex_packet_len - HMAC_SIZE;
@ -754,15 +764,9 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
}
}
////////////////////////////////////////////////////////////////
// packet decoding for noise initial key offer
// -> e, es, s, ss
////////////////////////////////////////////////////////////////
// Key agreement: alice (remote) ephemeral NIST P-384 <> local static NIST P-384
let (alice_e_public, noise_es) =
P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
.and_then(|pk| host.get_local_s_keypair().agree(&pk).map(move |s| (pk, s)))
.ok_or(Error::FailedAuthentication)?;
let alice_e_public = P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..plaintext_end]).ok_or(Error::FailedAuthentication)?;
let noise_es = host.get_local_s_keypair().agree(&alice_e_public).ok_or(Error::FailedAuthentication)?;
// Initial key derivation from starting point, mixing in alice's ephemeral public and the es.
let es_key = Secret(hmac_sha512(&hmac_sha512(&INITIAL_KEY, alice_e_public.as_bytes()), noise_es.as_bytes()));
@ -773,14 +777,15 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
false,
);
c.reset_init_gcm(canonical_header_bytes);
c.crypt_in_place(&mut kex_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
if !c.finish_decrypt(&kex_packet[payload_end..aes_gcm_tag_end]) {
c.crypt_in_place(&mut kex_packet[plaintext_end..payload_end]);
let gcm_tag = &kex_packet[payload_end..aes_gcm_tag_end];
if !c.finish_decrypt(gcm_tag) {
return Err(Error::FailedAuthentication);
}
// Parse payload and get alice's session ID, alice's public blob, metadata, and (if present) Alice's Kyber1024 public.
let (offer_id, alice_session_id, alice_s_public_blob, alice_metadata, alice_hk_public_raw, alice_ratchet_key_fingerprint) =
parse_key_offer_after_header(&kex_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..kex_packet_len], packet_type)?;
parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?;
// We either have a session, in which case they should have supplied a ratchet key fingerprint, or
// we don't and they should not have supplied one.
@ -927,7 +932,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?;
idx = safe_write_all(&mut reply_buf, idx, bob_e_keypair.public_key_bytes())?;
let end_of_plaintext = idx;
let plaintext_end = idx;
idx = safe_write_all(&mut reply_buf, idx, offer_id)?;
idx = safe_write_all(&mut reply_buf, idx, &session.id.0.to_le_bytes()[..SESSION_ID_SIZE])?;
@ -945,18 +950,18 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
} else {
idx = safe_write_all(&mut reply_buf, idx, &[0x00])?;
}
let end_of_ciphertext = idx;
let payload_end = idx;
create_packet_header(
&mut reply_buf,
end_of_ciphertext,
payload_end,
mtu,
PACKET_TYPE_KEY_COUNTER_OFFER,
alice_session_id.into(),
reply_counter,
)?;
let reply_canonical_header =
CanonicalHeader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32());
CanonicalHeader::make(alice_session_id.into(), PACKET_TYPE_KEY_COUNTER_OFFER, reply_counter.to_u32());
// Encrypt reply packet using final Noise_IK key BEFORE mixing hybrid or ratcheting, since the other side
// must decrypt before doing these things.
@ -965,11 +970,11 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
true,
);
c.reset_init_gcm(reply_canonical_header.as_bytes());
c.crypt_in_place(&mut reply_buf[end_of_plaintext..end_of_ciphertext]);
let c = c.finish_encrypt();
idx = safe_write_all(&mut reply_buf, idx, &c)?;
c.crypt_in_place(&mut reply_buf[plaintext_end..payload_end]);
let gcm_tag = c.finish_encrypt();
idx = safe_write_all(&mut reply_buf, idx, &gcm_tag)?;
let aes_gcm_tag_end = idx;
// Mix ratchet key from previous session key (if any) and Kyber1024 hybrid shared key (if any).
let mut session_key = noise_ik_key;
if let Some(ratchet_key) = ratchet_key {
@ -986,9 +991,10 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
let hmac = hmac_sha384_2(
kbkdf512(session_key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
reply_canonical_header.as_bytes(),
&reply_buf[HEADER_SIZE..idx],
&reply_buf[HEADER_SIZE..aes_gcm_tag_end],
);
idx = safe_write_all(&mut reply_buf, idx, &hmac)?;
let packet_end = idx;
let session_key = SessionKey::new(session_key, Role::Bob, current_time, reply_counter, ratchet_count + 1, hybrid_kk.is_some());
@ -1000,7 +1006,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
send_with_fragmentation(send, &mut reply_buf[..idx], mtu, &session.header_check_cipher);
send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher);
if new_session.is_some() {
return Ok(ReceiveResult::OkNewSession(new_session.unwrap()));
@ -1011,10 +1017,15 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
PACKET_TYPE_KEY_COUNTER_OFFER => {
// bob (remote) -> alice (local)
////////////////////////////////////////////////////////////////
// packet decoding for noise key counter offer
// <- e, ee, se
////////////////////////////////////////////////////////////////
if kex_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) {
return Err(Error::InvalidPacket);
}
let plaintext_end = HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE;
let payload_end = kex_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE);
let aes_gcm_tag_end = kex_packet_len - HMAC_SIZE;
@ -1022,14 +1033,8 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
let state = session.state.read().unwrap();
if let Some(offer) = state.offer.as_ref() {
////////////////////////////////////////////////////////////////
// packet decoding for noise key counter offer
// <- e, ee, se
////////////////////////////////////////////////////////////////
let (bob_e_public, noise_ee) =
P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
.and_then(|pk| offer.alice_e_keypair.agree(&pk).map(move |s| (pk, s)))
.ok_or(Error::FailedAuthentication)?;
let bob_e_public = P384PublicKey::from_bytes(&kex_packet[(HEADER_SIZE + 1)..plaintext_end]).ok_or(Error::FailedAuthentication)?;
let noise_ee = offer.alice_e_keypair.agree(&bob_e_public).ok_or(Error::FailedAuthentication)?;
let noise_se = host
.get_local_s_keypair()
.agree(&bob_e_public)
@ -1048,15 +1053,16 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
false,
);
c.reset_init_gcm(canonical_header_bytes);
c.crypt_in_place(&mut kex_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]);
if !c.finish_decrypt(&kex_packet[payload_end..aes_gcm_tag_end]) {
c.crypt_in_place(&mut kex_packet[plaintext_end..payload_end]);
let gcm_tag = &kex_packet[payload_end..aes_gcm_tag_end];
if !c.finish_decrypt(gcm_tag) {
return Err(Error::FailedAuthentication);
}
// Alice has now completed Noise_IK with NIST P-384 and verified with GCM auth, but now for hybrid...
let (offer_id, bob_session_id, _, _, bob_hk_public_raw, bob_ratchet_key_id) = parse_key_offer_after_header(
&kex_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..kex_packet_len],
let (offer_id, bob_session_id, _, _, bob_hk_public_raw, bob_ratchet_key_id) = parse_dec_key_offer_after_header(
&kex_packet[plaintext_end..kex_packet_len],
packet_type,
)?;
@ -1114,7 +1120,8 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(CanonicalHeader::make(bob_session_id.into(), PACKET_TYPE_NOP, counter.to_u32()).as_bytes());
safe_write_all(&mut reply_buf, HEADER_SIZE, &c.finish_encrypt())?;
let gcm_tag = c.finish_encrypt();
safe_write_all(&mut reply_buf, HEADER_SIZE, &gcm_tag)?;
session_key.return_send_cipher(c);
set_header_check_code(&mut reply_buf, &session.header_check_cipher);
@ -1197,7 +1204,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
idx = safe_write_all(&mut packet_buf, idx, &[SESSION_PROTOCOL_VERSION])?;
//TODO: check this, the below line is supposed to be the blob, not just the key, right?
idx = safe_write_all(&mut packet_buf, idx, alice_e_keypair.public_key_bytes())?;
let end_of_plaintext = idx;
let plaintext_end = idx;
idx = safe_write_all(&mut packet_buf, idx, &id)?;
idx = safe_write_all(&mut packet_buf, idx, &alice_session_id.0.to_le_bytes()[..SESSION_ID_SIZE])?;
@ -1217,7 +1224,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
} else {
idx = safe_write_all(&mut packet_buf, idx, &[0x00])?;
}
let end_of_ciphertext = idx;
let payload_end = idx;
// Create ephemeral agreement secret.
@ -1227,7 +1234,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
));
let bob_session_id = bob_session_id.unwrap_or(SessionId::NIL);
create_packet_header(&mut packet_buf, end_of_ciphertext, mtu, PACKET_TYPE_INITIAL_KEY_OFFER, bob_session_id, counter)?;
create_packet_header(&mut packet_buf, payload_end, mtu, PACKET_TYPE_INITIAL_KEY_OFFER, bob_session_id, counter)?;
let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32());
@ -1238,10 +1245,12 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
true,
);
c.reset_init_gcm(canonical_header.as_bytes());
c.crypt_in_place(&mut packet_buf[end_of_plaintext..end_of_ciphertext]);
c.crypt_in_place(&mut packet_buf[plaintext_end..payload_end]);
c.finish_encrypt()
};
idx = safe_write_all(&mut packet_buf, idx, &gcm_tag)?;
let aes_gcm_tag_end = idx;
// Mix in static secret.
let ss_key = Secret(hmac_sha512(es_key.as_bytes(), noise_ss.as_bytes()));
@ -1251,20 +1260,22 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
let hmac = hmac_sha384_2(
kbkdf512(ss_key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(),
canonical_header.as_bytes(),
&packet_buf[HEADER_SIZE..idx],
&packet_buf[HEADER_SIZE..aes_gcm_tag_end],
);
idx = safe_write_all(&mut packet_buf, idx, &hmac)?;
let hmac1_end = idx;
// Add secondary HMAC to verify that the caller knows the recipient's full static public identity.
let hmac2 = hmac_sha384_2(bob_s_public_blob_hash, canonical_header.as_bytes(), &packet_buf[HEADER_SIZE..idx]);
let hmac2 = hmac_sha384_2(bob_s_public_blob_hash, canonical_header.as_bytes(), &packet_buf[HEADER_SIZE..hmac1_end]);
idx = safe_write_all(&mut packet_buf, idx, &hmac2)?;
let packet_end = idx;
if let Some(header_check_cipher) = header_check_cipher {
send_with_fragmentation(send, &mut packet_buf[..idx], mtu, header_check_cipher);
send_with_fragmentation(send, &mut packet_buf[..packet_end], mtu, header_check_cipher);
} else {
send_with_fragmentation(
send,
&mut packet_buf[..idx],
&mut packet_buf[..packet_end],
mtu,
&Aes::new(kbkdf512(&bob_s_public_blob_hash, KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>()),
);
@ -1370,7 +1381,7 @@ fn verify_header_check_code(packet: &[u8], header_check_cipher: &Aes) -> bool {
}
/// Parse KEY_OFFER and KEY_COUNTER_OFFER starting after the unencrypted public key part.
fn parse_key_offer_after_header(
fn parse_dec_key_offer_after_header(
incoming_packet: &[u8],
packet_type: u8,
) -> Result<(&[u8], SessionId, &[u8], &[u8], &[u8], Option<&[u8]>), Error> {