diff --git a/zssp/src/constants.rs b/zssp/src/constants.rs index 5f1fbfb78..73425d6ef 100644 --- a/zssp/src/constants.rs +++ b/zssp/src/constants.rs @@ -80,9 +80,8 @@ pub(crate) const COUNTER_WINDOW_MAX_OUT_OF_ORDER: usize = 16; // Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use pub(crate) const PACKET_TYPE_DATA: u8 = 0; -pub(crate) const PACKET_TYPE_NOP: u8 = 1; -pub(crate) const PACKET_TYPE_INITIAL_KEY_OFFER: u8 = 2; // "alice" -pub(crate) const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 3; // "bob" +pub(crate) const PACKET_TYPE_INITIAL_KEY_OFFER: u8 = 1; // "alice" +pub(crate) const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 2; // "bob" // Key usage labels for sub-key derivation using NIST-style KBKDF (basically just HMAC KDF). pub(crate) const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; // HMAC-SHA384 authentication for key exchanges diff --git a/zssp/src/sessionid.rs b/zssp/src/sessionid.rs index 37a29416b..09f32bd63 100644 --- a/zssp/src/sessionid.rs +++ b/zssp/src/sessionid.rs @@ -12,7 +12,7 @@ use crate::constants::SESSION_ID_SIZE; pub struct SessionId(NonZeroU64); // stored little endian internally impl SessionId { - pub const MAX: u64 = 0x7fffffffffff; + pub const MAX: u64 = 0xffffffffffff; /// Create a new session ID, panicing if 'i' is zero or exceeds MAX. pub fn new(i: u64) -> SessionId { diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 1953a7b72..9e8410966 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -374,16 +374,14 @@ impl ReceiveContext { return Err(Error::InvalidPacket); } - let raw_local_session_id_key_index = memory::load_raw(incoming_packet); - let key_index = (u64::from_le(raw_local_session_id_key_index).wrapping_shr(47) & 1) as usize; - - if let Some(local_session_id) = SessionId::new_from_u64_le(raw_local_session_id_key_index) { + if let Some(local_session_id) = SessionId::new_from_u64_le(memory::load_raw(incoming_packet)) { if let Some(session) = app.lookup_session(local_session_id) { session .header_check_cipher .decrypt_block_in_place(&mut incoming_packet[HEADER_CHECK_ENCRYPT_START..HEADER_CHECK_ENCRYPT_END]); let raw_header_a = u16::from_le(memory::load_raw(&incoming_packet[6..])); - let packet_type = (raw_header_a & 0xf) as u8; + let key_index = (raw_header_a & 1) as usize; + let packet_type = (raw_header_a.wrapping_shr(1) & 7) as u8; let fragment_count = ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8; let fragment_no = raw_header_a.wrapping_shr(10) as u8; let counter = u64::from_le(memory::load_raw(&incoming_packet[8..])); @@ -442,7 +440,8 @@ impl ReceiveContext { self.incoming_init_header_check_cipher .decrypt_block_in_place(&mut incoming_packet[HEADER_CHECK_ENCRYPT_START..HEADER_CHECK_ENCRYPT_END]); let raw_header_a = u16::from_le(memory::load_raw(&incoming_packet[6..])); - let packet_type = (raw_header_a & 0xf) as u8; + let key_index = (raw_header_a & 1) as usize; + let packet_type = (raw_header_a.wrapping_shr(1) & 7) as u8; let fragment_count = ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8; let fragment_no = raw_header_a.wrapping_shr(10) as u8; let counter = u64::from_le(memory::load_raw(&incoming_packet[8..])); @@ -506,13 +505,8 @@ impl ReceiveContext { ) -> Result, Error> { debug_assert!(fragments.len() >= 1); - // The first 'if' below should capture both DATA and NOP but not other types. Sanity check this. - debug_assert_eq!(PACKET_TYPE_DATA, 0); - debug_assert_eq!(PACKET_TYPE_NOP, 1); - let message_nonce = create_message_nonce(packet_type, counter); - - if packet_type <= PACKET_TYPE_NOP { + if packet_type == PACKET_TYPE_DATA { if let Some(session) = session { let state = session.state.read().unwrap(); if let Some(session_key) = state.session_keys[key_index].as_ref() { @@ -581,12 +575,7 @@ impl ReceiveContext { } } - if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); - } else { - unlikely_branch(); - return Ok(ReceiveResult::Ok); - } + return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); } else { unlikely_branch(); return Ok(ReceiveResult::Ignored); @@ -1254,17 +1243,17 @@ fn set_packet_header( debug_assert!(fragment_no < MAX_FRAGMENTS); debug_assert!(packet_type <= 0x0f); // packet type is 4 bits if fragment_count <= MAX_FRAGMENTS { - // [0-46] recipient session ID - // [47-47] ratchet count least significant bit (key index) + // [0-47] recipient session ID // -- start of header check cipher single block encrypt -- - // [48-51] packet type (0-15) + // [48-48] key index (least significant bit of ratchet count) + // [49-51] packet type (0-15) // [52-57] fragment count (1..64 - 1, so 0 means 1 fragment) // [58-63] fragment number (0..63) // [64-127] 64-bit counter memory::store_raw( (u64::from(recipient_session_id) - | (ratchet_count & 1).wrapping_shl(47) - | (packet_type as u64).wrapping_shl(48) + | (ratchet_count & 1).wrapping_shl(48) + | (packet_type as u64).wrapping_shl(49) | ((fragment_count - 1) as u64).wrapping_shl(52) | (fragment_no as u64).wrapping_shl(58)) .to_le(),