Simplify some packet building code.

This commit is contained in:
Adam Ierymenko 2023-03-08 15:03:27 -05:00
parent 94b3e208e7
commit cd6d8d36b0
2 changed files with 19 additions and 24 deletions

View file

@ -47,14 +47,6 @@ pub enum Error {
UnexpectedBufferOverrun,
}
// An I/O error in the parser means an invalid packet.
impl From<std::io::Error> for Error {
#[inline(always)]
fn from(_: std::io::Error) -> Self {
Self::UnexpectedBufferOverrun
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {

View file

@ -10,7 +10,6 @@
// FIPS compliant Noise_XK with Jedi powers (Kyber1024) and built-in attack-resistant large payload (fragmentation) support.
use std::collections::{HashMap, HashSet};
use std::io::Write;
use std::num::NonZeroU64;
use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock, Weak};
@ -923,15 +922,14 @@ impl<Application: ApplicationLayer> Context<Application> {
// up forward secrecy. Also return Bob's opaque note.
let mut reply_buffer = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
reply_buffer[HEADER_SIZE] = SESSION_PROTOCOL_VERSION;
let mut rw = &mut reply_buffer[HEADER_SIZE + 1..];
let mut reply_len = HEADER_SIZE + 1;
let alice_s_public_blob = app.get_local_s_public_blob();
assert!(alice_s_public_blob.len() <= (u16::MAX as usize));
rw.write_all(&(alice_s_public_blob.len() as u16).to_le_bytes())?;
let mut enc_start = MAX_NOISE_HANDSHAKE_SIZE - rw.len();
rw.write_all(alice_s_public_blob)?;
reply_len = append_to_slice(&mut reply_buffer, reply_len, &(alice_s_public_blob.len() as u16).to_le_bytes())?;
let mut enc_start = reply_len;
reply_len = append_to_slice(&mut reply_buffer, reply_len, alice_s_public_blob)?;
let mut reply_len = MAX_NOISE_HANDSHAKE_SIZE - rw.len();
let mut gcm = AesGcm::new(
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_ENCRYPTION>(&hmac_sha512(
noise_es_ee.as_bytes(),
@ -943,20 +941,17 @@ impl<Application: ApplicationLayer> Context<Application> {
gcm.reset_init_gcm(&reply_message_nonce);
gcm.aad(&noise_h_next);
gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]);
let mut rw = &mut reply_buffer[reply_len..];
rw.write_all(&gcm.finish_encrypt())?;
reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?;
let metadata = outgoing_offer.metadata.as_ref().map_or(&[][..0], |md| md.as_slice());
assert!(metadata.len() <= (u16::MAX as usize));
rw.write_all(&(metadata.len() as u16).to_le_bytes())?;
reply_len = append_to_slice(&mut reply_buffer, reply_len, &(metadata.len() as u16).to_le_bytes())?;
reply_len = MAX_NOISE_HANDSHAKE_SIZE - rw.len();
let noise_h_next = mix_hash(&noise_h_next, &reply_buffer[HEADER_SIZE..reply_len]);
let mut rw = &mut reply_buffer[reply_len..];
enc_start = reply_len;
rw.write_all(metadata.as_ref())?;
reply_len = append_to_slice(&mut reply_buffer, reply_len, metadata)?;
let mut gcm = AesGcm::new(
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_ENCRYPTION>(noise_es_ee_se_hk_psk.as_bytes()).as_bytes(),
@ -964,11 +959,8 @@ impl<Application: ApplicationLayer> Context<Application> {
);
gcm.reset_init_gcm(&reply_message_nonce);
gcm.aad(&noise_h_next);
reply_len = MAX_NOISE_HANDSHAKE_SIZE - rw.len();
gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]);
let mut rw = &mut reply_buffer[reply_len..];
rw.write_all(&gcm.finish_encrypt())?;
reply_len = MAX_NOISE_HANDSHAKE_SIZE - rw.len();
reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?;
drop(state);
{
@ -1665,6 +1657,17 @@ impl<'a> PktReader<'a> {
}
}
/// Helper function to append to a slice when we still want to be able to look back at it.
fn append_to_slice(s: &mut [u8], p: usize, d: &[u8]) -> Result<usize, Error> {
let tmp = p + d.len();
if tmp <= s.len() {
s[p..tmp].copy_from_slice(d);
Ok(tmp)
} else {
Err(Error::UnexpectedBufferOverrun)
}
}
/// MixHash to update 'h' during negotiation.
fn mix_hash(h: &[u8; SHA384_HASH_SIZE], m: &[u8]) -> [u8; SHA384_HASH_SIZE] {
let mut hasher = SHA384::new();