diff --git a/zssp/src/error.rs b/zssp/src/error.rs index 5fd35c9b4..a8448bc35 100644 --- a/zssp/src/error.rs +++ b/zssp/src/error.rs @@ -51,7 +51,7 @@ pub enum Error { impl From for Error { #[inline(always)] fn from(_: std::io::Error) -> Self { - Self::InvalidPacket + Self::UnexpectedBufferOverrun } } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 57a01d55f..89bdf490e 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -946,26 +946,29 @@ impl Context { let mut rw = &mut reply_buffer[reply_len..]; rw.write_all(&gcm.finish_encrypt())?; - if let Some(md) = outgoing_offer.metadata.as_ref() { - assert!(md.len() <= (u16::MAX as usize)); - rw.write_all(&(md.len() as u16).to_le_bytes())?; - enc_start = MAX_NOISE_HANDSHAKE_SIZE - rw.len(); - rw.write_all(md.as_ref())?; - } else { - rw.write_all(&[0u8, 0u8])?; // no meta-data - enc_start = MAX_NOISE_HANDSHAKE_SIZE - rw.len(); - } + let metadata = outgoing_offer.metadata.as_ref().map_or(&[][..0], |md| md.as_slice()); + + assert!(metadata.len() <= (u16::MAX as usize)); + rw.write_all(&(metadata.len() as u16).to_le_bytes())?; reply_len = MAX_NOISE_HANDSHAKE_SIZE - rw.len(); + let noise_h_next = mix_hash(&noise_h_next, &reply_buffer[HEADER_SIZE..reply_len]); + let mut rw = &mut reply_buffer[reply_len..]; + + enc_start = reply_len; + rw.write_all(metadata.as_ref())?; + let mut gcm = AesGcm::new( kbkdf::(noise_es_ee_se_hk_psk.as_bytes()).as_bytes(), true, ); gcm.reset_init_gcm(&reply_message_nonce); gcm.aad(&noise_h_next); + reply_len = MAX_NOISE_HANDSHAKE_SIZE - rw.len(); gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]); - reply_buffer[reply_len..reply_len + AES_GCM_TAG_SIZE].copy_from_slice(&gcm.finish_encrypt()); - reply_len += AES_GCM_TAG_SIZE; + let mut rw = &mut reply_buffer[reply_len..]; + rw.write_all(&gcm.finish_encrypt())?; + reply_len = MAX_NOISE_HANDSHAKE_SIZE - rw.len(); drop(state); { @@ -1027,6 +1030,13 @@ impl Context { let mut r = PktReader(pkt_assembled, HEADER_SIZE + 1); let alice_static_public_blob_size = r.read_u16()? as usize; + + let ciphertext_up_to_metadata_size = r.1 + alice_static_public_blob_size + AES_GCM_TAG_SIZE + 2; + if r.0.len() < ciphertext_up_to_metadata_size { + return Err(Error::InvalidPacket); + } + let noise_h_next = mix_hash(&incoming.noise_h, &r.0[HEADER_SIZE..ciphertext_up_to_metadata_size]); + let alice_static_public_blob = r.read_decrypt_auth( alice_static_public_blob_size, kbkdf::(&hmac_sha512( @@ -1064,7 +1074,7 @@ impl Context { let alice_meta_data = r.read_decrypt_auth( alice_meta_data_size, kbkdf::(noise_es_ee_se_hk_psk.as_bytes()), - &incoming.noise_h, + &noise_h_next, &incoming_message_nonce, )?; if alice_meta_data.len() > data_buf.len() {