More cleanup in session.

This commit is contained in:
Adam Ierymenko 2022-09-06 18:31:19 -04:00
parent 06573c1ea8
commit d5943f246a
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3

View file

@ -261,9 +261,10 @@ pub struct ReceiveContext<H: Host> {
}
impl<H: Host> Session<H> {
#[inline]
pub fn new<SendFunction: FnMut(&mut [u8])>(
host: &H,
send: SendFunction,
mut send: SendFunction,
local_session_id: SessionId,
remote_s_public: &[u8],
offer_metadata: &[u8],
@ -278,7 +279,7 @@ impl<H: Host> Session<H> {
let counter = Counter::new();
let remote_s_public_hash = SHA384::hash(remote_s_public);
if let Ok(offer) = EphemeralOffer::create_alice_offer(
send,
&mut send,
counter.next(),
local_session_id,
None,
@ -312,13 +313,14 @@ impl<H: Host> Session<H> {
return Err(Error::InvalidParameter);
}
pub fn rekey_check<SendFunction: FnMut(&mut [u8])>(&self, host: &H, send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) {
#[inline]
pub fn rekey_check<SendFunction: FnMut(&mut [u8])>(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) {
let state = self.state.upgradable_read();
if let Some(key) = state.keys[0].as_ref() {
if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) {
if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) {
if let Ok(offer) = EphemeralOffer::create_alice_offer(
send,
&mut send,
self.send_counter.next(),
self.id,
state.remote_session_id,
@ -340,16 +342,18 @@ impl<H: Host> Session<H> {
}
impl<H: Host> ReceiveContext<H> {
#[inline]
pub fn new() -> Self {
Self {
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
}
}
#[inline]
pub fn receive<'a, SendFunction: FnMut(&mut [u8])>(
&self,
host: &H,
send: SendFunction,
mut send: SendFunction,
data_buf: &'a mut [u8],
incoming_packet_buf: H::IncomingPacketBuffer,
mtu: usize,
@ -377,7 +381,7 @@ impl<H: Host> ReceiveContext<H> {
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time);
return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time);
}
} else {
unlikely_branch();
@ -393,7 +397,7 @@ impl<H: Host> ReceiveContext<H> {
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
}
} else {
unlikely_branch();
@ -403,7 +407,7 @@ impl<H: Host> ReceiveContext<H> {
} else {
return self.receive_complete(
host,
send,
&mut send,
data_buf,
&[incoming_packet_buf],
packet_type,
@ -420,7 +424,7 @@ impl<H: Host> ReceiveContext<H> {
fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>(
&self,
host: &H,
mut send: SendFunction,
send: &mut SendFunction,
data_buf: &'a mut [u8],
fragments: &[H::IncomingPacketBuffer],
packet_type: u8,
@ -475,8 +479,8 @@ impl<H: Host> ReceiveContext<H> {
key.return_receive_cipher(c);
if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) {
// If this succeeded with the "next" key, promote it to current.
if ki == 1 {
// Promote next key to current key on success.
unlikely_branch();
drop(state);
let mut state = session.state.write();
@ -486,6 +490,7 @@ impl<H: Host> ReceiveContext<H> {
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else {
unlikely_branch();
return Ok(ReceiveResult::Ok);
}
}
@ -517,7 +522,7 @@ impl<H: Host> ReceiveContext<H> {
let original_ciphertext = incoming_packet_buf.clone();
let incoming_packet = &mut incoming_packet_buf[..incoming_packet_len];
if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) {
if incoming_packet_len <= HEADER_SIZE {
return Err(Error::InvalidPacket);
}
if incoming_packet[HEADER_SIZE] != SESSION_PROTOCOL_VERSION {
@ -528,6 +533,9 @@ impl<H: Host> ReceiveContext<H> {
PACKET_TYPE_KEY_OFFER => {
// alice (remote) -> bob (local)
if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE) {
return Err(Error::InvalidPacket);
}
let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE);
let aes_gcm_tag_end = incoming_packet_len - (HMAC_SIZE + HMAC_SIZE);
let hmac1_end = incoming_packet_len - HMAC_SIZE;
@ -688,6 +696,9 @@ impl<H: Host> ReceiveContext<H> {
PACKET_TYPE_KEY_COUNTER_OFFER => {
// bob (remote) -> alice (local)
if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) {
return Err(Error::InvalidPacket);
}
let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE);
let aes_gcm_tag_end = incoming_packet_len - HMAC_SIZE;
@ -712,7 +723,7 @@ impl<H: Host> ReceiveContext<H> {
return Err(Error::FailedAuthentication);
}
// Alice has now completed Noise_IK for P-384 and verified with GCM auth, now for the hybrid add-on.
// Alice has now completed Noise_IK with NIST P-384 and verified with GCM auth, but now for hybrid...
let (bob_session_id, _, _, bob_e1_public) = parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?;
@ -734,36 +745,26 @@ impl<H: Host> ReceiveContext<H> {
return Err(Error::FailedAuthentication);
}
// Alice has now completed and validated the full hybrid exchange. If this is the first exchange send
// a NOP back to Bob to acknowledge that the session is open and can now be used. Otherwise just queue
// this up as the next key to be promoted to current when Bob uses it.
// Alice has now completed and validated the full hybrid exchange.
let counter = session.send_counter.next();
let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi);
let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
let header = send_with_fragmentation_init_header(HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter);
reply_buf[..HEADER_SIZE].copy_from_slice(&header);
let mut c = key.get_send_cipher(counter)?;
c.init(&get_aes_gcm_nonce(&reply_buf));
reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish());
key.return_send_cipher(c);
send(&mut reply_buf);
let mut state = RwLockUpgradableReadGuard::upgrade(state);
let _ = state.offer.take();
let _ = state.remote_session_id.replace(bob_session_id);
if state.keys[0].is_some() {
let _ = state.keys[1].replace(SessionKey::new(key, Role::Alice, current_time, session.send_counter.current(), jedi));
} else {
let counter = session.send_counter.next();
let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi);
let mut reply_buf = [0_u8; MIN_MTU];
let dummy_data_len = (random::next_u32_secure() % (mtu - (HEADER_SIZE + AES_GCM_TAG_SIZE)) as u32) as usize;
let reply_len = dummy_data_len + HEADER_SIZE + AES_GCM_TAG_SIZE;
let header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter);
reply_buf[..HEADER_SIZE].copy_from_slice(&header);
let mut c = key.get_send_cipher(counter)?;
c.init(&get_aes_gcm_nonce(&reply_buf));
c.crypt_in_place(&mut reply_buf[HEADER_SIZE..(HEADER_SIZE + dummy_data_len)]);
reply_buf[(HEADER_SIZE + dummy_data_len)..reply_len].copy_from_slice(&c.finish());
key.return_send_cipher(c);
send(&mut reply_buf[..reply_len]);
let _ = state.keys[0].replace(key);
let _ = state.keys[1].take();
}
let _ = state.offer.take();
let _ = state.keys[0].insert(key);
return Ok(ReceiveResult::Ok);
}
@ -783,6 +784,7 @@ impl<H: Host> ReceiveContext<H> {
struct Counter(AtomicU64);
impl Counter {
#[inline(always)]
fn new() -> Self {
Self(AtomicU64::new(random::next_u32_secure() as u64))
}
@ -850,7 +852,7 @@ struct EphemeralOffer {
impl EphemeralOffer {
fn create_alice_offer<SendFunction: FnMut(&mut [u8])>(
send: SendFunction,
send: &mut SendFunction,
counter: CounterValue,
alice_session_id: SessionId,
bob_session_id: Option<SessionId>,
@ -947,14 +949,14 @@ fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_typ
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
// Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter
let mut header = ((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u128;
header |= recipient_session_id.wrapping_shl(16) as u128;
header |= (counter.to_u32() as u128).wrapping_shl(64);
header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap()
((((((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u64) | recipient_session_id.wrapping_shl(16)) as u128) | (counter.to_u32() as u128).wrapping_shl(64)).to_le_bytes()
[..HEADER_SIZE]
.try_into()
.unwrap()
}
#[inline(always)]
fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(mut send: SendFunction, packet: &mut [u8], mtu: usize, header: &mut [u8; HEADER_SIZE]) {
fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header: &mut [u8; HEADER_SIZE]) {
let packet_len = packet.len();
let mut fragment_start = 0;
let mut fragment_end = packet_len.min(mtu);
@ -1168,6 +1170,7 @@ mod tests {
#[allow(unused_variables)]
#[test]
fn establish_session() {
let jedi = true;
let mut psk: Secret<64> = Secret::default();
random::fill_bytes_secure(&mut psk.0);
let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob"));
@ -1188,7 +1191,7 @@ mod tests {
1,
1280,
1,
true,
jedi,
)
.unwrap(),
));
@ -1208,7 +1211,7 @@ mod tests {
if let Some(qi) = host.queue.lock().pop_back() {
let qi_len = qi.len();
ts += 1;
let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, true, ts);
let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, jedi, ts);
if r.is_ok() {
let r = r.unwrap();
match r {