Rework defragmentation, and it now tolerates very poor link quality pretty well.

This commit is contained in:
Adam Ierymenko 2023-03-02 19:09:31 -05:00
parent 87989ac008
commit 40945cf6c9
7 changed files with 514 additions and 293 deletions

View file

@ -38,7 +38,7 @@ impl<T, const C: usize> GatherArray<T, C> {
/// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here.
#[inline(always)]
pub fn add(&mut self, index: u8, value: T) -> Option<ArrayVec<T, C>> {
pub fn add_return_when_satisfied(&mut self, index: u8, value: T) -> Option<ArrayVec<T, C>> {
if index < self.goal {
let mut have = self.have_bits;
let got = 1u64.wrapping_shl(index as u32);
@ -91,9 +91,9 @@ mod tests {
for goal in 2u8..64u8 {
let mut m = GatherArray::<u8, 64>::new(goal);
for x in 0..(goal - 1) {
assert!(m.add(x, x).is_none());
assert!(m.add_return_when_satisfied(x, x).is_none());
}
let r = m.add(goal - 1, goal - 1).unwrap();
let r = m.add_return_when_satisfied(goal - 1, goal - 1).unwrap();
for x in 0..goal {
assert_eq!(r.as_ref()[x as usize], x);
}

View file

@ -6,6 +6,8 @@
* https://www.zerotier.com/
*/
use std::hash::Hash;
use zerotier_crypto::p384::P384KeyPair;
/// Trait to implement to integrate the session into an application.
@ -65,6 +67,12 @@ pub trait ApplicationLayer: Sized {
/// for a short period of time when assembling fragmented packets on the receive path.
type IncomingPacketBuffer: AsRef<[u8]> + AsMut<[u8]>;
/// Opaque type for whatever constitutes a physical path to the application.
///
/// A physical path could be an IP address or IP plus device in the case of UDP, a socket in the
/// case of TCP, etc.
type PhysicalPath: PartialEq + Eq + Hash + Clone;
/// Get a reference to this host's static public key blob.
///
/// This must contain a NIST P-384 public key but can contain other information. In ZeroTier this

103
zssp/src/fragged.rs Normal file
View file

@ -0,0 +1,103 @@
use std::mem::{needs_drop, size_of, zeroed, MaybeUninit};
use std::ptr::slice_from_raw_parts;
/// Fast packet defragmenter
pub struct Fragged<Fragment, const MAX_FRAGMENTS: usize> {
have: u64,
counter: u64,
frags: [MaybeUninit<Fragment>; MAX_FRAGMENTS],
}
pub struct Assembled<Fragment, const MAX_FRAGMENTS: usize>([MaybeUninit<Fragment>; MAX_FRAGMENTS], usize);
impl<Fragment, const MAX_FRAGMENTS: usize> AsRef<[Fragment]> for Assembled<Fragment, MAX_FRAGMENTS> {
#[inline(always)]
fn as_ref(&self) -> &[Fragment] {
unsafe { &*slice_from_raw_parts(self.0.as_ptr().cast::<Fragment>(), self.1) }
}
}
impl<Fragment, const MAX_FRAGMENTS: usize> Drop for Assembled<Fragment, MAX_FRAGMENTS> {
#[inline(always)]
fn drop(&mut self) {
for i in 0..self.1 {
unsafe {
self.0.get_unchecked_mut(i).assume_init_drop();
}
}
}
}
impl<Fragment, const MAX_FRAGMENTS: usize> Fragged<Fragment, MAX_FRAGMENTS> {
pub fn new() -> Self {
debug_assert!(MAX_FRAGMENTS <= 64);
debug_assert_eq!(size_of::<MaybeUninit<Fragment>>(), size_of::<Fragment>());
debug_assert_eq!(
size_of::<[MaybeUninit<Fragment>; MAX_FRAGMENTS]>(),
size_of::<[Fragment; MAX_FRAGMENTS]>()
);
unsafe { zeroed() }
}
pub fn assemble(
&mut self,
counter: u64,
fragment: Fragment,
fragment_no: u8,
fragment_count: u8,
) -> Option<Assembled<Fragment, MAX_FRAGMENTS>> {
if fragment_no < fragment_count && (fragment_count as usize) <= MAX_FRAGMENTS {
debug_assert!((fragment_count as usize) <= MAX_FRAGMENTS);
debug_assert!((fragment_no as usize) < MAX_FRAGMENTS);
let mut have = self.have;
if counter != self.counter {
self.counter = counter;
if needs_drop::<Fragment>() {
let mut i = 0;
while have != 0 {
if (have & 1) != 0 {
debug_assert!(i < MAX_FRAGMENTS);
unsafe { self.frags.get_unchecked_mut(i).assume_init_drop() };
}
have = have.wrapping_shr(1);
i += 1;
}
} else {
have = 0;
}
}
unsafe {
self.frags.get_unchecked_mut(fragment_no as usize).write(fragment);
}
let want = 0xffffffffffffffffu64.wrapping_shr((64 - fragment_count) as u32);
have |= 1u64.wrapping_shl(fragment_no as u32);
if (have & want) == want {
self.have = 0;
return Some(Assembled(unsafe { std::mem::transmute_copy(&self.frags) }, fragment_count as usize));
} else {
self.have = have;
}
}
return None;
}
}
impl<Fragment, const MAX_FRAGMENTS: usize> Drop for Fragged<Fragment, MAX_FRAGMENTS> {
fn drop(&mut self) {
if needs_drop::<Fragment>() {
let mut have = self.have;
let mut i = 0;
while have != 0 {
if (have & 1) != 0 {
debug_assert!(i < MAX_FRAGMENTS);
unsafe { self.frags.get_unchecked_mut(i).assume_init_drop() };
}
have = have.wrapping_shr(1);
i += 1;
}
}
}
}

View file

@ -8,6 +8,7 @@
mod applicationlayer;
mod error;
mod fragged;
mod proto;
mod sessionid;
mod zssp;

View file

@ -1,9 +1,12 @@
use std::iter::ExactSizeIterator;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use zerotier_crypto::p384::{P384KeyPair, P384PublicKey};
use zerotier_crypto::random;
use zerotier_crypto::secret::Secret;
use zerotier_utils::hex;
use zerotier_utils::ms_monotonic;
@ -23,8 +26,8 @@ impl zssp::ApplicationLayer for TestApplication {
const RETRY_INTERVAL: i64 = 500;
type Data = ();
type IncomingPacketBuffer = Vec<u8>;
type PhysicalPath = usize;
fn get_local_s_public_blob(&self) -> &[u8] {
self.identity_key.public_key_bytes()
@ -37,6 +40,7 @@ impl zssp::ApplicationLayer for TestApplication {
fn alice_main(
run: &AtomicBool,
packet_success_rate: u32,
alice_app: &TestApplication,
bob_app: &TestApplication,
alice_out: mpsc::SyncSender<Vec<u8>>,
@ -46,7 +50,7 @@ fn alice_main(
let mut data_buf = [0u8; 65536];
let mut next_service = ms_monotonic() + 500;
let mut last_ratchet_count = 0;
let test_data = [1u8; 10000];
let test_data = [1u8; TEST_MTU * 10];
let mut up = false;
let alice_session = context
@ -71,31 +75,35 @@ fn alice_main(
loop {
let pkt = alice_in.try_recv();
if let Ok(pkt) = pkt {
//println!("bob >> alice {}", pkt.len());
match context.receive(
alice_app,
|| true,
|s_public, _| Some((P384PublicKey::from_bytes(s_public).unwrap(), Secret::default(), ())),
|_, b| {
let _ = alice_out.send(b.to_vec());
},
&mut data_buf,
pkt,
TEST_MTU,
current_time,
) {
Ok(zssp::ReceiveResult::Ok) => {
//println!("[alice] ok");
}
Ok(zssp::ReceiveResult::OkData(_, _)) => {
//println!("[alice] received {}", data.len());
}
Ok(zssp::ReceiveResult::OkNewSession(s)) => {
println!("[alice] new session {}", s.id.to_string());
}
Ok(zssp::ReceiveResult::Rejected) => {}
Err(e) => {
println!("[alice] ERROR {}", e.to_string());
if (random::xorshift64_random() as u32) <= packet_success_rate {
match context.receive(
alice_app,
|| true,
|s_public, _| Some((P384PublicKey::from_bytes(s_public).unwrap(), Secret::default(), ())),
|_, b| {
let _ = alice_out.send(b.to_vec());
},
&0,
&mut data_buf,
pkt,
TEST_MTU,
current_time,
) {
Ok(zssp::ReceiveResult::Ok) => {
//println!("[alice] ok");
}
Ok(zssp::ReceiveResult::OkData(_, _)) => {
//println!("[alice] received {}", data.len());
}
Ok(zssp::ReceiveResult::OkNewSession(s)) => {
println!("[alice] new session {}", s.id.to_string());
}
Ok(zssp::ReceiveResult::Rejected) => {}
Err(e) => {
println!("[alice] ERROR {}", e.to_string());
//run.store(false, Ordering::SeqCst);
//break;
}
}
}
} else {
@ -116,12 +124,14 @@ fn alice_main(
let _ = alice_out.send(b.to_vec());
},
&mut data_buf[..TEST_MTU],
&test_data[..1400 + ((zerotier_crypto::random::xorshift64_random() as usize) % (test_data.len() - 1400))],
&test_data[..1400 + ((random::xorshift64_random() as usize) % (test_data.len() - 1400))],
)
.is_ok());
} else {
if alice_session.established() {
up = true;
} else {
thread::sleep(Duration::from_millis(10));
}
}
@ -140,6 +150,7 @@ fn alice_main(
fn bob_main(
run: &AtomicBool,
packet_success_rate: u32,
_alice_app: &TestApplication,
bob_app: &TestApplication,
bob_out: mpsc::SyncSender<Vec<u8>>,
@ -160,42 +171,46 @@ fn bob_main(
let current_time = ms_monotonic();
if let Ok(pkt) = pkt {
//println!("alice >> bob {}", pkt.len());
match context.receive(
bob_app,
|| true,
|s_public, _| Some((P384PublicKey::from_bytes(s_public).unwrap(), Secret::default(), ())),
|_, b| {
let _ = bob_out.send(b.to_vec());
},
&mut data_buf,
pkt,
TEST_MTU,
current_time,
) {
Ok(zssp::ReceiveResult::Ok) => {
//println!("[bob] ok");
}
Ok(zssp::ReceiveResult::OkData(s, data)) => {
//println!("[bob] received {}", data.len());
assert!(s
.send(
|b| {
let _ = bob_out.send(b.to_vec());
},
&mut data_buf_2,
data.as_mut(),
)
.is_ok());
transferred += data.len() as u64 * 2; // *2 because we are also sending this many bytes back
}
Ok(zssp::ReceiveResult::OkNewSession(s)) => {
println!("[bob] new session {}", s.id.to_string());
let _ = bob_session.replace(s);
}
Ok(zssp::ReceiveResult::Rejected) => {}
Err(e) => {
println!("[bob] ERROR {}", e.to_string());
if (random::xorshift64_random() as u32) <= packet_success_rate {
match context.receive(
bob_app,
|| true,
|s_public, _| Some((P384PublicKey::from_bytes(s_public).unwrap(), Secret::default(), ())),
|_, b| {
let _ = bob_out.send(b.to_vec());
},
&0,
&mut data_buf,
pkt,
TEST_MTU,
current_time,
) {
Ok(zssp::ReceiveResult::Ok) => {
//println!("[bob] ok");
}
Ok(zssp::ReceiveResult::OkData(s, data)) => {
//println!("[bob] received {}", data.len());
assert!(s
.send(
|b| {
let _ = bob_out.send(b.to_vec());
},
&mut data_buf_2,
data.as_mut(),
)
.is_ok());
transferred += data.len() as u64 * 2; // *2 because we are also sending this many bytes back
}
Ok(zssp::ReceiveResult::OkNewSession(s)) => {
println!("[bob] new session {}", s.id.to_string());
let _ = bob_session.replace(s);
}
Ok(zssp::ReceiveResult::Rejected) => {}
Err(e) => {
println!("[bob] ERROR {}", e.to_string());
//run.store(false, Ordering::SeqCst);
//break;
}
}
}
}
@ -211,10 +226,12 @@ fn bob_main(
let speed_metric_elapsed = current_time - last_speed_metric;
if speed_metric_elapsed >= 1000 {
last_speed_metric = current_time;
println!(
"[bob] throughput: {} MiB/sec (combined input and output)",
((transferred as f64) / 1048576.0) / ((speed_metric_elapsed as f64) / 1000.0)
);
if transferred > 0 {
println!(
"[bob] throughput: {} MiB/sec (combined input and output)",
((transferred as f64) / 1048576.0) / ((speed_metric_elapsed as f64) / 1000.0)
);
}
transferred = 0;
}
@ -240,9 +257,16 @@ fn main() {
let (alice_out, bob_in) = mpsc::sync_channel::<Vec<u8>>(1024);
let (bob_out, alice_in) = mpsc::sync_channel::<Vec<u8>>(1024);
let args = std::env::args();
let packet_success_rate = if args.len() <= 1 {
u32::MAX
} else {
((u32::MAX as f64) * f64::from_str(args.last().unwrap().as_str()).unwrap()) as u32
};
thread::scope(|ts| {
let alice_thread = ts.spawn(|| alice_main(&run, &alice_app, &bob_app, alice_out, alice_in));
let bob_thread = ts.spawn(|| bob_main(&run, &alice_app, &bob_app, bob_out, bob_in));
let alice_thread = ts.spawn(|| alice_main(&run, packet_success_rate, &alice_app, &bob_app, alice_out, alice_in));
let bob_thread = ts.spawn(|| bob_main(&run, packet_success_rate, &alice_app, &bob_app, bob_out, bob_in));
thread::sleep(Duration::from_secs(60 * 10));

View file

@ -24,24 +24,29 @@ pub const MIN_TRANSPORT_MTU: usize = 128;
/// Maximum combined size of static public blob and metadata.
pub const MAX_INIT_PAYLOAD_SIZE: usize = MAX_NOISE_HANDSHAKE_SIZE - ALICE_NOISE_XK_ACK_MIN_SIZE;
/// Version 0: Noise_XK with NIST P-384 plus Kyber1024 hybrid exchange on session init.
pub(crate) const SESSION_PROTOCOL_VERSION: u8 = 0x00;
pub(crate) const COUNTER_WINDOW_MAX_OOO: usize = 16;
/// Maximum window over which packets may be reordered.
pub(crate) const COUNTER_WINDOW_MAX_OOO: usize = 32;
/// Maximum number of counter steps that the counter is allowed to skip ahead.
pub(crate) const COUNTER_WINDOW_MAX_SKIP_AHEAD: u64 = 16777216;
pub(crate) const PACKET_TYPE_DATA: u8 = 0;
pub(crate) const PACKET_TYPE_ALICE_NOISE_XK_INIT: u8 = 1;
pub(crate) const PACKET_TYPE_BOB_NOISE_XK_ACK: u8 = 2;
pub(crate) const PACKET_TYPE_ALICE_NOISE_XK_ACK: u8 = 3;
pub(crate) const PACKET_TYPE_REKEY_INIT: u8 = 4;
pub(crate) const PACKET_TYPE_REKEY_ACK: u8 = 5;
pub(crate) const PACKET_TYPE_NOP: u8 = 0;
pub(crate) const PACKET_TYPE_DATA: u8 = 1;
pub(crate) const PACKET_TYPE_ALICE_NOISE_XK_INIT: u8 = 2;
pub(crate) const PACKET_TYPE_BOB_NOISE_XK_ACK: u8 = 3;
pub(crate) const PACKET_TYPE_ALICE_NOISE_XK_ACK: u8 = 4;
pub(crate) const PACKET_TYPE_REKEY_INIT: u8 = 5;
pub(crate) const PACKET_TYPE_REKEY_ACK: u8 = 6;
pub(crate) const HEADER_SIZE: usize = 16;
pub(crate) const HEADER_PROTECT_ENCRYPT_START: usize = 6;
pub(crate) const HEADER_PROTECT_ENCRYPT_END: usize = 22;
pub(crate) const KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION: u8 = b'X'; // intermediate keys used in key exchanges
pub(crate) const KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION: u8 = b'x'; // intermediate keys used in key exchanges
pub(crate) const KBKDF_KEY_USAGE_LABEL_INIT_ENCRYPTION: u8 = b'x'; // AES-CTR encryption during initial setup
pub(crate) const KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION: u8 = b'X'; // HMAC-SHA384 during initial setup
pub(crate) const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; // AES-GCM in A->B direction
pub(crate) const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; // AES-GCM in B->A direction
pub(crate) const KBKDF_KEY_USAGE_LABEL_RATCHET: u8 = b'R'; // Key used in derivatin of next session key
@ -50,6 +55,7 @@ pub(crate) const MAX_FRAGMENTS: usize = 48; // hard protocol max: 63
pub(crate) const MAX_NOISE_HANDSHAKE_FRAGMENTS: usize = 16; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc.
pub(crate) const MAX_NOISE_HANDSHAKE_SIZE: usize = MAX_NOISE_HANDSHAKE_FRAGMENTS * MIN_TRANSPORT_MTU;
/// Size of keys used during derivation, mixing, etc. process.
pub(crate) const BASE_KEY_SIZE: usize = 64;
pub(crate) const AES_256_KEY_SIZE: usize = 32;
@ -162,7 +168,6 @@ impl RekeyAck {
pub(crate) trait ProtocolFlatBuffer {}
impl ProtocolFlatBuffer for AliceNoiseXKInit {}
impl ProtocolFlatBuffer for BobNoiseXKAck {}
//impl ProtocolFlatBuffer for NoiseXKAliceStaticAck {}
impl ProtocolFlatBuffer for RekeyInit {}
impl ProtocolFlatBuffer for RekeyAck {}

View file

@ -15,19 +15,18 @@ use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock, Weak};
use zerotier_crypto::aes::{Aes, AesCtr, AesGcm};
use zerotier_crypto::hash::{hmac_sha512, HMACSHA384, HMAC_SHA384_SIZE, SHA384, SHA384_HASH_SIZE};
use zerotier_crypto::hash::{hmac_sha512, HMACSHA384, HMAC_SHA384_SIZE, SHA384};
use zerotier_crypto::p384::{P384KeyPair, P384PublicKey, P384_ECDH_SHARED_SECRET_SIZE};
use zerotier_crypto::secret::Secret;
use zerotier_crypto::{random, secure_eq};
use zerotier_utils::arrayvec::ArrayVec;
use zerotier_utils::gatherarray::GatherArray;
use zerotier_utils::ringbuffermap::RingBufferMap;
use pqc_kyber::{KYBER_CIPHERTEXTBYTES, KYBER_SECRETKEYBYTES, KYBER_SSBYTES};
use pqc_kyber::{KYBER_SECRETKEYBYTES, KYBER_SSBYTES};
use crate::applicationlayer::ApplicationLayer;
use crate::error::Error;
use crate::fragged::Fragged;
use crate::proto::*;
use crate::sessionid::SessionId;
@ -37,7 +36,12 @@ use crate::sessionid::SessionId;
/// defragment incoming packets that are not yet associated with a session.
pub struct Context<Application: ApplicationLayer> {
max_incomplete_session_queue_size: usize,
defrag: Mutex<RingBufferMap<u64, GatherArray<Application::IncomingPacketBuffer, MAX_NOISE_HANDSHAKE_FRAGMENTS>, 256, 256>>,
defrag: Mutex<
HashMap<
(Application::PhysicalPath, u64),
Arc<Mutex<(Fragged<Application::IncomingPacketBuffer, MAX_NOISE_HANDSHAKE_FRAGMENTS>, i64)>>,
>,
>,
sessions: RwLock<SessionsById<Application>>,
}
@ -80,7 +84,7 @@ pub struct Session<Application: ApplicationLayer> {
receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO],
header_protection_cipher: Aes,
state: RwLock<State>,
defrag: Mutex<RingBufferMap<u64, GatherArray<Application::IncomingPacketBuffer, MAX_FRAGMENTS>, 16, 16>>,
defrag: [Mutex<Fragged<Application::IncomingPacketBuffer, MAX_FRAGMENTS>>; COUNTER_WINDOW_MAX_OOO],
}
/// Most of the mutable parts of a session state.
@ -91,20 +95,16 @@ struct State {
current_offer: Offer,
}
/// State related to an incoming session not yet fully established.
struct IncomingIncompleteSession {
timestamp: i64,
request_hash: [u8; SHA384_HASH_SIZE],
alice_session_id: SessionId,
bob_session_id: SessionId,
noise_es_ee: Secret<BASE_KEY_SIZE>,
bob_hk_ciphertext: [u8; KYBER_CIPHERTEXTBYTES],
hk: Secret<KYBER_SSBYTES>,
header_protection_key: Secret<AES_HEADER_PROTECTION_KEY_SIZE>,
bob_noise_e_secret: P384KeyPair,
}
/// State related to an outgoing session attempt.
struct OutgoingSessionInit {
last_retry_time: AtomicI64,
alice_noise_e_secret: P384KeyPair,
@ -114,14 +114,19 @@ struct OutgoingSessionInit {
init_packet: [u8; AliceNoiseXKInit::SIZE],
}
/// Latest outgoing offer, either an outgoing attempt or a rekey attempt.
struct OutgoingSessionAck {
last_retry_time: AtomicI64,
ack: [u8; MAX_NOISE_HANDSHAKE_SIZE],
ack_size: usize,
}
enum Offer {
None,
NoiseXKInit(Box<OutgoingSessionInit>),
RekeyInit(P384KeyPair, [u8; RekeyInit::SIZE], AtomicI64),
NoiseXKAck(Box<OutgoingSessionAck>),
RekeyInit(P384KeyPair, i64),
}
/// An ephemeral session key with expiration info.
struct SessionKey {
ratchet_key: Secret<BASE_KEY_SIZE>, // Key used in derivation of the next session key
receive_key: Secret<AES_256_KEY_SIZE>, // Receive side AES-GCM key
@ -134,6 +139,7 @@ struct SessionKey {
expire_at_counter: u64, // Hard error when this counter value is reached or exceeded
ratchet_count: u64, // Number of rekey events
bob: bool, // Was this side "Bob" in this exchange?
confirmed: bool, // Is this key confirmed by the other side?
}
impl<Application: ApplicationLayer> Context<Application> {
@ -141,7 +147,7 @@ impl<Application: ApplicationLayer> Context<Application> {
pub fn new(max_incomplete_session_queue_size: usize) -> Self {
Self {
max_incomplete_session_queue_size,
defrag: Mutex::new(RingBufferMap::new(random::next_u32_secure())),
defrag: Mutex::new(HashMap::new()),
sessions: RwLock::new(SessionsById {
active: HashMap::with_capacity(64),
incoming: HashMap::with_capacity(64),
@ -173,19 +179,14 @@ impl<Application: ApplicationLayer> Context<Application> {
for (id, s) in sessions.active.iter() {
if let Some(session) = s.upgrade() {
let state = session.state.read().unwrap();
match &state.current_offer {
Offer::None => {
if let Some(key) = state.keys[state.current_key].as_ref() {
if key.bob
&& (current_time >= key.rekey_at_time
|| session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter)
{
drop(state);
session.initiate_rekey(|b| send(&session, b), current_time);
}
}
}
if match &state.current_offer {
Offer::None => true,
Offer::NoiseXKInit(offer) => {
// If there's an outstanding attempt to open a session, retransmit this periodically
// in case the initial packet doesn't make it. Note that we currently don't have
// retransmission for the intermediate steps, so a new session may still fail if the
// packet loss rate is huge. The application layer has its own logic to keep trying
// under those conditions.
if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
offer.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
@ -199,11 +200,37 @@ impl<Application: ApplicationLayer> Context<Application> {
None,
);
}
false
}
Offer::RekeyInit(_, rekey_packet, last_retry_time) => {
if last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
last_retry_time.store(current_time, Ordering::Relaxed);
send(&session, &mut (rekey_packet.clone()));
Offer::NoiseXKAck(ack) => {
// We also keep retransmitting the final ACK until we get a valid DATA or NOP packet
// from Bob, otherwise we could get a half open session.
if ack.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
ack.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
|b| send(&session, b),
&mut (ack.ack.clone())[..ack.ack_size],
mtu,
PACKET_TYPE_ALICE_NOISE_XK_ACK,
state.remote_session_id,
0,
2,
Some(&session.header_protection_cipher),
);
}
false
}
Offer::RekeyInit(_, last_rekey_attempt_time) => *last_rekey_attempt_time < retry_cutoff,
} {
// Check whether we need to rekey if there is no pending offer or if the last rekey
// offer was before retry_cutoff (checked in the 'match' above).
if let Some(key) = state.keys[state.current_key].as_ref() {
if key.bob
&& (current_time >= key.rekey_at_time
|| session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter)
{
drop(state);
session.initiate_rekey(|b| send(&session, b), current_time);
}
}
}
@ -297,7 +324,7 @@ impl<Application: ApplicationLayer> Context<Application> {
init_packet: [0u8; AliceNoiseXKInit::SIZE],
})),
}),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
});
sessions.active.insert(local_session_id, Arc::downgrade(&session));
@ -310,7 +337,7 @@ impl<Application: ApplicationLayer> Context<Application> {
let init_packet = if let Offer::NoiseXKInit(offer) = &mut state.current_offer {
&mut offer.init_packet
} else {
panic!();
panic!(); // should be impossible
};
let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(init_packet).unwrap();
@ -321,12 +348,12 @@ impl<Application: ApplicationLayer> Context<Application> {
init.header_protection_key = header_protection_key.0;
aes_ctr_crypt_one_time_use_key(
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es.as_bytes()).as_bytes(),
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_ENCRYPTION>(noise_es.as_bytes()).as_bytes(),
&mut init_packet[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START],
);
let hmac = hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es.as_bytes()).as_bytes(),
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION>(noise_es.as_bytes()).as_bytes(),
&create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_INIT, 1),
&init_packet[HEADER_SIZE..AliceNoiseXKInit::AUTH_START],
);
@ -385,6 +412,7 @@ impl<Application: ApplicationLayer> Context<Application> {
mut check_allow_incoming_session: CheckAllowIncomingSession,
mut check_accept_session: CheckAcceptSession,
mut send: SendFunction,
source: &Application::PhysicalPath,
data_buf: &'b mut [u8],
mut incoming_packet_buf: Application::IncomingPacketBuffer,
mtu: usize,
@ -414,31 +442,27 @@ impl<Application: ApplicationLayer> Context<Application> {
if session.check_receive_window(incoming_counter) {
if fragment_count > 1 {
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
let mut defrag = session.defrag.lock().unwrap();
let fragment_gather_array = defrag.get_or_create_mut(&incoming_counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock
return self.process_complete_incoming_packet(
app,
&mut send,
&mut check_allow_incoming_session,
&mut check_accept_session,
data_buf,
incoming_counter,
assembled_packet.as_ref(),
packet_type,
Some(session),
None,
key_index,
mtu,
current_time,
);
} else {
return Ok(ReceiveResult::Ok);
}
let mut fragged = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO].lock().unwrap();
if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count)
{
drop(fragged);
return self.process_complete_incoming_packet(
app,
&mut send,
&mut check_allow_incoming_session,
&mut check_accept_session,
data_buf,
incoming_counter,
assembled_packet.as_ref(),
packet_type,
Some(session),
None,
key_index,
mtu,
current_time,
);
} else {
return Err(Error::InvalidPacket);
return Ok(ReceiveResult::Ok);
}
} else {
return self.process_complete_incoming_packet(
@ -476,10 +500,19 @@ impl<Application: ApplicationLayer> Context<Application> {
let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet);
if fragment_count > 1 {
let mut defrag = self.defrag.lock().unwrap();
let fragment_gather_array = defrag.get_or_create_mut(&incoming_counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock
let fragged_m = {
let mut defrag = self.defrag.lock().unwrap();
defrag
.entry((source.clone(), incoming_counter))
.or_insert_with(|| Arc::new(Mutex::new((Fragged::new(), current_time))))
.clone()
};
let mut fragged = fragged_m.lock().unwrap();
if let Some(assembled_packet) = fragged
.0
.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count)
{
self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter));
return self.process_complete_incoming_packet(
app,
&mut send,
@ -543,7 +576,7 @@ impl<Application: ApplicationLayer> Context<Application> {
// Generate incoming message nonce for decryption and authentication.
let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter);
if packet_type == PACKET_TYPE_DATA {
if packet_type <= PACKET_TYPE_DATA {
if let Some(session) = session {
let state = session.state.read().unwrap();
if let Some(key) = state.keys[key_index].as_ref() {
@ -590,22 +623,41 @@ impl<Application: ApplicationLayer> Context<Application> {
// Update the current key to point to this key if it's newer, since having received
// a packet encrypted with it proves that the other side has successfully derived it
// as well.
if state.current_key == key_index {
if state.current_key == key_index && key.confirmed {
drop(state);
} else {
let key_created_at_counter = key.created_at_counter;
let current_key_created_at_counter = key.created_at_counter;
drop(state);
let mut state = session.state.write().unwrap();
if let Some(other_session_key) = state.keys[state.current_key].as_ref() {
if other_session_key.created_at_counter < key_created_at_counter {
if state.current_key != key_index {
if let Some(other_session_key) = state.keys[state.current_key].as_ref() {
if other_session_key.created_at_counter < current_key_created_at_counter {
state.current_key = key_index;
}
} else {
state.current_key = key_index;
}
} else {
state.current_key = key_index;
state.keys[key_index].as_mut().unwrap().confirmed = true;
}
// If we got a valid data packet from Bob, this means we can cancel any offers
// that are still oustanding for initialization.
match &state.current_offer {
Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => {
state.current_offer = Offer::None;
}
_ => {}
}
}
return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len]));
if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(session, &mut data_buf[..data_len]));
} else {
println!("nop");
}
} else {
return Err(Error::OutOfSequence);
}
@ -650,120 +702,104 @@ impl<Application: ApplicationLayer> Context<Application> {
if incoming_counter != 1 || session.is_some() {
return Err(Error::OutOfSequence);
}
// Hash the init packet so we can check to see if it's just being retransmitted. Alice may
// attempt to retransmit this packet until she receives a response.
let request_hash = SHA384::hash(&pkt_assembled);
let (alice_session_id, mut bob_session_id, noise_es_ee, bob_hk_ciphertext, header_protection_key, bob_noise_e);
if let Some(incoming) = incoming {
// If we've already seen this exact packet before, just recall the same state so we send the
// same response.
if secure_eq(&request_hash, &incoming.request_hash) {
alice_session_id = incoming.alice_session_id;
bob_session_id = incoming.bob_session_id;
noise_es_ee = incoming.noise_es_ee.clone();
bob_hk_ciphertext = incoming.bob_hk_ciphertext;
header_protection_key = incoming.header_protection_key.clone();
bob_noise_e = *incoming.bob_noise_e_secret.public_key_bytes();
} else {
return Err(Error::FailedAuthentication);
}
} else {
// Otherwise parse the packet, authenticate, generate keys, etc. and record state in an
// incoming state object until this phase of the negotiation is done.
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_noise_e = P384PublicKey::from_bytes(&pkt.alice_noise_e).ok_or(Error::FailedAuthentication)?;
let noise_es = app.get_local_s_keypair().agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?;
// Authenticate packet and also prove that Alice knows our static public key.
if !secure_eq(
&pkt.hmac_es,
&hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es.as_bytes()).as_bytes(),
&incoming_message_nonce,
&pkt_assembled[HEADER_SIZE..AliceNoiseXKInit::AUTH_START],
),
) {
return Err(Error::FailedAuthentication);
}
// Let application filter incoming connection attempt by whatever criteria it wants.
if !check_allow_incoming_session() {
return Ok(ReceiveResult::Rejected);
}
// Decrypt encrypted part of payload.
aes_ctr_crypt_one_time_use_key(
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es.as_bytes()).as_bytes(),
&mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START],
);
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?;
header_protection_key = Secret(pkt.header_protection_key);
// Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create
// a Kyber ciphertext to send back to Alice.
let bob_noise_e_secret = P384KeyPair::generate();
bob_noise_e = bob_noise_e_secret.public_key_bytes().clone();
noise_es_ee = Secret(hmac_sha512(
noise_es.as_bytes(),
bob_noise_e_secret
.agree(&alice_noise_e)
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
));
let (hk_ct, hk) = pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default())
.map_err(|_| Error::FailedAuthentication)
.map(|(ct, hk)| (ct, Secret(hk)))?;
bob_hk_ciphertext = hk_ct;
let mut sessions = self.sessions.write().unwrap();
loop {
bob_session_id = SessionId::random();
if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) {
break;
}
}
if sessions.incoming.len() >= self.max_incomplete_session_queue_size {
// If this queue is too big, we remove the latest entry and replace it. The latest
// is used because under flood conditions this is most likely to be another bogus
// entry. If we find one that is actually timed out, that one is replaced instead.
let mut newest = i64::MIN;
let mut replace_id = None;
let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
for (id, s) in sessions.incoming.iter() {
if s.timestamp <= cutoff_time {
replace_id = Some(*id);
break;
} else if s.timestamp >= newest {
newest = s.timestamp;
replace_id = Some(*id);
}
}
let _ = sessions.incoming.remove(replace_id.as_ref().unwrap());
}
// Reserve session ID on this side and record incomplete session state.
sessions.incoming.insert(
bob_session_id,
Arc::new(IncomingIncompleteSession {
timestamp: current_time,
request_hash,
alice_session_id,
bob_session_id,
noise_es_ee: noise_es_ee.clone(),
bob_hk_ciphertext,
hk,
bob_noise_e_secret,
header_protection_key: Secret(pkt.header_protection_key),
}),
);
if pkt_assembled.len() != AliceNoiseXKInit::SIZE {
return Err(Error::InvalidPacket);
}
// Otherwise parse the packet, authenticate, generate keys, etc. and record state in an
// incoming state object until this phase of the negotiation is done.
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_noise_e = P384PublicKey::from_bytes(&pkt.alice_noise_e).ok_or(Error::FailedAuthentication)?;
let noise_es = app.get_local_s_keypair().agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?;
// Authenticate packet and also prove that Alice knows our static public key.
if !secure_eq(
&pkt.hmac_es,
&hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION>(noise_es.as_bytes()).as_bytes(),
&incoming_message_nonce,
&pkt_assembled[HEADER_SIZE..AliceNoiseXKInit::AUTH_START],
),
) {
return Err(Error::FailedAuthentication);
}
// Let application filter incoming connection attempt by whatever criteria it wants.
if !check_allow_incoming_session() {
return Ok(ReceiveResult::Rejected);
}
// Decrypt encrypted part of payload.
aes_ctr_crypt_one_time_use_key(
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_ENCRYPTION>(noise_es.as_bytes()).as_bytes(),
&mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START],
);
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?;
let header_protection_key = Secret(pkt.header_protection_key);
// Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create
// a Kyber ciphertext to send back to Alice.
let bob_noise_e_secret = P384KeyPair::generate();
let bob_noise_e = bob_noise_e_secret.public_key_bytes().clone();
let noise_es_ee = Secret(hmac_sha512(
noise_es.as_bytes(),
bob_noise_e_secret
.agree(&alice_noise_e)
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
));
let (bob_hk_ciphertext, hk) = pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default())
.map_err(|_| Error::FailedAuthentication)
.map(|(ct, hk)| (ct, Secret(hk)))?;
let mut sessions = self.sessions.write().unwrap();
let mut bob_session_id;
loop {
bob_session_id = SessionId::random();
if !sessions.active.contains_key(&bob_session_id) && !sessions.incoming.contains_key(&bob_session_id) {
break;
}
}
if sessions.incoming.len() >= self.max_incomplete_session_queue_size {
// If this queue is too big, we remove the latest entry and replace it. The latest
// is used because under flood conditions this is most likely to be another bogus
// entry. If we find one that is actually timed out, that one is replaced instead.
let mut newest = i64::MIN;
let mut replace_id = None;
let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
for (id, s) in sessions.incoming.iter() {
if s.timestamp <= cutoff_time {
replace_id = Some(*id);
break;
} else if s.timestamp >= newest {
newest = s.timestamp;
replace_id = Some(*id);
}
}
let _ = sessions.incoming.remove(replace_id.as_ref().unwrap());
}
// Reserve session ID on this side and record incomplete session state.
sessions.incoming.insert(
bob_session_id,
Arc::new(IncomingIncompleteSession {
timestamp: current_time,
alice_session_id,
bob_session_id,
noise_es_ee: noise_es_ee.clone(),
hk,
bob_noise_e_secret,
header_protection_key: Secret(pkt.header_protection_key),
}),
);
debug_assert!(!sessions.active.contains_key(&bob_session_id));
drop(sessions);
// Create Bob's ephemeral counter-offer reply.
let mut ack_packet = [0u8; BobNoiseXKAck::SIZE];
let ack: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut ack_packet)?;
@ -774,13 +810,13 @@ impl<Application: ApplicationLayer> Context<Application> {
// Encrypt main section of reply.
aes_ctr_crypt_one_time_use_key(
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es_ee.as_bytes()).as_bytes(),
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_ENCRYPTION>(noise_es_ee.as_bytes()).as_bytes(),
&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START],
);
// Add HMAC-SHA384 to reply packet.
let reply_hmac = hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es_ee.as_bytes()).as_bytes(),
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION>(noise_es_ee.as_bytes()).as_bytes(),
&create_message_nonce(PACKET_TYPE_BOB_NOISE_XK_ACK, 1),
&ack_packet[HEADER_SIZE..BobNoiseXKAck::AUTH_START],
);
@ -812,9 +848,18 @@ impl<Application: ApplicationLayer> Context<Application> {
if incoming_counter != 1 || incoming.is_some() {
return Err(Error::OutOfSequence);
}
if pkt_assembled.len() != BobNoiseXKAck::SIZE {
return Err(Error::InvalidPacket);
}
if let Some(session) = session {
let state = session.state.read().unwrap();
// This doesn't make sense if the session is up.
if state.keys[state.current_key].is_some() {
return Err(Error::OutOfSequence);
}
if let Offer::NoiseXKInit(outgoing_offer) = &state.current_offer {
let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?;
@ -830,7 +875,7 @@ impl<Application: ApplicationLayer> Context<Application> {
));
let noise_es_ee_kex_hmac_key =
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es_ee.as_bytes());
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION>(noise_es_ee.as_bytes());
// Authenticate Bob's reply and the validity of bob_noise_e.
if !secure_eq(
@ -846,7 +891,7 @@ impl<Application: ApplicationLayer> Context<Application> {
// Decrypt encrypted portion of message.
aes_ctr_crypt_one_time_use_key(
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es_ee.as_bytes()).as_bytes(),
kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_ENCRYPTION>(noise_es_ee.as_bytes()).as_bytes(),
&mut pkt_assembled[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START],
);
let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?;
@ -915,7 +960,7 @@ impl<Application: ApplicationLayer> Context<Application> {
// key exchange. Bob won't be able to do this until he decrypts and parses Alice's
// identity, so the first HMAC is to let him authenticate that first.
let hmac_es_ee_se_hk_psk = hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es_ee_se_hk_psk.as_bytes())
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION>(noise_es_ee_se_hk_psk.as_bytes())
.as_bytes(),
&reply_message_nonce,
&reply_buffer[HEADER_SIZE..reply_len],
@ -933,9 +978,15 @@ impl<Application: ApplicationLayer> Context<Application> {
current_time,
2,
false,
false,
));
debug_assert!(state.keys[1].is_none());
state.current_key = 0;
state.current_offer = Offer::None;
state.current_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck {
last_retry_time: AtomicI64::new(current_time),
ack: reply_buffer,
ack_size: reply_len,
}));
}
send_with_fragmentation(
@ -980,18 +1031,13 @@ impl<Application: ApplicationLayer> Context<Application> {
}
if let Some(incoming) = incoming {
// Check timeout, negotiations aren't allowed to take longer than this.
if (current_time - incoming.timestamp) > Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS {
return Err(Error::UnknownLocalSessionId);
}
// Check the first HMAC to verify against the currently known noise_es_ee key, which verifies
// that this reply is part of this session.
let auth_start = pkt_assembled.len() - ALICE_NOISE_XK_ACK_AUTH_SIZE;
if !secure_eq(
&pkt_assembled[auth_start..pkt_assembled.len() - HMAC_SHA384_SIZE],
&hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(incoming.noise_es_ee.as_bytes())
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION>(incoming.noise_es_ee.as_bytes())
.as_bytes(),
&incoming_message_nonce,
&pkt_assembled[HEADER_SIZE..auth_start],
@ -1065,7 +1111,7 @@ impl<Application: ApplicationLayer> Context<Application> {
if !secure_eq(
&pkt_assembly_buffer_copy[auth_start + HMAC_SHA384_SIZE..pkt_assembled.len()],
&hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es_ee_se_hk_psk.as_bytes())
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_INIT_AUTHENTICATION>(noise_es_ee_se_hk_psk.as_bytes())
.as_bytes(),
&incoming_message_nonce,
&pkt_assembly_buffer_copy[HEADER_SIZE..auth_start + HMAC_SHA384_SIZE],
@ -1084,22 +1130,30 @@ impl<Application: ApplicationLayer> Context<Application> {
state: RwLock::new(State {
remote_session_id: Some(incoming.alice_session_id),
keys: [
Some(SessionKey::new::<Application>(noise_es_ee_se_hk_psk, 1, current_time, 2, true)),
Some(SessionKey::new::<Application>(
noise_es_ee_se_hk_psk,
1,
current_time,
2,
true,
true,
)),
None,
],
current_key: 0,
current_offer: Offer::None,
}),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())),
});
// Promote this from an incomplete session to an established session.
{
let mut sessions = self.sessions.write().unwrap();
sessions.incoming.remove(&incoming.bob_session_id);
sessions.active.insert(incoming.bob_session_id, Arc::downgrade(&session));
}
let _ = session.send_nop(|b| send(Some(&session), b));
return Ok(ReceiveResult::OkNewSession(session));
} else {
return Err(Error::UnknownLocalSessionId);
@ -1177,6 +1231,7 @@ impl<Application: ApplicationLayer> Context<Application> {
current_time,
counter,
false,
false,
));
return Ok(ReceiveResult::Ok);
@ -1202,7 +1257,7 @@ impl<Application: ApplicationLayer> Context<Application> {
if let Some(session) = session {
let state = session.state.read().unwrap();
if let Offer::RekeyInit(alice_e_secret, _, _) = &state.current_offer {
if let Offer::RekeyInit(alice_e_secret, _) = &state.current_offer {
if let Some(key) = state.keys[key_index].as_ref() {
// Only the current "Bob" initiates rekeys and expects this ACK.
if key.bob {
@ -1234,6 +1289,7 @@ impl<Application: ApplicationLayer> Context<Application> {
current_time,
session.send_counter.load(Ordering::Acquire),
true,
true,
));
state.current_key = next_key_index; // this is an ACK so it's confirmed
state.current_offer = Offer::None;
@ -1325,6 +1381,32 @@ impl<Application: ApplicationLayer> Session<Application> {
return Err(Error::SessionNotEstablished);
}
/// Send a NOP to the other side (e.g. for keep alive).
pub fn send_nop<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction) -> Result<(), Error> {
let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(session_key) = state.keys[state.current_key].as_ref() {
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE];
let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter));
nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt());
session_key.return_send_cipher(c);
set_packet_header(
&mut nop,
1,
0,
PACKET_TYPE_NOP,
u64::from(remote_session_id),
state.current_key,
counter,
);
send(&mut nop);
}
}
return Err(Error::SessionNotEstablished);
}
/// Check whether this session is established.
pub fn established(&self) -> bool {
let state = self.state.read().unwrap();
@ -1381,7 +1463,7 @@ impl<Application: ApplicationLayer> Session<Application> {
.encrypt_block_in_place(&mut rekey_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
send(&mut rekey_buf);
self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, rekey_buf, AtomicI64::new(current_time));
self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, current_time);
}
}
}
@ -1448,15 +1530,11 @@ fn set_packet_header(
#[inline(always)]
fn parse_packet_header(incoming_packet: &[u8]) -> (usize, u8, u8, u8, u64) {
let raw_header_a = u16::from_le_bytes(incoming_packet[6..8].try_into().unwrap());
let key_index = (raw_header_a & 1) as usize;
let packet_type = (raw_header_a.wrapping_shr(1) & 7) as u8;
let fragment_count = ((raw_header_a.wrapping_shr(4) & 63) + 1) as u8;
let fragment_no = raw_header_a.wrapping_shr(10) as u8;
(
key_index,
packet_type,
fragment_count,
fragment_no,
(raw_header_a & 1) as usize,
(raw_header_a.wrapping_shr(1) & 7) as u8,
((raw_header_a.wrapping_shr(4) & 63) + 1) as u8,
raw_header_a.wrapping_shr(10) as u8,
u64::from_le_bytes(incoming_packet[8..16].try_into().unwrap()),
)
}
@ -1532,11 +1610,12 @@ impl SessionKey {
ratchet_count: u64,
current_time: i64,
current_counter: u64,
role_is_bob: bool,
bob: bool,
confirmed: bool,
) -> Self {
let a2b = kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB>(key.as_bytes());
let b2a = kbkdf::<AES_256_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE>(key.as_bytes());
let (receive_key, send_key) = if role_is_bob {
let (receive_key, send_key) = if bob {
(a2b, b2a)
} else {
(b2a, a2b)
@ -1557,7 +1636,8 @@ impl SessionKey {
rekey_at_counter: current_counter.checked_add(Application::REKEY_AFTER_USES).unwrap(),
expire_at_counter: current_counter.checked_add(Application::EXPIRE_AFTER_USES).unwrap(),
ratchet_count,
bob: role_is_bob,
bob,
confirmed,
}
}