Put key index inside the encrypted part of the header.

This commit is contained in:
Adam Ierymenko 2023-01-11 19:54:04 -05:00
parent 3db9603799
commit 2479645341
3 changed files with 15 additions and 27 deletions

View file

@ -80,9 +80,8 @@ pub(crate) const COUNTER_WINDOW_MAX_OUT_OF_ORDER: usize = 16;
// Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use // Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use
pub(crate) const PACKET_TYPE_DATA: u8 = 0; pub(crate) const PACKET_TYPE_DATA: u8 = 0;
pub(crate) const PACKET_TYPE_NOP: u8 = 1; pub(crate) const PACKET_TYPE_INITIAL_KEY_OFFER: u8 = 1; // "alice"
pub(crate) const PACKET_TYPE_INITIAL_KEY_OFFER: u8 = 2; // "alice" pub(crate) const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 2; // "bob"
pub(crate) const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 3; // "bob"
// Key usage labels for sub-key derivation using NIST-style KBKDF (basically just HMAC KDF). // Key usage labels for sub-key derivation using NIST-style KBKDF (basically just HMAC KDF).
pub(crate) const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; // HMAC-SHA384 authentication for key exchanges pub(crate) const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; // HMAC-SHA384 authentication for key exchanges

View file

@ -12,7 +12,7 @@ use crate::constants::SESSION_ID_SIZE;
pub struct SessionId(NonZeroU64); // stored little endian internally pub struct SessionId(NonZeroU64); // stored little endian internally
impl SessionId { impl SessionId {
pub const MAX: u64 = 0x7fffffffffff; pub const MAX: u64 = 0xffffffffffff;
/// Create a new session ID, panicing if 'i' is zero or exceeds MAX. /// Create a new session ID, panicing if 'i' is zero or exceeds MAX.
pub fn new(i: u64) -> SessionId { pub fn new(i: u64) -> SessionId {

View file

@ -374,16 +374,14 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
let raw_local_session_id_key_index = memory::load_raw(incoming_packet); if let Some(local_session_id) = SessionId::new_from_u64_le(memory::load_raw(incoming_packet)) {
let key_index = (u64::from_le(raw_local_session_id_key_index).wrapping_shr(47) & 1) as usize;
if let Some(local_session_id) = SessionId::new_from_u64_le(raw_local_session_id_key_index) {
if let Some(session) = app.lookup_session(local_session_id) { if let Some(session) = app.lookup_session(local_session_id) {
session session
.header_check_cipher .header_check_cipher
.decrypt_block_in_place(&mut incoming_packet[HEADER_CHECK_ENCRYPT_START..HEADER_CHECK_ENCRYPT_END]); .decrypt_block_in_place(&mut incoming_packet[HEADER_CHECK_ENCRYPT_START..HEADER_CHECK_ENCRYPT_END]);
let raw_header_a = u16::from_le(memory::load_raw(&incoming_packet[6..])); let raw_header_a = u16::from_le(memory::load_raw(&incoming_packet[6..]));
let packet_type = (raw_header_a & 0xf) as u8; let key_index = (raw_header_a & 1) as usize;
let packet_type = (raw_header_a.wrapping_shr(1) & 7) as u8;
let fragment_count = ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8; let fragment_count = ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8;
let fragment_no = raw_header_a.wrapping_shr(10) as u8; let fragment_no = raw_header_a.wrapping_shr(10) as u8;
let counter = u64::from_le(memory::load_raw(&incoming_packet[8..])); let counter = u64::from_le(memory::load_raw(&incoming_packet[8..]));
@ -442,7 +440,8 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
self.incoming_init_header_check_cipher self.incoming_init_header_check_cipher
.decrypt_block_in_place(&mut incoming_packet[HEADER_CHECK_ENCRYPT_START..HEADER_CHECK_ENCRYPT_END]); .decrypt_block_in_place(&mut incoming_packet[HEADER_CHECK_ENCRYPT_START..HEADER_CHECK_ENCRYPT_END]);
let raw_header_a = u16::from_le(memory::load_raw(&incoming_packet[6..])); let raw_header_a = u16::from_le(memory::load_raw(&incoming_packet[6..]));
let packet_type = (raw_header_a & 0xf) as u8; let key_index = (raw_header_a & 1) as usize;
let packet_type = (raw_header_a.wrapping_shr(1) & 7) as u8;
let fragment_count = ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8; let fragment_count = ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8;
let fragment_no = raw_header_a.wrapping_shr(10) as u8; let fragment_no = raw_header_a.wrapping_shr(10) as u8;
let counter = u64::from_le(memory::load_raw(&incoming_packet[8..])); let counter = u64::from_le(memory::load_raw(&incoming_packet[8..]));
@ -506,13 +505,8 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
) -> Result<ReceiveResult<'a, Application>, Error> { ) -> Result<ReceiveResult<'a, Application>, Error> {
debug_assert!(fragments.len() >= 1); debug_assert!(fragments.len() >= 1);
// The first 'if' below should capture both DATA and NOP but not other types. Sanity check this.
debug_assert_eq!(PACKET_TYPE_DATA, 0);
debug_assert_eq!(PACKET_TYPE_NOP, 1);
let message_nonce = create_message_nonce(packet_type, counter); let message_nonce = create_message_nonce(packet_type, counter);
if packet_type == PACKET_TYPE_DATA {
if packet_type <= PACKET_TYPE_NOP {
if let Some(session) = session { if let Some(session) = session {
let state = session.state.read().unwrap(); let state = session.state.read().unwrap();
if let Some(session_key) = state.session_keys[key_index].as_ref() { if let Some(session_key) = state.session_keys[key_index].as_ref() {
@ -581,12 +575,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
} }
} }
if packet_type == PACKET_TYPE_DATA { return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else {
unlikely_branch();
return Ok(ReceiveResult::Ok);
}
} else { } else {
unlikely_branch(); unlikely_branch();
return Ok(ReceiveResult::Ignored); return Ok(ReceiveResult::Ignored);
@ -1254,17 +1243,17 @@ fn set_packet_header(
debug_assert!(fragment_no < MAX_FRAGMENTS); debug_assert!(fragment_no < MAX_FRAGMENTS);
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
if fragment_count <= MAX_FRAGMENTS { if fragment_count <= MAX_FRAGMENTS {
// [0-46] recipient session ID // [0-47] recipient session ID
// [47-47] ratchet count least significant bit (key index)
// -- start of header check cipher single block encrypt -- // -- start of header check cipher single block encrypt --
// [48-51] packet type (0-15) // [48-48] key index (least significant bit of ratchet count)
// [49-51] packet type (0-15)
// [52-57] fragment count (1..64 - 1, so 0 means 1 fragment) // [52-57] fragment count (1..64 - 1, so 0 means 1 fragment)
// [58-63] fragment number (0..63) // [58-63] fragment number (0..63)
// [64-127] 64-bit counter // [64-127] 64-bit counter
memory::store_raw( memory::store_raw(
(u64::from(recipient_session_id) (u64::from(recipient_session_id)
| (ratchet_count & 1).wrapping_shl(47) | (ratchet_count & 1).wrapping_shl(48)
| (packet_type as u64).wrapping_shl(48) | (packet_type as u64).wrapping_shl(49)
| ((fragment_count - 1) as u64).wrapping_shl(52) | ((fragment_count - 1) as u64).wrapping_shl(52)
| (fragment_no as u64).wrapping_shl(58)) | (fragment_no as u64).wrapping_shl(58))
.to_le(), .to_le(),