switched offer parsing to safe_read

This commit is contained in:
mamoniot 2022-12-14 13:31:30 -05:00
parent a7fdc82c5b
commit 3befaad3b0
2 changed files with 26 additions and 52 deletions

View file

@ -5,3 +5,5 @@ zssp has been cut up into several files, only the new zssp.rs file contains the
Standardized the naming conventions for security variables throughout zssp. Standardized the naming conventions for security variables throughout zssp.
Implemented a safer version of write_all for zssp to use. This has 3 benefits: it completely prevents unknown io errors, making error handling easier and self-documenting; it completely prevents src from being truncated in dest, putting in an extra barrier to prevent catastrophic key truncation; and it has slightly less performance overhead than a write_all. Implemented a safer version of write_all for zssp to use. This has 3 benefits: it completely prevents unknown io errors, making error handling easier and self-documenting; it completely prevents src from being truncated in dest, putting in an extra barrier to prevent catastrophic key truncation; and it has slightly less performance overhead than a write_all.
Fixed a possible buffer overrun panic when decoding alice_ratchet_key_fingerprint

View file

@ -189,13 +189,11 @@ pub fn varint_safe_write(dest: &mut &mut [u8], v: u64) -> Result<(), Error> {
safe_write_all(dest, &b[0..i]) safe_write_all(dest, &b[0..i])
} }
fn safe_read_exact(src: &mut &[u8], dest: &mut [u8]) -> Result<(), Error> { fn safe_read_exact<'a>(src: &mut &'a [u8], amt: usize) -> Result<&'a [u8], Error> {
let amt = dest.len();
if src.len() >= amt { if src.len() >= amt {
let (a, b) = src.split_at(amt); let (a, b) = src.split_at(amt);
dest.copy_from_slice(a);
*src = b; *src = b;
Ok(()) Ok(a)
} else { } else {
Err(Error::InvalidPacket) Err(Error::InvalidPacket)
} }
@ -806,7 +804,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
} }
// Match ratchet key fingerprint and fail if no match, which likely indicates an old offer packet. // Match ratchet key fingerprint and fail if no match, which likely indicates an old offer packet.
let alice_ratchet_key_fingerprint = alice_ratchet_key_fingerprint.as_ref().unwrap(); let alice_ratchet_key_fingerprint = alice_ratchet_key_fingerprint.unwrap();
let mut ratchet_key = None; let mut ratchet_key = None;
let mut ratchet_count = 0; let mut ratchet_count = 0;
let state = session.state.read().unwrap(); let state = session.state.read().unwrap();
@ -909,7 +907,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
safe_write_all(&mut rp, &[SESSION_PROTOCOL_VERSION])?; safe_write_all(&mut rp, &[SESSION_PROTOCOL_VERSION])?;
safe_write_all(&mut rp, bob_e_keypair.public_key_bytes())?; safe_write_all(&mut rp, bob_e_keypair.public_key_bytes())?;
safe_write_all(&mut rp, &offer_id)?; safe_write_all(&mut rp, offer_id)?;
safe_write_all(&mut rp, &session.id.0.to_le_bytes()[..SESSION_ID_SIZE])?; safe_write_all(&mut rp, &session.id.0.to_le_bytes()[..SESSION_ID_SIZE])?;
varint_safe_write(&mut rp, 0)?; // they don't need our static public; they have it varint_safe_write(&mut rp, 0)?; // they don't need our static public; they have it
varint_safe_write(&mut rp, 0)?; // no meta-data in counter-offers (could be used in the future) varint_safe_write(&mut rp, 0)?; // no meta-data in counter-offers (could be used in the future)
@ -921,7 +919,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
} }
if ratchet_key.is_some() { if ratchet_key.is_some() {
safe_write_all(&mut rp, &[0x01])?; safe_write_all(&mut rp, &[0x01])?;
safe_write_all(&mut rp, alice_ratchet_key_fingerprint.as_ref().unwrap())?; safe_write_all(&mut rp, alice_ratchet_key_fingerprint.unwrap())?;
} else { } else {
safe_write_all(&mut rp, &[0x00])?; safe_write_all(&mut rp, &[0x00])?;
} }
@ -1037,7 +1035,7 @@ impl<Layer: ApplicationLayer> ReceiveContext<Layer> {
packet_type, packet_type,
)?; )?;
if !offer.id.eq(&offer_id) { if !offer.id.eq(offer_id) {
return Ok(ReceiveResult::Ignored); return Ok(ReceiveResult::Ignored);
} }
@ -1341,49 +1339,26 @@ fn verify_header_check_code(packet: &[u8], header_check_cipher: &Aes) -> bool {
fn parse_key_offer_after_header( fn parse_key_offer_after_header(
incoming_packet: &[u8], incoming_packet: &[u8],
packet_type: u8, packet_type: u8,
) -> Result<([u8; 16], SessionId, &[u8], &[u8], &[u8], Option<[u8; 16]>), Error> { ) -> Result<(&[u8], SessionId, &[u8], &[u8], &[u8], Option<&[u8]>), Error> {
let mut p = &incoming_packet[..]; let mut p = &incoming_packet[..];
let mut offer_id = [0_u8; 16]; let offer_id = safe_read_exact(&mut p, 16)?;
safe_read_exact(&mut p, &mut offer_id)?;
let mut session_id_buf = 0_u64.to_ne_bytes(); let mut session_id_buf = 0_u64.to_ne_bytes();
safe_read_exact(&mut p, &mut session_id_buf[..SESSION_ID_SIZE])?; session_id_buf[..SESSION_ID_SIZE].copy_from_slice(safe_read_exact(&mut p, SESSION_ID_SIZE)?);
let alice_session_id = SessionId::new_from_u64(u64::from_le_bytes(session_id_buf)); let alice_session_id = SessionId::new_from_u64(u64::from_le_bytes(session_id_buf)).ok_or(Error::InvalidPacket)?;
if alice_session_id.is_none() {
return Err(Error::InvalidPacket);
}
let alice_session_id = alice_session_id.unwrap();
let alice_s_public_len = varint_safe_read(&mut p)?; let alice_s_public_len = varint_safe_read(&mut p)?;
if (p.len() as u64) < alice_s_public_len { let alice_s_public_raw = safe_read_exact(&mut p, alice_s_public_len as usize)?;
return Err(Error::InvalidPacket);
}
let alice_s_public = &p[..(alice_s_public_len as usize)];
p = &p[(alice_s_public_len as usize)..];
let alice_metadata_len = varint_safe_read(&mut p)?; let alice_metadata_len = varint_safe_read(&mut p)?;
if (p.len() as u64) < alice_metadata_len { let alice_metadata = safe_read_exact(&mut p, alice_metadata_len as usize)?;
return Err(Error::InvalidPacket);
} let alice_hk_public_raw = match safe_read_exact(&mut p, 1)?[0] {
let alice_metadata = &p[..(alice_metadata_len as usize)];
p = &p[(alice_metadata_len as usize)..];
if p.is_empty() {
return Err(Error::InvalidPacket);
}
let alice_hk_public = match p[0] {
E1_TYPE_KYBER1024 => { E1_TYPE_KYBER1024 => {
if packet_type == PACKET_TYPE_KEY_OFFER { if packet_type == PACKET_TYPE_KEY_OFFER {
if p.len() < (pqc_kyber::KYBER_PUBLICKEYBYTES + 1) { safe_read_exact(&mut p, pqc_kyber::KYBER_PUBLICKEYBYTES)?
return Err(Error::InvalidPacket);
}
let hkp = &p[1..(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)];
p = &p[(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)..];
hkp
} else { } else {
if p.len() < (pqc_kyber::KYBER_CIPHERTEXTBYTES + 1) { safe_read_exact(&mut p, pqc_kyber::KYBER_CIPHERTEXTBYTES)?
return Err(Error::InvalidPacket);
}
let hkp = &p[1..(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)];
p = &p[(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)..];
hkp
} }
} }
_ => &[], _ => &[],
@ -1391,21 +1366,18 @@ fn parse_key_offer_after_header(
if p.is_empty() { if p.is_empty() {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
let alice_ratchet_key_fingerprint = if p[0] == 0x01 { let alice_ratchet_key_fingerprint = if safe_read_exact(&mut p, 1)?[0] == 0x01 {
if p.len() < 16 { Some(safe_read_exact(&mut p, 16)?)
return Err(Error::InvalidPacket);
}
Some(p[1..17].try_into().unwrap())
} else { } else {
None None
}; };
Ok(( Ok((
offer_id, offer_id,//always 16 bytes
alice_session_id, alice_session_id,
alice_s_public, alice_s_public_raw,
alice_metadata, alice_metadata,
alice_hk_public, alice_hk_public_raw,
alice_ratchet_key_fingerprint, alice_ratchet_key_fingerprint,//always 16 bytes
)) ))
} }