got all tests to pass

This commit is contained in:
mamoniot 2022-12-27 21:09:01 -05:00
parent 8d1efcdffa
commit 5d72aabe17
3 changed files with 109 additions and 76 deletions

View file

@ -17,9 +17,10 @@ impl Counter {
// helps randomize packet contents a bit. // helps randomize packet contents a bit.
Self(AtomicU32::new(1u32)) Self(AtomicU32::new(1u32))
} }
#[inline(always)] #[inline(always)]
pub fn reset_after_initial_offer(&self) { pub fn reset_for_initial_offer(&self) {
self.0.store(2u32, Ordering::SeqCst); self.0.store(1u32, Ordering::SeqCst);
} }
/// Get the value most recently used to send a packet. /// Get the value most recently used to send a packet.
@ -63,7 +64,8 @@ impl CounterWindow {
pub fn new_invalid() -> Self { pub fn new_invalid() -> Self {
Self(std::array::from_fn(|_| AtomicU32::new(u32::MAX))) Self(std::array::from_fn(|_| AtomicU32::new(u32::MAX)))
} }
pub fn reset_after_initial_offer(&self) { pub fn reset_for_initial_offer(&self) {
let o = true;
for i in 0..COUNTER_MAX_ALLOWED_OOO { for i in 0..COUNTER_MAX_ALLOWED_OOO {
self.0[i].store(0, Ordering::SeqCst) self.0[i].store(0, Ordering::SeqCst)
} }

View file

@ -254,7 +254,7 @@ mod tests {
w.message_received(c); w.message_received(c);
continue; continue;
} else { } else {
w.reset_after_initial_offer(); w.reset_for_initial_offer();
counter = 1u32; counter = 1u32;
history = Vec::new(); history = Vec::new();
continue; continue;

View file

@ -112,12 +112,12 @@ pub struct Session<Application: ApplicationLayer> {
/// An arbitrary application defined object associated with each session /// An arbitrary application defined object associated with each session
pub application_data: Application::Data, pub application_data: Application::Data,
send_counter: Counter, // Outgoing packet counter and nonce state header_check_cipher: Aes, // Cipher used for header check codes (not Noise related)
receive_window: [CounterWindow; 2], // Receive window for anti-replay and deduplication receive_windows: [CounterWindow; 2], // Receive window for anti-replay and deduplication
send_counters: [Counter; 2], // Outgoing packet counter and nonce state
state: RwLock<SessionMutableState>, // Mutable parts of state (other than defrag buffers)
psk: Secret<64>, // Arbitrary PSK provided by external code psk: Secret<64>, // Arbitrary PSK provided by external code
noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key noise_ss: Secret<48>, // Static raw shared ECDH NIST P-384 key
header_check_cipher: Aes, // Cipher used for header check codes (not Noise related)
state: RwLock<SessionMutableState>, // Mutable parts of state (other than defrag buffers)
remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_blob_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_p384_bytes: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key remote_s_public_p384_bytes: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
@ -127,7 +127,7 @@ pub struct Session<Application: ApplicationLayer> {
struct SessionMutableState { struct SessionMutableState {
remote_session_id: Option<SessionId>, // The other side's 48-bit session ID remote_session_id: Option<SessionId>, // The other side's 48-bit session ID
session_keys: [Option<SessionKey>; 2], // Buffers to store current, next, and last active key session_keys: [Option<SessionKey>; 2], // Buffers to store current, next, and last active key
cur_session_key_id: bool, // Pointer used for keys[] circular buffer cur_session_key_id: bool, // Pointer used for keys[] circular buffer
offer: Option<EphemeralOffer>, // Most recent ephemeral offer sent to remote offer: Option<EphemeralOffer>, // Most recent ephemeral offer sent to remote
last_remote_offer: i64, // Time of most recent ephemeral offer (ms) last_remote_offer: i64, // Time of most recent ephemeral offer (ms)
} }
@ -236,6 +236,7 @@ impl<Application: ApplicationLayer> Session<Application> {
&mut send, &mut send,
send_counter.next(), send_counter.next(),
false, false,
false,
local_session_id, local_session_id,
None, None,
app.get_local_s_public_blob(), app.get_local_s_public_blob(),
@ -254,11 +255,9 @@ impl<Application: ApplicationLayer> Session<Application> {
return Ok(Self { return Ok(Self {
id: local_session_id, id: local_session_id,
application_data, application_data,
send_counter,
receive_window: [CounterWindow::new(), CounterWindow::new_invalid()],
psk: psk.clone(),
noise_ss,
header_check_cipher, header_check_cipher,
receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()],
send_counters: [send_counter, Counter::new()],
state: RwLock::new(SessionMutableState { state: RwLock::new(SessionMutableState {
remote_session_id: None, remote_session_id: None,
session_keys: [None, None], session_keys: [None, None],
@ -266,6 +265,8 @@ impl<Application: ApplicationLayer> Session<Application> {
offer, offer,
last_remote_offer: i64::MIN, last_remote_offer: i64::MIN,
}), }),
psk: psk.clone(),
noise_ss,
remote_s_public_blob_hash: bob_s_public_blob_hash, remote_s_public_blob_hash: bob_s_public_blob_hash,
remote_s_public_p384_bytes: bob_s_public.as_bytes().clone(), remote_s_public_p384_bytes: bob_s_public.as_bytes().clone(),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
@ -296,7 +297,7 @@ impl<Application: ApplicationLayer> Session<Application> {
let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; let packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE;
// This outgoing packet's nonce counter value. // This outgoing packet's nonce counter value.
let counter = self.send_counter.next(); let counter = self.send_counters[state.cur_session_key_id as usize].next();
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
// packet encoding for post-noise transport // packet encoding for post-noise transport
@ -403,10 +404,11 @@ impl<Application: ApplicationLayer> Session<Application> {
force_rekey: bool, force_rekey: bool,
) { ) {
let state = self.state.read().unwrap(); let state = self.state.read().unwrap();
let current_key_id = state.cur_session_key_id;
if (force_rekey if (force_rekey
|| state.session_keys[state.cur_session_key_id as usize] || state.session_keys[state.cur_session_key_id as usize]
.as_ref() .as_ref()
.map_or(true, |key| key.lifetime.should_rekey(self.send_counter.current(), current_time))) .map_or(true, |key| key.lifetime.should_rekey(self.send_counters[current_key_id as usize].current(), current_time)))
&& state && state
.offer .offer
.as_ref() .as_ref()
@ -416,8 +418,9 @@ impl<Application: ApplicationLayer> Session<Application> {
let mut offer = None; let mut offer = None;
if send_ephemeral_offer( if send_ephemeral_offer(
&mut send, &mut send,
CounterValue::get_initial_offer_counter(), self.send_counters[current_key_id as usize].next(),
!state.cur_session_key_id, current_key_id,
!current_key_id,
self.id, self.id,
state.remote_session_id, state.remote_session_id,
app.get_local_s_public_blob(), app.get_local_s_public_blob(),
@ -425,7 +428,7 @@ impl<Application: ApplicationLayer> Session<Application> {
&remote_s_public, &remote_s_public,
&self.remote_s_public_blob_hash, &self.remote_s_public_blob_hash,
&self.noise_ss, &self.noise_ss,
state.session_keys[state.cur_session_key_id as usize].as_ref(), state.session_keys[current_key_id as usize].as_ref(),
if state.remote_session_id.is_some() { if state.remote_session_id.is_some() {
Some(&self.header_check_cipher) Some(&self.header_check_cipher)
} else { } else {
@ -492,12 +495,13 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
{ {
if let Some(session) = app.lookup_session(local_session_id) { if let Some(session) = app.lookup_session(local_session_id) {
if verify_header_check_code(incoming_packet, &session.header_check_cipher) { if verify_header_check_code(incoming_packet, &session.header_check_cipher) {
if session.receive_window[key_id as usize].message_received(counter) { if session.receive_windows[key_id as usize].message_received(counter) {
let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter); let canonical_header = CanonicalHeader::make(local_session_id, packet_type, counter);
if fragment_count > 1 { if fragment_count > 1 {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
let mut defrag = session.defrag.lock().unwrap(); let mut defrag = session.defrag.lock().unwrap();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); // by using the counter + the key_id as the key we can prevent packet collisions, this only works if defrag hashes
let fragment_gather_array = defrag.get_or_create_mut(&raw_counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock drop(defrag); // release lock
return self.receive_complete( return self.receive_complete(
@ -672,7 +676,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
session_key.return_receive_cipher(c); session_key.return_receive_cipher(c);
if aead_authentication_ok { if aead_authentication_ok {
if session.receive_window[key_id as usize].message_authenticated(counter) { if session.receive_windows[key_id as usize].message_authenticated(counter) {
if packet_type == PACKET_TYPE_DATA { if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else { } else {
@ -832,7 +836,7 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
// Perform checks and match ratchet key if there's an existing session, or gate (via host) and // Perform checks and match ratchet key if there's an existing session, or gate (via host) and
// then create new sessions. // then create new sessions.
let (new_session, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() { let (new_session, reply_counter, new_key_id, ratchet_key, last_ratchet_count) = if let Some(session) = session.as_ref() {
// Existing session identity must match the one in this offer. // Existing session identity must match the one in this offer.
if !secure_eq(&session.remote_s_public_blob_hash, &SHA384::hash(&alice_s_public_blob)) { if !secure_eq(&session.remote_s_public_blob_hash, &SHA384::hash(&alice_s_public_blob)) {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
@ -843,10 +847,12 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
let mut ratchet_key = None; let mut ratchet_key = None;
let mut last_ratchet_count = 0; let mut last_ratchet_count = 0;
let state = session.state.read().unwrap(); let state = session.state.read().unwrap();
if state.cur_session_key_id == key_id { //key_id here is the key id of the key being rekeyed and replaced
return Ok(ReceiveResult::Ignored); // alice is requesting to overwrite the current key, reject it //it must be equal to the current session key, and not the previous session key
if state.cur_session_key_id != key_id {
return Ok(ReceiveResult::Ignored);
} }
if let Some(k) = state.session_keys[state.cur_session_key_id as usize].as_ref() { if let Some(k) = state.session_keys[key_id as usize].as_ref() {
if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) { if public_fingerprint_of_secret(k.ratchet_key.as_bytes())[..16].eq(alice_ratchet_key_fingerprint) {
ratchet_key = Some(k.ratchet_key.clone()); ratchet_key = Some(k.ratchet_key.clone());
last_ratchet_count = k.ratchet_count; last_ratchet_count = k.ratchet_count;
@ -856,34 +862,41 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
return Ok(ReceiveResult::Ignored); // old packet? return Ok(ReceiveResult::Ignored); // old packet?
} }
(None, ratchet_key, last_ratchet_count) (None, session.send_counters[key_id as usize].next(), !key_id, ratchet_key, last_ratchet_count)
} else { } else {
if key_id != false {
return Ok(ReceiveResult::Ignored);
}
if let Some((new_session_id, psk, associated_object)) = if let Some((new_session_id, psk, associated_object)) =
app.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata) app.accept_new_session(self, remote_address, alice_s_public_blob, alice_metadata)
{ {
let header_check_cipher = Aes::new( let header_check_cipher = Aes::new(
kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(), kbkdf512(noise_ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_CHECK).first_n::<HEADER_CHECK_AES_KEY_SIZE>(),
); );
let send_counter = Counter::new();
let reply_counter = send_counter.next();
( (
Some(Session::<Application> { Some(Session::<Application> {
id: new_session_id, id: new_session_id,
application_data: associated_object, application_data: associated_object,
send_counter: Counter::new(),
receive_window: [CounterWindow::new_invalid(), CounterWindow::new_invalid()],
psk,
noise_ss,
header_check_cipher, header_check_cipher,
receive_windows: [CounterWindow::new(), CounterWindow::new_invalid()],
send_counters: [send_counter, Counter::new()],
state: RwLock::new(SessionMutableState { state: RwLock::new(SessionMutableState {
remote_session_id: Some(alice_session_id), remote_session_id: Some(alice_session_id),
session_keys: [None, None], session_keys: [None, None],//this is the only value which will be writen later
cur_session_key_id: key_id, cur_session_key_id: false,
offer: None, offer: None,
last_remote_offer: current_time, last_remote_offer: current_time,
}), }),
psk,
noise_ss,
remote_s_public_blob_hash: SHA384::hash(&alice_s_public_blob), remote_s_public_blob_hash: SHA384::hash(&alice_s_public_blob),
remote_s_public_p384_bytes: alice_s_public.as_bytes().clone(), remote_s_public_p384_bytes: alice_s_public.as_bytes().clone(),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
}), }),
reply_counter,
false,
None, None,
0, 0,
) )
@ -946,7 +959,6 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
let mut reply_buf = [0_u8; KEX_BUF_LEN]; let mut reply_buf = [0_u8; KEX_BUF_LEN];
let reply_counter = CounterValue::get_initial_offer_counter();
let mut idx = HEADER_SIZE; let mut idx = HEADER_SIZE;
idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?; idx = safe_write_all(&mut reply_buf, idx, &[SESSION_PROTOCOL_VERSION])?;
@ -1016,33 +1028,40 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
); );
idx = safe_write_all(&mut reply_buf, idx, &hmac)?; idx = safe_write_all(&mut reply_buf, idx, &hmac)?;
let packet_end = idx; let packet_end = idx;
if session.receive_windows[key_id as usize].message_authenticated(counter) {
//initial key offers should only check this if this is a rekey
let session_key = SessionKey::new(
session_key,
Role::Bob,
current_time,
last_ratchet_count + 1,
hybrid_kk.is_some(),
);
let session_key = SessionKey::new( //TODO: check for correct orderings
session_key, let mut state = session.state.write().unwrap();
Role::Bob, let _ = state.session_keys[new_key_id as usize].replace(session_key);
current_time, if existing_session.is_some() {
last_ratchet_count + 1, let _ = state.remote_session_id.replace(alice_session_id);
hybrid_kk.is_some(), state.cur_session_key_id = new_key_id;
); session.send_counters[new_key_id as usize].reset_for_initial_offer();
session.receive_windows[new_key_id as usize].reset_for_initial_offer();
}
drop(state);
let mut state = session.state.write().unwrap(); // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it.
let _ = state.remote_session_id.replace(alice_session_id);
let _ = state.session_keys[key_id as usize].replace(session_key);
session.send_counter.reset_after_initial_offer();
state.cur_session_key_id = key_id;
session.receive_window[key_id as usize].reset_after_initial_offer();
session.receive_window[key_id as usize].message_authenticated(counter);
drop(state);
// Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher);
send_with_fragmentation(send, &mut reply_buf[..packet_end], mtu, &session.header_check_cipher); if let Some(new_session) = new_session {
return Ok(ReceiveResult::OkNewSession(new_session));
if let Some(new_session) = new_session { } else {
return Ok(ReceiveResult::OkNewSession(new_session)); return Ok(ReceiveResult::Ok);
}
} else { } else {
return Ok(ReceiveResult::Ok); return Ok(ReceiveResult::Ignored);
} }
} }
PACKET_TYPE_KEY_COUNTER_OFFER => { PACKET_TYPE_KEY_COUNTER_OFFER => {
@ -1096,8 +1115,9 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?; parse_dec_key_offer_after_header(&kex_packet[plaintext_end..kex_packet_len], packet_type)?;
// Check that this is a counter offer to the original offer we sent. // Check that this is a counter offer to the original offer we sent.
if !(offer.id.eq(offer_id) & (offer.key_id == key_id)) { if !(offer.id.eq(offer_id)) {
return Ok(ReceiveResult::Ignored); return Ok(ReceiveResult::Ignored);
} }
// Kyber1024 key agreement if enabled. // Kyber1024 key agreement if enabled.
@ -1138,27 +1158,36 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
) { ) {
return Err(Error::FailedAuthentication); return Err(Error::FailedAuthentication);
} }
if session.receive_windows[key_id as usize].message_authenticated(counter) {
// Alice has now completed and validated the full hybrid exchange.
// Alice has now completed and validated the full hybrid exchange. let session_key = SessionKey::new(
session_key,
Role::Alice,
current_time,
last_ratchet_count + 1,
hybrid_kk.is_some(),
);
let session_key = SessionKey::new( let new_key_id = offer.key_id;
session_key, let is_new_session = offer.ratchet_count == 0;
Role::Alice, drop(state);
current_time, //TODO: check for correct orderings
last_ratchet_count + 1, let mut state = session.state.write().unwrap();
hybrid_kk.is_some(), let _ = state.remote_session_id.replace(bob_session_id);
); let _ = state.session_keys[new_key_id as usize].replace(session_key);
if !is_new_session {
//when an brand new key offer is sent, it is sent using the new_key_id==false counter, we cannot reset it in that case.
state.cur_session_key_id = new_key_id;
session.send_counters[new_key_id as usize].reset_for_initial_offer();
session.receive_windows[new_key_id as usize].reset_for_initial_offer();
}
let _ = state.offer.take();
drop(state); return Ok(ReceiveResult::Ok);
let mut state = session.state.write().unwrap(); } else {
let _ = state.remote_session_id.replace(bob_session_id); return Ok(ReceiveResult::Ignored);
let _ = state.session_keys[key_id as usize].replace(session_key); }
session.send_counter.reset_after_initial_offer();
state.cur_session_key_id = key_id;
session.receive_window[key_id as usize].message_authenticated(counter);
let _ = state.offer.take();
return Ok(ReceiveResult::Ok);
} }
} }
@ -1174,10 +1203,12 @@ impl<Application: ApplicationLayer> ReceiveContext<Application> {
} }
/// Create an send an ephemeral offer, populating ret_ephemeral_offer on success. /// Create an send an ephemeral offer, populating ret_ephemeral_offer on success.
/// If there is no current session key set `current_key_id == new_key_id == false`
fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>( fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
send: &mut SendFunction, send: &mut SendFunction,
counter: CounterValue, counter: CounterValue,
key_id: bool, current_key_id: bool,
new_key_id: bool,
alice_session_id: SessionId, alice_session_id: SessionId,
bob_session_id: Option<SessionId>, bob_session_id: Option<SessionId>,
alice_s_public_blob: &[u8], alice_s_public_blob: &[u8],
@ -1263,7 +1294,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
PACKET_TYPE_INITIAL_KEY_OFFER, PACKET_TYPE_INITIAL_KEY_OFFER,
bob_session_id, bob_session_id,
counter, counter,
key_id, current_key_id,
)?; )?;
let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32()); let canonical_header = CanonicalHeader::make(bob_session_id, PACKET_TYPE_INITIAL_KEY_OFFER, counter.to_u32());
@ -1317,7 +1348,7 @@ fn send_ephemeral_offer<SendFunction: FnMut(&mut [u8])>(
*ret_ephemeral_offer = Some(EphemeralOffer { *ret_ephemeral_offer = Some(EphemeralOffer {
id, id,
key_id, key_id: new_key_id,
creation_time: current_time, creation_time: current_time,
ratchet_count, ratchet_count,
ratchet_key, ratchet_key,