// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. use std::collections::LinkedList; use std::io::{Read, Write}; use std::num::NonZeroU64; use std::ops::Deref; use std::sync::atomic::{AtomicU64, Ordering}; use crate::aes::{Aes, AesGcm}; use crate::hash::{hmac_sha384, hmac_sha512, SHA384}; use crate::p384::{P384KeyPair, P384PublicKey, P384_PUBLIC_KEY_SIZE}; use crate::random; use crate::secret::Secret; use zerotier_utils::gatherarray::GatherArray; use zerotier_utils::memory; use zerotier_utils::ringbuffermap::RingBufferMap; use zerotier_utils::varint; use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard}; pub const MIN_PACKET_SIZE: usize = HEADER_SIZE; pub const MIN_MTU: usize = 1280; /// Start attempting to rekey after a key has been used to send packets this many times. const REKEY_AFTER_USES: u64 = 536870912; /// Maximum random jitter to add to rekey-after usage count. const REKEY_AFTER_USES_MAX_JITTER: u32 = 1048576; /// Hard expiration after this many uses. const EXPIRE_AFTER_USES: u64 = (u32::MAX - 1024) as u64; /// Start attempting to rekey after a key has been in use for this many milliseconds. const REKEY_AFTER_TIME_MS: i64 = 1000 * 60 * 60; // 1 hour /// Maximum random jitter to add to rekey-after time. const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 5; /// Rate limit for sending new offers to attempt to re-key. const OFFER_RATE_LIMIT_MS: i64 = 2000; /// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication) const SESSION_PROTOCOL_VERSION: u8 = 0x00; // Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use const PACKET_TYPE_DATA: u8 = 0; const PACKET_TYPE_NOP: u8 = 1; const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice" const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 3; // "bob" const E1_TYPE_NONE: u8 = 0; const E1_TYPE_KYBER1024: u8 = 1; const MAX_FRAGMENTS: usize = 48; // protocol max: 63 const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc. const HEADER_SIZE: usize = 16; const HEADER_CHECK_SIZE: usize = 4; const AES_GCM_TAG_SIZE: usize = 16; const AES_GCM_NONCE_SIZE: usize = 12; const AES_GCM_NONCE_START: usize = 4; const AES_GCM_NONCE_END: usize = 16; const HMAC_SIZE: usize = 48; const SESSION_ID_SIZE: usize = 6; const KEY_HISTORY_SIZE_MAX: usize = 3; const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M'; const KBKDF_KEY_USAGE_LABEL_HEADER_MAC: u8 = b'H'; const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; /// Aribitrary starting value for key derivation chain. /// /// It doesn't matter very much what this is, but it's good for it to be unique. const KEY_DERIVATION_CHAIN_STARTING_SALT: [u8; 64] = [ // macOS command line to generate: // echo -n 'ZSSP_Noise_IKpsk2_NISTP384_?KYBER1024_AESGCM_SHA512' | shasum -a 512 | cut -d ' ' -f 1 | xxd -r -p | xxd -i 0x35, 0x6a, 0x75, 0xc0, 0xbf, 0xbe, 0xc3, 0x59, 0x70, 0x94, 0x50, 0x69, 0x4c, 0xa2, 0x08, 0x40, 0xc7, 0xdf, 0x67, 0xa8, 0x68, 0x52, 0x6e, 0xd5, 0xdd, 0x77, 0xec, 0x59, 0x6f, 0x8e, 0xa1, 0x99, 0xb4, 0x32, 0x85, 0xaf, 0x7f, 0x0d, 0xa9, 0x6c, 0x01, 0xfb, 0x72, 0x46, 0xc0, 0x09, 0x58, 0xb8, 0xe0, 0xa8, 0xcf, 0xb1, 0x58, 0x04, 0x6e, 0x32, 0xba, 0xa8, 0xb8, 0xf9, 0x0a, 0xa4, 0xbf, 0x36, ]; pub enum Error { /// The packet was addressed to an unrecognized local session UnknownLocalSessionId(SessionId), /// Packet was not well formed InvalidPacket, /// An invalid paramter was supplied to the function InvalidParameter, /// Packet failed one or more authentication (MAC) checks FailedAuthentication, /// New session was rejected by caller's supplied authentication check function NewSessionRejected, /// Rekeying failed and session secret has reached its maximum usage count MaxKeyLifetimeExceeded, /// Attempt to send using session without established key. SessionNotEstablished, /// Packet ignored by rate limiter. RateLimited, /// Other end sent a protocol version we don't support. UnknownProtocolVersion, /// Supplied data buffer is too small to receive data. DataBufferTooSmall, /// An internal error occurred. OtherError(Box), } impl From for Error { #[cold] #[inline(never)] fn from(e: std::io::Error) -> Self { Self::OtherError(Box::new(e)) } } #[cold] #[inline(never)] extern "C" fn unlikely_branch() {} impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::UnknownLocalSessionId(id) => f.write_str(format!("UnknownLocalSessionId({})", id.0.get()).as_str()), Self::InvalidPacket => f.write_str("InvalidPacket"), Self::InvalidParameter => f.write_str("InvalidParameter"), Self::FailedAuthentication => f.write_str("FailedAuthentication"), Self::NewSessionRejected => f.write_str("NewSessionRejected"), Self::MaxKeyLifetimeExceeded => f.write_str("MaxKeyLifetimeExceeded"), Self::SessionNotEstablished => f.write_str("SessionNotEstablished"), Self::RateLimited => f.write_str("RateLimited"), Self::UnknownProtocolVersion => f.write_str("UnknownProtocolVersion"), Self::DataBufferTooSmall => f.write_str("DataBufferTooSmall"), Self::OtherError(e) => f.write_str(format!("OtherError({})", e.to_string()).as_str()), } } } impl std::error::Error for Error {} impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } } pub enum ReceiveResult<'a, H: Host> { /// Packet is valid, no action needs to be taken. Ok, /// Packet is valid and a data payload was decoded and authenticated. /// /// The returned reference is to the filled parts of the data buffer supplied to receive. OkData(&'a mut [u8]), /// Packet is valid and a new session was created. /// /// The session will have already been gated by the accept_new_session() method in the Host trait. OkNewSession(Session), /// Packet apperas valid but was ignored e.g. as a duplicate. Ignored, } #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[repr(transparent)] pub struct SessionId(NonZeroU64); impl SessionId { pub const MAX_BIT_MASK: u64 = 0xffffffffffff; #[inline(always)] pub fn new_from_u64(i: u64) -> Option { debug_assert!(i <= Self::MAX_BIT_MASK); NonZeroU64::new(i).map(|i| Self(i)) } pub fn new_from_reader(r: &mut R) -> std::io::Result> { let mut tmp = [0_u8; 8]; r.read_exact(&mut tmp[..SESSION_ID_SIZE]).map(|_| NonZeroU64::new(u64::from_le_bytes(tmp)).map(|i| Self(i))) } pub fn new_random() -> Self { Self(NonZeroU64::new((random::next_u64_secure() & Self::MAX_BIT_MASK).max(1)).unwrap()) } #[inline(always)] pub fn to_bytes(&self) -> [u8; SESSION_ID_SIZE] { self.0.get().to_le_bytes()[0..6].try_into().unwrap() } } impl From for u64 { #[inline(always)] fn from(sid: SessionId) -> Self { sid.0.get() } } /// Trait to implement to integrate the session into an application. pub trait Host: Sized { /// Arbitrary object that can be associated with sessions. type AssociatedObject; /// Arbitrary object that dereferences to the session, such as Arc>. type SessionRef: Deref>; /// A buffer containing data read from the network that can be cached. /// /// This can be e.g. a pooled buffer that automatically returns itself to the pool when dropped. type IncomingPacketBuffer: AsRef<[u8]>; /// Get a reference to this host's static public key blob. /// /// This must contain a NIST P-384 public key but can contain other information. fn get_local_s_public(&self) -> &[u8]; /// Get SHA384(this host's static public key blob), included here so we don't have to calculate it each time. fn get_local_s_public_hash(&self) -> &[u8; 48]; /// Get a reference to this hosts' static public key's NIST P-384 secret key pair fn get_local_s_keypair_p384(&self) -> &P384KeyPair; /// Extract the NIST P-384 ECC public key component from a static public key blob or return None on failure. fn extract_p384_static(static_public: &[u8]) -> Option; /// Look up a local session by local ID. fn session_lookup(&self, local_session_id: SessionId) -> Option; /// Check whether a new session should be accepted. /// /// On success a tuple of local session ID, static secret, and associated object is returned. The /// static secret is whatever results from agreement between the local and remote static public /// keys. fn accept_new_session(&self, remote_static_public: &[u8], remote_metadata: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)>; } pub struct Session { pub id: SessionId, pub associated_object: H::AssociatedObject, send_counter: Counter, psk: Secret<64>, // Arbitrary PSK provided by external code ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer header_check_cipher: Aes, // Cipher used for fast 32-bit header MAC state: RwLock, // Mutable parts of state (other than defrag buffers) remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob) remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key defrag: Mutex, 16, 4>>, } struct MutableState { remote_session_id: Option, keys: LinkedList, offer: Option, } /// State information to associate with receiving contexts such as sockets or remote paths/endpoints. pub struct ReceiveContext { initial_offer_defrag: Mutex, 1024, 128>>, incoming_init_header_check_cipher: Aes, } impl Session { #[inline] pub fn new( host: &H, mut send: SendFunction, local_session_id: SessionId, remote_s_public: &[u8], offer_metadata: &[u8], psk: &Secret<64>, associated_object: H::AssociatedObject, mtu: usize, current_time: i64, jedi: bool, ) -> Result { if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) { if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) { let send_counter = Counter::new(); let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); let remote_s_public_hash = SHA384::hash(remote_s_public); let outgoing_init_header_check_cipher = Aes::new(kbkdf512(&remote_s_public_hash, KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); if let Ok(offer) = EphemeralOffer::create_alice_offer( &mut send, send_counter.next(), local_session_id, None, host.get_local_s_public(), offer_metadata, &remote_s_public_p384, &remote_s_public_hash, &ss, &outgoing_init_header_check_cipher, mtu, current_time, jedi, ) { return Ok(Self { id: local_session_id, associated_object, send_counter, psk: psk.clone(), ss, header_check_cipher, state: RwLock::new(MutableState { remote_session_id: None, keys: LinkedList::new(), offer: Some(offer), }), remote_s_public_hash, remote_s_public_p384: remote_s_public_p384.as_bytes().clone(), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), }); } } } return Err(Error::InvalidParameter); } #[inline] pub fn send(&self, mut send: SendFunction, mtu_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> { debug_assert!(mtu_buffer.len() >= MIN_MTU); let state = self.state.read(); if let Some(remote_session_id) = state.remote_session_id { if let Some(key) = state.keys.front() { let mut packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; let counter = self.send_counter.next(); send_with_fragmentation_init_header(mtu_buffer, packet_len, mtu_buffer.len(), PACKET_TYPE_DATA, remote_session_id.into(), counter); let mut c = key.get_send_cipher(counter)?; c.init(mtu_buffer); if packet_len > mtu_buffer.len() { let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap(); let fragment_data_mtu = mtu_buffer.len() - HEADER_SIZE; let last_fragment_data_mtu = mtu_buffer.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); loop { debug_assert!(data.len() > last_fragment_data_mtu); c.crypt(&data[HEADER_CHECK_SIZE..fragment_data_mtu], &mut mtu_buffer[HEADER_SIZE..]); let hc = header_check(mtu_buffer, &self.header_check_cipher); mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes()); send(mtu_buffer); data = &data[fragment_data_mtu..]; debug_assert!(header[7].wrapping_shr(2) < 63); header[7] += 0x04; // increment fragment number mtu_buffer[HEADER_CHECK_SIZE..HEADER_SIZE].copy_from_slice(&header); if data.len() <= last_fragment_data_mtu { break; } } packet_len = data.len() + HEADER_SIZE + AES_GCM_TAG_SIZE; } let gcm_tag_idx = data.len() + HEADER_SIZE; c.crypt(data, &mut mtu_buffer[HEADER_SIZE..gcm_tag_idx]); let hc = header_check(mtu_buffer, &self.header_check_cipher); mtu_buffer[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes()); mtu_buffer[gcm_tag_idx..packet_len].copy_from_slice(&c.finish()); send(&mut mtu_buffer[..packet_len]); key.return_send_cipher(c); return Ok(()); } } return Err(Error::SessionNotEstablished); } #[inline] pub fn rekey_check(&self, host: &H, mut send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) { let state = self.state.upgradable_read(); if let Some(key) = state.keys.front() { if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) { if let Some(remote_s_public_p384) = P384PublicKey::from_bytes(&self.remote_s_public_p384) { if let Ok(offer) = EphemeralOffer::create_alice_offer( &mut send, self.send_counter.next(), self.id, state.remote_session_id, host.get_local_s_public(), offer_metadata, &remote_s_public_p384, &self.remote_s_public_hash, &self.ss, &self.header_check_cipher, mtu, current_time, jedi, ) { let _ = RwLockUpgradableReadGuard::upgrade(state).offer.replace(offer); } } } } } } impl ReceiveContext { #[inline] pub fn new(host: &H) -> Self { Self { initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), incoming_init_header_check_cipher: Aes::new(kbkdf512(host.get_local_s_public_hash(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()), } } #[inline] pub fn receive<'a, SendFunction: FnMut(&mut [u8])>( &self, host: &H, mut send: SendFunction, data_buf: &'a mut [u8], incoming_packet_buf: H::IncomingPacketBuffer, mtu: usize, jedi: bool, current_time: i64, ) -> Result, Error> { let incoming_packet = incoming_packet_buf.as_ref(); if incoming_packet.len() < MIN_PACKET_SIZE { unlikely_branch(); return Err(Error::InvalidPacket); } let header_0_8 = memory::u64_from_le_bytes(&incoming_packet[HEADER_CHECK_SIZE..12]); // session ID, type, frag info let counter = memory::u32_from_le_bytes(&incoming_packet[12..16]); let local_session_id = SessionId::new_from_u64(header_0_8 & SessionId::MAX_BIT_MASK); let packet_type = (header_0_8.wrapping_shr(48) as u8) & 15; let fragment_count = ((header_0_8.wrapping_shr(52) as u8) & 63).wrapping_add(1); let fragment_no = (header_0_8.wrapping_shr(58) as u8) & 63; if let Some(local_session_id) = local_session_id { if let Some(session) = host.session_lookup(local_session_id) { if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &session.header_check_cipher) { unlikely_branch(); return Err(Error::FailedAuthentication); } if fragment_count > 1 { if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count { let mut defrag = session.defrag.lock(); let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time); } } else { unlikely_branch(); return Err(Error::InvalidPacket); } } else { return self.receive_complete(host, &mut send, data_buf, &[incoming_packet_buf], packet_type, Some(session), mtu, jedi, current_time); } } else { unlikely_branch(); return Err(Error::UnknownLocalSessionId(local_session_id)); } } else { unlikely_branch(); if memory::u32_from_ne_bytes(incoming_packet) != header_check(incoming_packet, &self.incoming_init_header_check_cipher) { unlikely_branch(); return Err(Error::FailedAuthentication); } let mut defrag = self.initial_offer_defrag.lock(); let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count)); if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) { drop(defrag); // release lock return self.receive_complete(host, &mut send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); } }; return Ok(ReceiveResult::Ok); } fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>( &self, host: &H, send: &mut SendFunction, data_buf: &'a mut [u8], fragments: &[H::IncomingPacketBuffer], packet_type: u8, session: Option, mtu: usize, jedi: bool, current_time: i64, ) -> Result, Error> { debug_assert!(fragments.len() >= 1); debug_assert_eq!(PACKET_TYPE_DATA, 0); debug_assert_eq!(PACKET_TYPE_NOP, 1); if packet_type <= PACKET_TYPE_NOP { if let Some(session) = session { let state = session.state.read(); let key_count = state.keys.len(); for (key_index, key) in state.keys.iter().enumerate() { let tail = fragments.last().unwrap().as_ref(); if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { unlikely_branch(); return Err(Error::InvalidPacket); } let mut c = key.get_receive_cipher(); c.init(&fragments.first().unwrap().as_ref()[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); let mut data_len = 0; for f in fragments[..(fragments.len() - 1)].iter() { let f = f.as_ref(); debug_assert!(f.len() >= HEADER_SIZE); let current_frag_data_start = data_len; data_len += f.len() - HEADER_SIZE; if data_len > data_buf.len() { unlikely_branch(); key.return_receive_cipher(c); return Err(Error::DataBufferTooSmall); } c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); } let current_frag_data_start = data_len; data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); if data_len > data_buf.len() { unlikely_branch(); key.return_receive_cipher(c); return Err(Error::DataBufferTooSmall); } c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]); let tag = c.finish(); key.return_receive_cipher(c); if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) { // Drop obsolete keys if we had to iterate past the first key to get here. if key_index > 0 { unlikely_branch(); drop(state); let mut state = session.state.write(); if state.keys.len() == key_count { for _ in 0..key_index { let _ = state.keys.pop_front(); } } } if packet_type == PACKET_TYPE_DATA { return Ok(ReceiveResult::OkData(&mut data_buf[..data_len])); } else { unlikely_branch(); return Ok(ReceiveResult::Ok); } } } return Err(Error::FailedAuthentication); } else { unlikely_branch(); return Err(Error::SessionNotEstablished); } } else { unlikely_branch(); let mut incoming_packet_buf = [0_u8; MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS]; let mut incoming_packet_len = 0; for i in 0..fragments.len() { let mut ff = fragments[i].as_ref(); debug_assert!(ff.len() >= MIN_PACKET_SIZE); if i > 0 { ff = &ff[HEADER_SIZE..]; } let j = incoming_packet_len + ff.len(); if j > incoming_packet_buf.len() { return Err(Error::InvalidPacket); } incoming_packet_buf[incoming_packet_len..j].copy_from_slice(ff); incoming_packet_len = j; } let original_ciphertext = incoming_packet_buf.clone(); let incoming_packet = &mut incoming_packet_buf[..incoming_packet_len]; if incoming_packet_len <= HEADER_SIZE { return Err(Error::InvalidPacket); } if incoming_packet[HEADER_SIZE] != SESSION_PROTOCOL_VERSION { return Err(Error::UnknownProtocolVersion); } match packet_type { PACKET_TYPE_KEY_OFFER => { // alice (remote) -> bob (local) if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE) { return Err(Error::InvalidPacket); } let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE); let aes_gcm_tag_end = incoming_packet_len - (HMAC_SIZE + HMAC_SIZE); let hmac1_end = incoming_packet_len - HMAC_SIZE; // Check that the sender knows this host's identity before doing anything else. if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[HEADER_CHECK_SIZE..hmac1_end]).eq(&incoming_packet[hmac1_end..]) { return Err(Error::FailedAuthentication); } // Check rate limit if this session is known. if let Some(session) = session.as_ref() { if let Some(offer) = session.state.read().offer.as_ref() { if (current_time - offer.creation_time) < OFFER_RATE_LIMIT_MS { return Err(Error::RateLimited); } } } let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) .and_then(|pk| host.get_local_s_keypair_p384().agree(&pk).map(move |s| (pk, s))) .ok_or(Error::FailedAuthentication)?; let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes())); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false); c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); let c = c.finish(); if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { return Err(Error::FailedAuthentication); } let (alice_session_id, alice_s_public, alice_metadata, alice_e1_public) = parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?; if let Some(session) = session.as_ref() { // Important! If there's already a session, make sure the caller is the same endpoint as that session! if !session.remote_s_public_hash.eq(&SHA384::hash(&alice_s_public)) { return Err(Error::FailedAuthentication); } } let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?; let ss = host.get_local_s_keypair_p384().agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?; let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); if !hmac_sha384( kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end], ) .eq(&incoming_packet[aes_gcm_tag_end..hmac1_end]) { return Err(Error::FailedAuthentication); } // Alice's offer has been verified and her current key state reconstructed. let bob_e0_keypair = P384KeyPair::generate(); let e0e0 = bob_e0_keypair.agree(&alice_e0_public).ok_or(Error::FailedAuthentication)?; let se0 = bob_e0_keypair.agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?; let new_session = if session.is_some() { None } else { if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) { let header_check_cipher = Aes::new(kbkdf512(ss.as_bytes(), KBKDF_KEY_USAGE_LABEL_HEADER_MAC).first_n::<16>()); Some(Session:: { id: new_session_id, associated_object, send_counter: Counter::new(), psk, ss, header_check_cipher, state: RwLock::new(MutableState { remote_session_id: Some(alice_session_id), keys: LinkedList::new(), offer: None, }), remote_s_public_hash: SHA384::hash(&alice_s_public), remote_s_public_p384: alice_s_public_p384.as_bytes().clone(), defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), }) } else { return Err(Error::NewSessionRejected); } }; let session_ref = session; let session = session_ref.as_ref().map_or_else(|| new_session.as_ref().unwrap(), |s| &*s); // FIPS note: the order of HMAC parameters are flipped here from the usual Noise HMAC(key, X). That's because // NIST/FIPS allows HKDF with HMAC(salt, key) and salt is allowed to be anything. This way if the PSK is not // FIPS compliant the compliance of the entire key derivation is not invalidated. Both inputs are secrets of // fixed size so this shouldn't matter cryptographically. let key = Secret(hmac_sha512( session.psk.as_bytes(), &hmac_sha512(&hmac_sha512(&hmac_sha512(key.as_bytes(), bob_e0_keypair.public_key_bytes()), e0e0.as_bytes()), se0.as_bytes()), )); // At this point we've completed Noise_IK key derivation with NIST P-384 ECDH, but see final step below... let (bob_e1_public, e1e1) = if jedi && alice_e1_public.len() > 0 { if let Ok((bob_e1_public, e1e1)) = pqc_kyber::encapsulate(alice_e1_public, &mut random::SecureRandom::default()) { (Some(bob_e1_public), Secret(e1e1)) } else { return Err(Error::FailedAuthentication); } } else { (None, Secret::default()) // use all zero Kyber secret if disabled }; let mut reply_buf = [0_u8; MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS]; let reply_counter = session.send_counter.next(); let mut reply_len = { let mut rp = &mut reply_buf[HEADER_SIZE..]; rp.write_all(&[SESSION_PROTOCOL_VERSION])?; rp.write_all(bob_e0_keypair.public_key_bytes())?; rp.write_all(&session.id.to_bytes())?; varint::write(&mut rp, 0)?; // they don't need our static public; they have it varint::write(&mut rp, 0)?; // no meta-data yet if let Some(bob_e1_public) = bob_e1_public.as_ref() { rp.write_all(&[E1_TYPE_KYBER1024])?; rp.write_all(bob_e1_public)?; } else { rp.write_all(&[E1_TYPE_NONE])?; } (MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len() }; send_with_fragmentation_init_header(&mut reply_buf, reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true); c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]); let c = c.finish(); reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c); reply_len += AES_GCM_TAG_SIZE; // Normal Noise_IK is done, but we have one more step: mix in the Kyber shared secret (or all zeroes if Kyber is // disabled). We have to wait until this point because Kyber's keys are encrypted and can't be decrypted until // the P-384 exchange is done. We also flip the HMAC parameter order here for the same reason we do in the previous // key derivation step. let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes())); let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[HEADER_CHECK_SIZE..reply_len]); reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac); reply_len += HMAC_SIZE; let mut state = session.state.write(); let _ = state.remote_session_id.replace(alice_session_id); add_key(&mut state.keys, SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi)); drop(state); // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &session.header_check_cipher); if new_session.is_some() { return Ok(ReceiveResult::OkNewSession(new_session.unwrap())); } else { return Ok(ReceiveResult::Ok); } } PACKET_TYPE_KEY_COUNTER_OFFER => { // bob (remote) -> alice (local) if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) { return Err(Error::InvalidPacket); } let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE); let aes_gcm_tag_end = incoming_packet_len - HMAC_SIZE; if let Some(session) = session { let state = session.state.upgradable_read(); if let Some(offer) = state.offer.as_ref() { let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) .and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s))) .ok_or(Error::FailedAuthentication)?; let se0 = host.get_local_s_keypair_p384().agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?; let key = Secret(hmac_sha512( session.psk.as_bytes(), &hmac_sha512(&hmac_sha512(&hmac_sha512(offer.key.as_bytes(), bob_e0_public.as_bytes()), e0e0.as_bytes()), se0.as_bytes()), )); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false); c.init(&incoming_packet[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); let c = c.finish(); if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { return Err(Error::FailedAuthentication); } // Alice has now completed Noise_IK with NIST P-384 and verified with GCM auth, but now for hybrid... let (bob_session_id, _, _, bob_e1_public) = parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?; let e1e1 = if jedi && bob_e1_public.len() > 0 && offer.alice_e1_keypair.is_some() { if let Ok(e1e1) = pqc_kyber::decapsulate(bob_e1_public, &offer.alice_e1_keypair.as_ref().unwrap().secret) { Secret(e1e1) } else { return Err(Error::FailedAuthentication); } } else { Secret::default() }; let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes())); if !hmac_sha384( kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[HEADER_CHECK_SIZE..aes_gcm_tag_end], ) .eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()]) { return Err(Error::FailedAuthentication); } // Alice has now completed and validated the full hybrid exchange. let counter = session.send_counter.next(); let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi); let mut reply_buf = [0_u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; send_with_fragmentation_init_header(&mut reply_buf, HEADER_SIZE + AES_GCM_TAG_SIZE, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter); let mut c = key.get_send_cipher(counter)?; c.init(&reply_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); reply_buf[HEADER_SIZE..].copy_from_slice(&c.finish()); key.return_send_cipher(c); let hc = header_check(&reply_buf, &session.header_check_cipher); reply_buf[..HEADER_CHECK_SIZE].copy_from_slice(&hc.to_ne_bytes()); send(&mut reply_buf); let mut state = RwLockUpgradableReadGuard::upgrade(state); let _ = state.remote_session_id.replace(bob_session_id); let _ = state.offer.take(); add_key(&mut state.keys, key); return Ok(ReceiveResult::Ok); } } // Just ignore counter-offers that are out of place. They probably indicate that this side // restarted and needs to establish a new session. return Ok(ReceiveResult::Ignored); } _ => return Err(Error::InvalidPacket), } } } } struct Counter(AtomicU64); impl Counter { #[inline(always)] fn new() -> Self { Self(AtomicU64::new(random::next_u32_secure() as u64)) } #[inline(always)] fn current(&self) -> CounterValue { CounterValue(self.0.load(Ordering::SeqCst)) } #[inline(always)] fn next(&self) -> CounterValue { CounterValue(self.0.fetch_add(1, Ordering::SeqCst)) } } /// A value of the outgoing packet counter. /// /// The counter is internally 64-bit so we can more easily track usage limits without /// confusing logic to handle 32-bit wrapping. The least significant 32 bits are the /// actual counter put in the packet. #[repr(transparent)] #[derive(Copy, Clone)] struct CounterValue(u64); impl CounterValue { #[inline(always)] pub fn to_u32(&self) -> u32 { self.0 as u32 } } struct KeyLifetime { rekey_at_or_after_counter: u64, hard_expire_at_counter: u64, rekey_at_or_after_timestamp: i64, } impl KeyLifetime { fn new(current_counter: CounterValue, current_time: i64) -> Self { Self { rekey_at_or_after_counter: current_counter.0 + REKEY_AFTER_USES + (random::next_u32_secure() % REKEY_AFTER_USES_MAX_JITTER) as u64, hard_expire_at_counter: current_counter.0 + EXPIRE_AFTER_USES, rekey_at_or_after_timestamp: current_time + REKEY_AFTER_TIME_MS + (random::next_u32_secure() % REKEY_AFTER_TIME_MS_MAX_JITTER) as i64, } } #[inline(always)] fn should_rekey(&self, counter: CounterValue, current_time: i64) -> bool { counter.0 >= self.rekey_at_or_after_counter || current_time >= self.rekey_at_or_after_timestamp } #[inline(always)] fn expired(&self, counter: CounterValue) -> bool { counter.0 >= self.hard_expire_at_counter } } /// Ephemeral offer sent with KEY_OFFER and rememebered so state can be reconstructed on COUNTER_OFFER. struct EphemeralOffer { creation_time: i64, key: Secret<64>, alice_e0_keypair: P384KeyPair, alice_e1_keypair: Option, } impl EphemeralOffer { fn create_alice_offer( send: &mut SendFunction, counter: CounterValue, alice_session_id: SessionId, bob_session_id: Option, alice_s_public: &[u8], alice_metadata: &[u8], bob_s_public_p384: &P384PublicKey, bob_s_public_hash: &[u8], ss: &Secret<48>, header_check_cipher: &Aes, mtu: usize, current_time: i64, jedi: bool, ) -> Result { let alice_e0_keypair = P384KeyPair::generate(); let e0s = alice_e0_keypair.agree(bob_s_public_p384); if e0s.is_none() { return Err(Error::InvalidPacket); } let alice_e1_keypair = if jedi { Some(pqc_kyber::keypair(&mut random::SecureRandom::get())) } else { None }; const PACKET_BUF_SIZE: usize = MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS; let mut packet_buf = [0_u8; PACKET_BUF_SIZE]; let mut packet_len = { let mut p = &mut packet_buf[HEADER_SIZE..]; p.write_all(&[SESSION_PROTOCOL_VERSION])?; p.write_all(alice_e0_keypair.public_key_bytes())?; p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?; varint::write(&mut p, alice_s_public.len() as u64)?; p.write_all(alice_s_public)?; varint::write(&mut p, alice_metadata.len() as u64)?; p.write_all(alice_metadata)?; if let Some(e1kp) = alice_e1_keypair { p.write_all(&[E1_TYPE_KYBER1024])?; p.write_all(&e1kp.public)?; } else { p.write_all(&[E1_TYPE_NONE])?; } PACKET_BUF_SIZE - p.len() }; send_with_fragmentation_init_header(&mut packet_buf, packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter); let key = Secret(hmac_sha512( &hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()), e0s.unwrap().as_bytes(), )); let gcm_tag = { let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true); c.init(&packet_buf[AES_GCM_NONCE_START..AES_GCM_NONCE_END]); c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]); c.finish() }; packet_buf[packet_len..(packet_len + AES_GCM_TAG_SIZE)].copy_from_slice(&gcm_tag); packet_len += AES_GCM_TAG_SIZE; let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &packet_buf[HEADER_CHECK_SIZE..packet_len]); packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); packet_len += HMAC_SIZE; let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[HEADER_CHECK_SIZE..packet_len]); packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); packet_len += HMAC_SIZE; send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, header_check_cipher); Ok(EphemeralOffer { creation_time: current_time, key, alice_e0_keypair, alice_e1_keypair, }) } } #[inline(always)] fn send_with_fragmentation_init_header(header: &mut [u8], packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) { let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize; debug_assert!(mtu >= MIN_MTU); debug_assert!(packet_len >= HEADER_SIZE); debug_assert!(fragment_count <= MAX_FRAGMENTS); debug_assert!(fragment_count > 0); debug_assert!(packet_type <= 0x0f); // packet type is 4 bits debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits header[HEADER_CHECK_SIZE..12].copy_from_slice(&(recipient_session_id | (packet_type as u64).wrapping_shl(48) | ((fragment_count - 1) as u64).wrapping_shl(52)).to_le_bytes()); header[12..HEADER_SIZE].copy_from_slice(&counter.to_u32().to_le_bytes()); } /// Send a packet in fragments (used for everything but DATA which has a hand-rolled version for performance). fn send_with_fragmentation(send: &mut SendFunction, packet: &mut [u8], mtu: usize, header_check_cipher: &Aes) { let packet_len = packet.len(); let mut fragment_start = 0; let mut fragment_end = packet_len.min(mtu); let mut header: [u8; HEADER_SIZE - HEADER_CHECK_SIZE] = packet[HEADER_CHECK_SIZE..HEADER_SIZE].try_into().unwrap(); loop { let hc = header_check(&packet[fragment_start..], header_check_cipher); packet[fragment_start..(fragment_start + HEADER_CHECK_SIZE)].copy_from_slice(&hc.to_ne_bytes()); send(&mut packet[fragment_start..fragment_end]); if fragment_end < packet_len { debug_assert!(header[7].wrapping_shr(2) < 63); header[7] += 0x04; // increment fragment number fragment_start = fragment_end - HEADER_SIZE; fragment_end = (fragment_start + mtu).min(packet_len); packet[(fragment_start + HEADER_CHECK_SIZE)..(fragment_start + HEADER_SIZE)].copy_from_slice(&header); } else { debug_assert_eq!(fragment_end, packet_len); break; } } } /// Compute the 32-bit header check code for a packet on receipt or right before send. fn header_check(packet: &[u8], header_check_cipher: &Aes) -> u32 { debug_assert!(packet.len() >= HEADER_SIZE); let mut header_check = 0u128.to_ne_bytes(); if packet.len() >= (16 + HEADER_CHECK_SIZE) { header_check_cipher.encrypt_block(&packet[HEADER_CHECK_SIZE..(16 + HEADER_CHECK_SIZE)], &mut header_check); } else { unlikely_branch(); header_check[..(packet.len() - HEADER_CHECK_SIZE)].copy_from_slice(&packet[HEADER_CHECK_SIZE..]); header_check_cipher.encrypt_block_in_place(&mut header_check); } memory::u32_from_ne_bytes(&header_check) } /// Add a new session key to the key list, retiring older non-active keys if necessary. fn add_key(keys: &mut LinkedList, key: SessionKey) { debug_assert!(KEY_HISTORY_SIZE_MAX >= 2); while keys.len() >= KEY_HISTORY_SIZE_MAX { let current = keys.pop_front().unwrap(); let _ = keys.pop_front(); keys.push_front(current); } keys.push_back(key); } fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<(SessionId, &[u8], &[u8], &[u8]), Error> { let mut p = &incoming_packet[..]; let alice_session_id = SessionId::new_from_reader(&mut p)?; if alice_session_id.is_none() { return Err(Error::InvalidPacket); } let alice_session_id = alice_session_id.unwrap(); let alice_s_public_len = varint::read(&mut p)?.0; if (p.len() as u64) < alice_s_public_len { return Err(Error::InvalidPacket); } let alice_s_public = &p[..(alice_s_public_len as usize)]; p = &p[(alice_s_public_len as usize)..]; let alice_metadata_len = varint::read(&mut p)?.0; if (p.len() as u64) < alice_metadata_len { return Err(Error::InvalidPacket); } let alice_metadata = &p[..(alice_metadata_len as usize)]; p = &p[(alice_metadata_len as usize)..]; if p.is_empty() { return Err(Error::InvalidPacket); } let alice_e1_public = match p[0] { E1_TYPE_KYBER1024 => { if packet_type == PACKET_TYPE_KEY_OFFER { if p.len() < (pqc_kyber::KYBER_PUBLICKEYBYTES + 1) { return Err(Error::InvalidPacket); } &p[1..(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)] } else { if p.len() < (pqc_kyber::KYBER_CIPHERTEXTBYTES + 1) { return Err(Error::InvalidPacket); } &p[1..(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)] } } _ => &[], }; Ok((alice_session_id, alice_s_public, alice_metadata, alice_e1_public)) } enum Role { Alice, Bob, } #[allow(unused)] struct SessionKey { lifetime: KeyLifetime, receive_key: Secret<32>, send_key: Secret<32>, receive_cipher_pool: Mutex>>, send_cipher_pool: Mutex>>, role: Role, jedi: bool, // true if kyber was used } impl SessionKey { /// Create a new symmetric shared session key and set its key expiration times, etc. fn new(key: Secret<64>, role: Role, current_time: i64, current_counter: CounterValue, jedi: bool) -> Self { let a2b: Secret<32> = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n_clone(); let b2a: Secret<32> = kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n_clone(); let (receive_key, send_key) = match role { Role::Alice => (b2a, a2b), Role::Bob => (a2b, b2a), }; Self { lifetime: KeyLifetime::new(current_counter, current_time), receive_key, send_key, receive_cipher_pool: Mutex::new(Vec::with_capacity(2)), send_cipher_pool: Mutex::new(Vec::with_capacity(2)), role, jedi, } } #[inline(always)] fn get_send_cipher(&self, counter: CounterValue) -> Result, Error> { if !self.lifetime.expired(counter) { Ok(self.send_cipher_pool.lock().pop().unwrap_or_else(|| Box::new(AesGcm::new(self.send_key.as_bytes(), true)))) } else { Err(Error::MaxKeyLifetimeExceeded) } } #[inline(always)] fn return_send_cipher(&self, c: Box) { self.send_cipher_pool.lock().push(c); } #[inline(always)] fn get_receive_cipher(&self) -> Box { self.receive_cipher_pool.lock().pop().unwrap_or_else(|| Box::new(AesGcm::new(self.receive_key.as_bytes(), false))) } #[inline(always)] fn return_receive_cipher(&self, c: Box) { self.receive_cipher_pool.lock().push(c); } } /// HMAC-SHA512 key derivation function modeled on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12) fn kbkdf512(key: &[u8], label: u8) -> Secret<64> { Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00])) } #[cfg(test)] mod tests { use parking_lot::Mutex; use std::collections::LinkedList; use std::sync::Arc; #[allow(unused_imports)] use super::*; struct TestHost { local_s: P384KeyPair, local_s_hash: [u8; 48], psk: Secret<64>, session: Mutex>>>>, session_id_counter: Mutex, pub queue: Mutex>>, pub this_name: &'static str, pub other_name: &'static str, } impl TestHost { fn new(psk: Secret<64>, this_name: &'static str, other_name: &'static str) -> Self { let local_s = P384KeyPair::generate(); let local_s_hash = SHA384::hash(local_s.public_key_bytes()); Self { local_s, local_s_hash, psk, session: Mutex::new(None), session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1), queue: Mutex::new(LinkedList::new()), this_name, other_name, } } } impl Host for Box { type AssociatedObject = u32; type SessionRef = Arc>>; type IncomingPacketBuffer = Vec; fn get_local_s_public(&self) -> &[u8] { self.local_s.public_key_bytes() } fn get_local_s_public_hash(&self) -> &[u8; 48] { &self.local_s_hash } fn get_local_s_keypair_p384(&self) -> &P384KeyPair { &self.local_s } fn extract_p384_static(static_public: &[u8]) -> Option { P384PublicKey::from_bytes(static_public) } fn session_lookup(&self, local_session_id: SessionId) -> Option { self.session.lock().as_ref().and_then(|s| { if s.id == local_session_id { Some(s.clone()) } else { None } }) } fn accept_new_session(&self, _: &[u8], _: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)> { loop { let mut new_id = self.session_id_counter.lock(); *new_id += 1; return Some((SessionId::new_from_u64(*new_id).unwrap(), self.psk.clone(), 0)); } } } #[allow(unused_variables)] #[test] fn establish_session() { let jedi = true; let mut psk: Secret<64> = Secret::default(); random::fill_bytes_secure(&mut psk.0); let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob")); let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice")); let alice_rc: Box>> = Box::new(ReceiveContext::new(&alice_host)); let bob_rc: Box>> = Box::new(ReceiveContext::new(&bob_host)); let mut data_buf = [0_u8; 4096]; //println!("zssp: size of session (bytes): {}", std::mem::size_of::>>()); let _ = alice_host.session.lock().insert(Arc::new( Session::new( &alice_host, |data| bob_host.queue.lock().push_front(data.to_vec()), SessionId::new_random(), bob_host.local_s.public_key_bytes(), &[], &psk, 1, 1280, 1, jedi, ) .unwrap(), )); let mut ts = 0; for _ in 0..256 { for host in [&alice_host, &bob_host] { let send_to_other = |data: &mut [u8]| { if std::ptr::eq(host, &alice_host) { bob_host.queue.lock().push_front(data.to_vec()); } else { alice_host.queue.lock().push_front(data.to_vec()); } }; let rc = if std::ptr::eq(host, &alice_host) { &alice_rc } else { &bob_rc }; loop { if let Some(qi) = host.queue.lock().pop_back() { let qi_len = qi.len(); ts += 1; let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, jedi, ts); if r.is_ok() { let r = r.unwrap(); match r { ReceiveResult::Ok => { println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len); } ReceiveResult::OkData(data) => { println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len()); } ReceiveResult::OkNewSession(new_session) => { println!("zssp: {} => {} ({}): OkNewSession ({})", host.other_name, host.this_name, qi_len, u64::from(new_session.id)); let mut hs = host.session.lock(); assert!(hs.is_none()); let _ = hs.insert(Arc::new(new_session)); } ReceiveResult::Ignored => { println!("zssp: {} => {} ({}): Ignored", host.other_name, host.this_name, qi_len); } } } else { println!("zssp: {} => {} ({}): error: {}", host.other_name, host.this_name, qi_len, r.err().unwrap().to_string()); panic!(); } } else { break; } } } } } }