diff --git a/Cargo.toml b/Cargo.toml index f7bb47033..762ea2844 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "network-hypervisor", "controller", "system-service", + "utils", ] [profile.release] diff --git a/core-crypto/Cargo.toml b/core-crypto/Cargo.toml index a6f977fe0..f04a36ccd 100644 --- a/core-crypto/Cargo.toml +++ b/core-crypto/Cargo.toml @@ -6,6 +6,7 @@ name = "zerotier-core-crypto" version = "0.1.0" [dependencies] +zerotier-utils = { path = "../utils" } ed25519-dalek = {version = "1.0.1", features = ["std", "u64_backend"], default-features = false} foreign-types = "0.3.1" lazy_static = "^1" diff --git a/core-crypto/src/lib.rs b/core-crypto/src/lib.rs index d9000403b..dcc628836 100644 --- a/core-crypto/src/lib.rs +++ b/core-crypto/src/lib.rs @@ -3,14 +3,12 @@ pub mod aes; pub mod aes_gmac_siv; pub mod hash; -pub mod hex; pub mod kbkdf; pub mod p384; pub mod poly1305; pub mod random; pub mod salsa; pub mod secret; -pub mod varint; pub mod x25519; pub mod zssp; diff --git a/core-crypto/src/salsa.rs b/core-crypto/src/salsa.rs index 5e5368879..679c4c69a 100644 --- a/core-crypto/src/salsa.rs +++ b/core-crypto/src/salsa.rs @@ -215,14 +215,12 @@ mod tests { use crate::salsa::*; const SALSA_20_TV0_KEY: [u8; 32] = [ - 0x0f, 0x62, 0xb5, 0x08, 0x5b, 0xae, 0x01, 0x54, 0xa7, 0xfa, 0x4d, 0xa0, 0xf3, 0x46, 0x99, 0xec, 0x3f, 0x92, 0xe5, 0x38, 0x8b, 0xde, 0x31, 0x84, 0xd7, 0x2a, 0x7d, 0xd0, - 0x23, 0x76, 0xc9, 0x1c, + 0x0f, 0x62, 0xb5, 0x08, 0x5b, 0xae, 0x01, 0x54, 0xa7, 0xfa, 0x4d, 0xa0, 0xf3, 0x46, 0x99, 0xec, 0x3f, 0x92, 0xe5, 0x38, 0x8b, 0xde, 0x31, 0x84, 0xd7, 0x2a, 0x7d, 0xd0, 0x23, 0x76, 0xc9, 0x1c, ]; const SALSA_20_TV0_IV: [u8; 8] = [0x28, 0x8f, 0xf6, 0x5d, 0xc4, 0x2b, 0x92, 0xf9]; const SALSA_20_TV0_KS: [u8; 64] = [ - 0x5e, 0x5e, 0x71, 0xf9, 0x01, 0x99, 0x34, 0x03, 0x04, 0xab, 0xb2, 0x2a, 0x37, 0xb6, 0x62, 0x5b, 0xf8, 0x83, 0xfb, 0x89, 0xce, 0x3b, 0x21, 0xf5, 0x4a, 0x10, 0xb8, 0x10, - 0x66, 0xef, 0x87, 0xda, 0x30, 0xb7, 0x76, 0x99, 0xaa, 0x73, 0x79, 0xda, 0x59, 0x5c, 0x77, 0xdd, 0x59, 0x54, 0x2d, 0xa2, 0x08, 0xe5, 0x95, 0x4f, 0x89, 0xe4, 0x0e, 0xb7, - 0xaa, 0x80, 0xa8, 0x4a, 0x61, 0x76, 0x66, 0x3f, + 0x5e, 0x5e, 0x71, 0xf9, 0x01, 0x99, 0x34, 0x03, 0x04, 0xab, 0xb2, 0x2a, 0x37, 0xb6, 0x62, 0x5b, 0xf8, 0x83, 0xfb, 0x89, 0xce, 0x3b, 0x21, 0xf5, 0x4a, 0x10, 0xb8, 0x10, 0x66, 0xef, 0x87, 0xda, + 0x30, 0xb7, 0x76, 0x99, 0xaa, 0x73, 0x79, 0xda, 0x59, 0x5c, 0x77, 0xdd, 0x59, 0x54, 0x2d, 0xa2, 0x08, 0xe5, 0x95, 0x4f, 0x89, 0xe4, 0x0e, 0xb7, 0xaa, 0x80, 0xa8, 0x4a, 0x61, 0x76, 0x66, 0x3f, ]; #[test] diff --git a/core-crypto/src/zssp.rs b/core-crypto/src/zssp.rs index 2364db06e..54eb4cac7 100644 --- a/core-crypto/src/zssp.rs +++ b/core-crypto/src/zssp.rs @@ -1,21 +1,24 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. -use std::io::Write; +use std::io::{Read, Write}; use std::num::NonZeroU64; use std::ops::Deref; use std::sync::atomic::{AtomicU64, Ordering}; -use crate::aes::{Aes, AesGcm}; -use crate::hash::{hmac_sha384, hmac_sha512, SHA384, SHA512}; +use crate::aes::AesGcm; +use crate::hash::{hmac_sha384, hmac_sha512, SHA384}; use crate::p384::{P384KeyPair, P384PublicKey, P384_PUBLIC_KEY_SIZE}; use crate::random; use crate::secret::Secret; -use crate::varint; + +use zerotier_utils::gatherarray::GatherArray; +use zerotier_utils::ringbuffermap::RingBufferMap; +use zerotier_utils::varint; use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard}; -/// Minimum possible packet size. Packets smaller than this are rejected. -pub const MIN_PACKET_SIZE: usize = HEADER_SIZE + 1 + AES_GCM_TAG_SIZE; +pub const MIN_PACKET_SIZE: usize = HEADER_SIZE; +pub const MIN_MTU: usize = 1280; /// Start attempting to rekey after a key has been used to send packets this many times. const REKEY_AFTER_USES: u64 = 536870912; @@ -35,23 +38,22 @@ const REKEY_AFTER_TIME_MS_MAX_JITTER: u32 = 1000 * 60 * 5; /// Rate limit for sending new offers to attempt to re-key. const OFFER_RATE_LIMIT_MS: i64 = 2000; -/// Version 1: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication) -const SESSION_PROTOCOL_VERSION: u8 = 1; +/// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication) +const SESSION_PROTOCOL_VERSION: u8 = 0x00; const PACKET_TYPE_DATA: u8 = 0; const PACKET_TYPE_NOP: u8 = 1; const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice" const PACKET_TYPE_KEY_COUNTER_OFFER: u8 = 3; // "bob" -const GET_PACKET_TYPE_BIT_MASK: u8 = 0x1f; -const GET_PROTOCOL_VERSION_SHIFT_RIGHT: u32 = 5; - const E1_TYPE_NONE: u8 = 0; const E1_TYPE_KYBER1024: u8 = 1; -const HEADER_SIZE: usize = 11; +const MAX_FRAGMENTS: usize = 48; // protocol max: 63 +const KEY_EXCHANGE_MAX_FRAGMENTS: usize = 2; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc. +const HEADER_SIZE: usize = 12; const AES_GCM_TAG_SIZE: usize = 16; -const HMAC_SIZE: usize = 48; // HMAC-SHA384 +const HMAC_SIZE: usize = 48; const SESSION_ID_SIZE: usize = 6; /// Aribitrary starting value for key derivation chain. @@ -94,7 +96,10 @@ pub enum Error { RateLimited, /// Other end sent a protocol version we don't support. - UnknownProtocolVersion(u8), + UnknownProtocolVersion, + + /// Supplied data buffer is too small to receive data. + DataBufferTooSmall, /// An internal error occurred. OtherError(Box), @@ -108,6 +113,10 @@ impl From for Error { } } +#[cold] +#[inline(never)] +extern "C" fn unlikely_branch() {} + impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -119,7 +128,8 @@ impl std::fmt::Display for Error { Self::MaxKeyLifetimeExceeded => f.write_str("MaxKeyLifetimeExceeded"), Self::SessionNotEstablished => f.write_str("SessionNotEstablished"), Self::RateLimited => f.write_str("RateLimited"), - Self::UnknownProtocolVersion(v) => f.write_str(format!("UnknownProtocolVersion({})", v).as_str()), + Self::UnknownProtocolVersion => f.write_str("UnknownProtocolVersion"), + Self::DataBufferTooSmall => f.write_str("DataBufferTooSmall"), Self::OtherError(e) => f.write_str(format!("OtherError({})", e.to_string()).as_str()), } } @@ -133,29 +143,16 @@ impl std::fmt::Debug for Error { } } -/// Obfuscator/deobfuscator for recipient privacy masking on the wire. -pub struct Obfuscator(Aes); - -impl Obfuscator { - /// Create a new obfuscator for sending packets TO the provided static public identity. - pub fn new(recipient_static_public: &[u8]) -> Self { - Self(Aes::new(&SHA512::hash(recipient_static_public)[..32])) - } -} - -pub enum ReceiveResult<'a, H: SessionHost> { - /// Packet is valid and contained a data payload. - OkData(&'a [u8], u32), - - /// Packet is valid and the provided reply should be sent back. - OkSendReply(&'a [u8]), - - /// Packet is valid and a new session was created, also includes a reply to be sent back. - OkNewSession(Session, &'a [u8]), - +pub enum ReceiveResult<'a, H: Host> { /// Packet is valid, no action needs to be taken. Ok, + /// Packet is valid and contained a data payload. + OkData(&'a [u8]), + + /// Packet is valid and a new session was created, also includes a reply to be sent back. + OkNewSession(Session), + /// Packet appears valid but was ignored as a duplicate. Duplicate, @@ -170,7 +167,6 @@ pub struct SessionId(NonZeroU64); impl SessionId { pub const MAX_BIT_MASK: u64 = 0xffffffffffff; - #[inline(always)] pub fn new_from_bytes(b: &[u8]) -> Option { if b.len() >= 6 { let value = (u32::from_le_bytes(b[..4].try_into().unwrap()) as u64) | (u16::from_le_bytes(b[4..6].try_into().unwrap()) as u64).wrapping_shl(32); @@ -181,7 +177,12 @@ impl SessionId { return None; } - #[inline(always)] + pub fn new_from_reader(r: &mut R) -> std::io::Result> { + let mut tmp = [0_u8; SESSION_ID_SIZE]; + r.read_exact(&mut tmp)?; + Ok(Self::new_from_bytes(&tmp)) + } + pub fn new_random() -> Self { Self(NonZeroU64::new((random::next_u64_secure() & Self::MAX_BIT_MASK).max(1)).unwrap()) } @@ -212,20 +213,20 @@ impl From for u64 { } } -pub trait SessionHost: Sized { - type Buffer: AsRef<[u8]> + AsMut<[u8]> + Write; +pub trait Host: Sized { type AssociatedObject: Sized; type SessionRef: Deref>; + type IncomingPacketBuffer: AsRef<[u8]>; /// Get a reference to this host's static public key blob. fn get_local_s_public(&self) -> &[u8]; + /// Get SHA384(this host's static public key blob) + fn get_local_s_public_hash(&self) -> &[u8; 48]; + /// Get a reference to this hosts' static public key's NIST P-384 secret key pair fn get_local_s_keypair_p384(&self) -> &P384KeyPair; - /// Get an empty writable buffer, or None if none are available and the operation in progress should fail. - fn get_buffer(&self) -> Option; - /// Extract a NIST P-384 ECC public key from a static public key blob. fn extract_p384_static(static_public: &[u8]) -> Option; @@ -240,17 +241,17 @@ pub trait SessionHost: Sized { fn accept_new_session(&self, remote_static_public: &[u8], remote_metadata: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)>; } -pub struct Session { +pub struct Session { pub id: SessionId, pub associated_object: H::AssociatedObject, send_counter: Counter, - remote_s_public_hash: [u8; 48], psk: Secret<64>, ss: Secret<48>, - outgoing_obfuscator: Obfuscator, state: RwLock, + remote_s_public_hash: [u8; 48], remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], + defrag: Mutex, 64, 32>>, } struct MutableState { @@ -259,8 +260,13 @@ struct MutableState { offer: Option, } -impl Session { - pub fn new( +/// State information to associate with receiving contexts such as sockets or remote paths/endpoints. +pub struct ReceiveContext { + initial_offer_defrag: Mutex, 1024, 256>>, +} + +impl Session { + pub fn new( host: &H, send: SendFunction, local_session_id: SessionId, @@ -274,8 +280,8 @@ impl Session { ) -> Result { if let Some(remote_s_public_p384) = H::extract_p384_static(remote_s_public) { if let Some(ss) = host.get_local_s_keypair_p384().agree(&remote_s_public_p384) { - let outgoing_obfuscator = Obfuscator::new(remote_s_public); let counter = Counter::new(); + let remote_s_public_hash = SHA384::hash(remote_s_public); if let Ok(offer) = EphemeralOffer::create_alice_offer( send, counter.next(), @@ -284,8 +290,8 @@ impl Session { host.get_local_s_public(), offer_metadata, &remote_s_public_p384, + &remote_s_public_hash, &ss, - &outgoing_obfuscator, mtu, current_time, jedi, @@ -294,16 +300,16 @@ impl Session { id: local_session_id, associated_object, send_counter: counter, - remote_s_public_hash: SHA384::hash(remote_s_public), psk: psk.clone(), ss, - outgoing_obfuscator, state: RwLock::new(MutableState { remote_session_id: None, keys: [None, None], offer: Some(offer), }), + remote_s_public_hash, remote_s_public_p384: remote_s_public_p384.as_bytes().clone(), + defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), }); } } @@ -311,7 +317,7 @@ impl Session { return Err(Error::InvalidParameter); } - pub fn rekey_check(&self, host: &H, send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) { + pub fn rekey_check(&self, host: &H, send: SendFunction, offer_metadata: &[u8], mtu: usize, current_time: i64, force: bool, jedi: bool) { let state = self.state.upgradable_read(); if let Some(key) = state.keys[0].as_ref() { if force || (key.lifetime.should_rekey(self.send_counter.current(), current_time) && state.offer.as_ref().map_or(true, |o| (current_time - o.creation_time) > OFFER_RATE_LIMIT_MS)) { @@ -324,8 +330,8 @@ impl Session { host.get_local_s_public(), offer_metadata, &remote_s_public_p384, + &self.remote_s_public_hash, &self.ss, - &self.outgoing_obfuscator, mtu, current_time, jedi, @@ -336,87 +342,100 @@ impl Session { } } } +} - /* - debug_assert!(packet_type == PACKET_TYPE_DATA || packet_type == PACKET_TYPE_NOP); - buffer[0] = packet_type; - buffer[1..7].copy_from_slice(&remote_session_id.to_le_bytes()[..SESSION_ID_SIZE]); - buffer[7..11].copy_from_slice(&counter.to_bytes()); - - let payload_end = HEADER_SIZE + data.len(); - let tag_end = payload_end + AES_GCM_TAG_SIZE; - if tag_end < MAX_PACKET_SIZE { - let mut c = key.get_send_cipher(counter)?; - buffer[11..16].fill(0); - c.init(&buffer[..16]); - c.crypt(data, &mut buffer[HEADER_SIZE..payload_end]); - buffer[payload_end..tag_end].copy_from_slice(&c.finish()); - key.return_send_cipher(c); - - outgoing_obfuscator.0.encrypt_block_in_place(&mut buffer[..16]); - - Ok(tag_end) - } else { - unlikely_branch(); - Err(Error::InvalidParameter) - } - */ - - /* - pub fn send(&self, data: &[u8]) -> Result { - let state = self.state.read(); - if let Some(key) = state.keys[0].as_ref() { - if let Some(remote_session_id) = state.remote_session_id { - //let data_len = assemble_and_armor_DATA(buffer, data, PACKET_TYPE_DATA, u64::from(remote_session_id), self.send_counter.next(), &key, &self.outgoing_obfuscator)?; - Ok(&buffer[..data_len]) - } else { - unlikely_branch(); - Err(Error::SessionNotEstablished) - } - } else { - unlikely_branch(); - Err(Error::SessionNotEstablished) +impl ReceiveContext { + pub fn new() -> Self { + Self { + initial_offer_defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), } } - */ - /* - /// Receive a packet from the network and take the appropriate action. - /// - /// Check ReceiveResult to see if it includes data or a reply packet. - pub fn receive< - 'a, - ExtractP384PublicKeyFunction: FnOnce(&[u8; STATIC_PUBLIC_SIZE]) -> Option, - SessionLookupFunction: FnOnce(SessionId) -> Option, - NewSessionAuthenticatorFunction: FnOnce(&[u8; STATIC_PUBLIC_SIZE]) -> Option<(SessionId, Secret<64>, O)>, - S: std::ops::Deref>, - const MAX_PACKET_SIZE: usize, - const STATIC_PUBLIC_SIZE: usize, - >( - incoming_packet: &[u8], - buffer: &'a mut [u8; MAX_PACKET_SIZE], - local_s_keypair_p384: &P384KeyPair, - incoming_obfuscator: &Obfuscator, - extract_p384_static_public: ExtractP384PublicKeyFunction, - session_lookup: SessionLookupFunction, - new_session_auth: NewSessionAuthenticatorFunction, - current_time: i64, + pub fn receive<'a, SendFunction: FnMut(&mut [u8])>( + &self, + host: &H, + send: SendFunction, + data_buf: &'a mut [u8], + incoming_packet_buf: H::IncomingPacketBuffer, + mtu: usize, jedi: bool, - ) -> Result, Error> { - debug_assert!(MAX_PACKET_SIZE >= (64 + STATIC_PUBLIC_SIZE + P384_PUBLIC_KEY_SIZE + pqc_kyber::KYBER_PUBLICKEYBYTES)); + current_time: i64, + ) -> Result, Error> { + let incoming_packet = incoming_packet_buf.as_ref(); - if incoming_packet.len() > MAX_PACKET_SIZE || incoming_packet.len() <= MIN_PACKET_SIZE { + if incoming_packet.len() < MIN_PACKET_SIZE { unlikely_branch(); return Err(Error::InvalidPacket); } - incoming_obfuscator.0.decrypt_block(&incoming_packet[..16], &mut buffer[..16]); - let mut packet_type = buffer[0]; - let continued = (packet_type & PACKET_FLAG_CONTINUED) != 0; - packet_type &= PACKET_TYPE_MASK; + let type_and_frag_info = u16::from_le_bytes(incoming_packet[0..2].try_into().unwrap()); + let local_session_id = SessionId::new_from_bytes(&incoming_packet[2..8]); + let counter = u32::from_le_bytes(incoming_packet[8..12].try_into().unwrap()); + let packet_type = (type_and_frag_info as u8) & 15; + let fragment_count = type_and_frag_info.wrapping_shr(4) & 63; + let fragment_no = type_and_frag_info.wrapping_shr(10); // & 63 not needed - let local_session_id = SessionId::new_from_bytes(&buffer[1..7]); - let session = local_session_id.and_then(|sid| session_lookup(sid)); + if fragment_count > 1 { + if let Some(local_session_id) = local_session_id { + if fragment_count < (MAX_FRAGMENTS as u16) && fragment_no < fragment_count { + if let Some(session) = host.session_lookup(local_session_id) { + let mut defrag = session.defrag.lock(); + let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) { + drop(defrag); // release lock + return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time); + } + } else { + unlikely_branch(); + return Err(Error::UnknownLocalSessionId(local_session_id)); + } + } else { + unlikely_branch(); + return Err(Error::InvalidPacket); + } + } else { + if fragment_count < (KEY_EXCHANGE_MAX_FRAGMENTS as u16) && fragment_no < fragment_count { + let mut defrag = self.initial_offer_defrag.lock(); + let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32)); + if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) { + drop(defrag); // release lock + return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); + } + } else { + unlikely_branch(); + return Err(Error::InvalidPacket); + } + } + } else { + return self.receive_complete( + host, + send, + data_buf, + &[incoming_packet_buf], + packet_type, + local_session_id.and_then(|lsid| host.session_lookup(lsid)), + mtu, + jedi, + current_time, + ); + } + + return Ok(ReceiveResult::Ok); + } + + fn receive_complete<'a, SendFunction: FnMut(&mut [u8])>( + &self, + host: &H, + mut send: SendFunction, + data_buf: &'a mut [u8], + fragments: &[H::IncomingPacketBuffer], + packet_type: u8, + session: Option, + mtu: usize, + jedi: bool, + current_time: i64, + ) -> Result, Error> { + debug_assert!(fragments.len() >= 1); debug_assert_eq!(PACKET_TYPE_DATA, 0); debug_assert_eq!(PACKET_TYPE_NOP, 1); @@ -425,16 +444,53 @@ impl Session { let state = session.state.read(); for ki in 0..2 { if let Some(key) = state.keys[ki].as_ref() { - let nonce = get_aes_gcm_nonce(buffer); + let head = fragments.first().unwrap().as_ref(); + debug_assert!(head.len() >= MIN_PACKET_SIZE); + let mut c = key.get_receive_cipher(); - c.init(&nonce); - c.crypt_in_place(&mut buffer[HEADER_SIZE..16]); - let data_len = incoming_packet.len() - AES_GCM_TAG_SIZE; - c.crypt(&incoming_packet[16..data_len], &mut buffer[16..data_len]); + c.init(&get_aes_gcm_nonce(head)); + + let mut data_len = head.len() - HEADER_SIZE; + if data_len > data_buf.len() { + unlikely_branch(); + key.return_receive_cipher(c); + return Err(Error::DataBufferTooSmall); + } + c.crypt(&head[HEADER_SIZE..], &mut data_buf[..data_len]); + + for fi in 1..(fragments.len() - 1) { + let f = fragments[fi].as_ref(); + debug_assert!(f.len() >= HEADER_SIZE); + let current_frag_data_start = data_len; + data_len += f.len() - HEADER_SIZE; + if data_len > data_buf.len() { + unlikely_branch(); + key.return_receive_cipher(c); + return Err(Error::DataBufferTooSmall); + } + c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); + } + + let tail = fragments.last().unwrap().as_ref(); + if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) { + unlikely_branch(); + key.return_receive_cipher(c); + return Err(Error::InvalidPacket); + } + let tail_data_len = tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE); + let current_frag_data_start = data_len; + data_len += tail_data_len; + if data_len > data_buf.len() { + unlikely_branch(); + key.return_receive_cipher(c); + return Err(Error::DataBufferTooSmall); + } + c.crypt(&tail[HEADER_SIZE..tail_data_len], &mut data_buf[current_frag_data_start..data_len]); + let tag = c.finish(); key.return_receive_cipher(c); - if tag.eq(&incoming_packet[data_len..]) { + if tag.eq(&tail[(tail.len() - AES_GCM_TAG_SIZE)..]) { if ki == 1 { // Promote next key to current key on success. unlikely_branch(); @@ -444,7 +500,7 @@ impl Session { } if packet_type == PACKET_TYPE_DATA { - return Ok(ReceiveResult::OkData(&buffer[HEADER_SIZE..data_len], u32::from_le_bytes(nonce[7..11].try_into().unwrap()))); + return Ok(ReceiveResult::OkData(&data_buf[..data_len])); } else { return Ok(ReceiveResult::Ok); } @@ -454,38 +510,51 @@ impl Session { return Err(Error::FailedAuthentication); } else { unlikely_branch(); - if let Some(local_session_id) = local_session_id { - return Err(Error::UnknownLocalSessionId(local_session_id)); - } else { - return Err(Error::InvalidPacket); - } + return Err(Error::SessionNotEstablished); } } else { unlikely_branch(); - if local_session_id.is_some() && session.is_none() { - return Err(Error::UnknownLocalSessionId(local_session_id.unwrap())); - } - - if incoming_packet.len() > (HEADER_SIZE + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) { - for i in (16..64).step_by(16) { - let j = i + 16; - incoming_obfuscator.0.decrypt_block(&incoming_packet[i..j], &mut buffer[i..j]); - for k in i..j { - buffer[k] ^= incoming_packet[k - 16]; - } + let mut incoming_packet_buf = [0_u8; MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS]; + let mut incoming_packet_len = 0; + for i in 0..fragments.len() { + let mut ff = fragments[i].as_ref(); + debug_assert!(ff.len() >= MIN_PACKET_SIZE); + if i > 0 { + ff = &ff[HEADER_SIZE..]; } - buffer[64..incoming_packet.len()].copy_from_slice(&incoming_packet[64..]); - } else { + let j = incoming_packet_len + ff.len(); + if j > incoming_packet_buf.len() { + return Err(Error::InvalidPacket); + } + incoming_packet_buf[incoming_packet_len..j].copy_from_slice(ff); + incoming_packet_len = j; + } + let original_ciphertext = incoming_packet_buf.clone(); + let incoming_packet = &mut incoming_packet_buf[..incoming_packet_len]; + + if incoming_packet_len < (HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE + AES_GCM_TAG_SIZE + HMAC_SIZE) { return Err(Error::InvalidPacket); } - let payload_end = incoming_packet.len() - (AES_GCM_TAG_SIZE + HMAC_SIZE); - let aes_gcm_tag_end = incoming_packet.len() - HMAC_SIZE; + if incoming_packet[HEADER_SIZE] != SESSION_PROTOCOL_VERSION { + return Err(Error::UnknownProtocolVersion); + } + + let local_s_keypair_p384 = host.get_local_s_keypair_p384(); match packet_type { PACKET_TYPE_KEY_OFFER => { // alice (remote) -> bob (local) + let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE + HMAC_SIZE); + let aes_gcm_tag_end = incoming_packet_len - (HMAC_SIZE + HMAC_SIZE); + let hmac1_end = incoming_packet_len - HMAC_SIZE; + + // Check that the sender knows this host's identity before doing anything else. + if !hmac_sha384(host.get_local_s_public_hash(), &incoming_packet[..hmac1_end]).eq(&incoming_packet[hmac1_end..]) { + return Err(Error::FailedAuthentication); + } + // Check rate limit if this session is known. if let Some(session) = session.as_ref() { if let Some(offer) = session.state.read().offer.as_ref() { @@ -495,36 +564,35 @@ impl Session { } } - let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&buffer[HEADER_SIZE..HEADER_SIZE + P384_PUBLIC_KEY_SIZE]) + let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) .and_then(|pk| local_s_keypair_p384.agree(&pk).map(move |s| (pk, s))) .ok_or(Error::FailedAuthentication)?; let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes())); - let original_ciphertext = buffer.clone(); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), false); - c.init(&get_aes_gcm_nonce(buffer)); - c.crypt_in_place(&mut buffer[(HEADER_SIZE + P384_PUBLIC_KEY_SIZE)..payload_end]); + c.init(&get_aes_gcm_nonce(incoming_packet)); + c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); let c = c.finish(); - if !c.eq(&buffer[payload_end..aes_gcm_tag_end]) { + if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { return Err(Error::FailedAuthentication); } - let (alice_session_id, alice_s_public, alice_e1_public) = parse_KEY_OFFER_after_header(&buffer[(HEADER_SIZE + P384_PUBLIC_KEY_SIZE)..payload_end])?; + let (alice_session_id, alice_s_public, alice_metadata, alice_e1_public) = parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?; - // Important! Check to make sure the caller's public identity matches the one for this session. if let Some(session) = session.as_ref() { + // Important! If there's already a session, make sure the caller is the same endpoint as that session! if !session.remote_s_public_hash.eq(&SHA384::hash(&alice_s_public)) { return Err(Error::FailedAuthentication); } } - let alice_s_public_p384 = extract_p384_static_public(&alice_s_public).ok_or(Error::InvalidPacket)?; + let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?; let ss = local_s_keypair_p384.agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?; let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); - if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]).eq(&buffer[aes_gcm_tag_end..incoming_packet.len()]) { + if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]).eq(&incoming_packet[aes_gcm_tag_end..hmac1_end]) { return Err(Error::FailedAuthentication); } @@ -537,21 +605,21 @@ impl Session { let new_session = if session.is_some() { None } else { - if let Some((local_session_id, psk, associated_object)) = new_session_auth(&alice_s_public) { - Some(Session:: { - id: local_session_id, + if let Some((new_session_id, psk, associated_object)) = host.accept_new_session(alice_s_public, alice_metadata) { + Some(Session:: { + id: new_session_id, associated_object, send_counter: Counter::new(), - remote_s_public_hash: SHA384::hash(&alice_s_public), psk, ss, - outgoing_obfuscator: Obfuscator::new(&alice_s_public), state: RwLock::new(MutableState { remote_session_id: Some(alice_session_id), keys: [None, None], offer: None, }), + remote_s_public_hash: SHA384::hash(&alice_s_public), remote_s_public_p384: alice_s_public_p384.as_bytes().clone(), + defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)), }) } else { return Err(Error::NewSessionRejected); @@ -564,12 +632,15 @@ impl Session { // NIST/FIPS allows HKDF with HMAC(salt, key) and salt is allowed to be anything. This way if the PSK is not // FIPS compliant the compliance of the entire key derivation is not invalidated. Both inputs are secrets of // fixed size so this shouldn't matter cryptographically. - let key = Secret(hmac_sha512(session.psk.as_bytes(), &hmac_sha512(&hmac_sha512(&hmac_sha512(key.as_bytes(), bob_e0_keypair.public_key_bytes()), e0e0.as_bytes()), se0.as_bytes()))); + let key = Secret(hmac_sha512( + session.psk.as_bytes(), + &hmac_sha512(&hmac_sha512(&hmac_sha512(key.as_bytes(), bob_e0_keypair.public_key_bytes()), e0e0.as_bytes()), se0.as_bytes()), + )); // At this point we've completed Noise_IK key derivation with NIST P-384 ECDH, but see final step below... - let (bob_e1_public, e1e1) = if jedi && alice_e1_public.is_some() { - if let Ok((bob_e1_public, e1e1)) = pqc_kyber::encapsulate(alice_e1_public.as_ref().unwrap(), &mut random::SecureRandom::default()) { + let (bob_e1_public, e1e1) = if jedi && alice_e1_public.len() > 0 { + if let Ok((bob_e1_public, e1e1)) = pqc_kyber::encapsulate(alice_e1_public, &mut random::SecureRandom::default()) { (Some(bob_e1_public), Secret(e1e1)) } else { return Err(Error::FailedAuthentication); @@ -578,15 +649,34 @@ impl Session { (None, Secret::default()) // use all zero Kyber secret if disabled }; - let counter = session.send_counter.next(); - let mut reply_size = assemble_KEY_COUNTER_OFFER(buffer, counter, alice_session_id, bob_e0_keypair.public_key(), session.id, bob_e1_public.as_ref()); + let mut reply_buf = [0_u8; MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS]; + let reply_counter = session.send_counter.next(); + let mut reply_len = { + let mut rp = &mut reply_buf[HEADER_SIZE..]; + + rp.write_all(&[SESSION_PROTOCOL_VERSION])?; + rp.write_all(bob_e0_keypair.public_key_bytes())?; + + rp.write_all(&session.id.to_bytes())?; + varint::write(&mut rp, 0)?; // they don't need our static public; they have it + varint::write(&mut rp, 0)?; // no meta-data yet + if let Some(bob_e1_public) = bob_e1_public.as_ref() { + rp.write_all(&[E1_TYPE_KYBER1024])?; + rp.write_all(bob_e1_public)?; + } else { + rp.write_all(&[E1_TYPE_NONE])?; + } + (MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len() + }; + let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_OFFER, session.id.into(), reply_counter); + reply_buf[..HEADER_SIZE].copy_from_slice(&header); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true); - c.init(&get_aes_gcm_nonce(buffer)); - c.crypt_in_place(&mut buffer[(HEADER_SIZE + P384_PUBLIC_KEY_SIZE)..reply_size]); + c.init(&get_aes_gcm_nonce(&header)); + c.crypt_in_place(&mut reply_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..reply_len]); let c = c.finish(); - buffer[reply_size..(reply_size + AES_GCM_TAG_SIZE)].copy_from_slice(&c); - reply_size += AES_GCM_TAG_SIZE; + reply_buf[reply_len..(reply_len + AES_GCM_TAG_SIZE)].copy_from_slice(&c); + reply_len += AES_GCM_TAG_SIZE; // Normal Noise_IK is done, but we have one more step: mix in the Kyber shared secret (or all zeroes if Kyber is // disabled). We have to wait until this point because Kyber's keys are encrypted and can't be decrypted until @@ -594,57 +684,58 @@ impl Session { // key derivation step. let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes())); - let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &buffer[..reply_size]); - buffer[reply_size..reply_size + HMAC_SIZE].copy_from_slice(&hmac); - reply_size += HMAC_SIZE; + let hmac = hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &reply_buf[..reply_len]); + reply_buf[reply_len..(reply_len + HMAC_SIZE)].copy_from_slice(&hmac); + reply_len += HMAC_SIZE; let mut state = session.state.write(); let _ = state.remote_session_id.replace(alice_session_id); - state.keys[1].replace(SessionKey::new(key, Role::Bob, current_time, counter, jedi)); + state.keys[1].replace(SessionKey::new(key, Role::Bob, current_time, reply_counter, jedi)); drop(state); // Bob now has final key state for this exchange. Yay! Now reply to Alice so she can construct it. - session.outgoing_obfuscator.0.encrypt_block_in_place(&mut buffer[0..16]); - for i in (16..64).step_by(16) { - let j = i + 16; - for k in i..j { - buffer[k] ^= buffer[k - 16]; - } - session.outgoing_obfuscator.0.encrypt_block_in_place(&mut buffer[i..j]); + send_with_fragmentation(send, &mut reply_buf[..reply_len], mtu, &mut header); + if new_session.is_some() { + return Ok(ReceiveResult::OkNewSession(new_session.unwrap())); + } else { + return Ok(ReceiveResult::Ok); } - - return new_session.map_or_else(|| Ok(ReceiveResult::OkSendReply(&buffer[..reply_size])), |ns| Ok(ReceiveResult::OkNewSession(ns, &buffer[..reply_size]))); } PACKET_TYPE_KEY_COUNTER_OFFER => { // bob (remote) -> alice (local) + let payload_end = incoming_packet_len - (AES_GCM_TAG_SIZE + HMAC_SIZE); + let aes_gcm_tag_end = incoming_packet_len - HMAC_SIZE; + if let Some(session) = session { let state = session.state.upgradable_read(); if let Some(offer) = state.offer.as_ref() { - let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&buffer[HEADER_SIZE..(HEADER_SIZE + P384_PUBLIC_KEY_SIZE)]) + let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) .and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s))) .ok_or(Error::FailedAuthentication)?; let se0 = local_s_keypair_p384.agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?; - let key = Secret(hmac_sha512(session.psk.as_bytes(), &hmac_sha512(&hmac_sha512(&hmac_sha512(offer.key.as_bytes(), bob_e0_public.as_bytes()), e0e0.as_bytes()), se0.as_bytes()))); + let key = Secret(hmac_sha512( + session.psk.as_bytes(), + &hmac_sha512(&hmac_sha512(&hmac_sha512(offer.key.as_bytes(), bob_e0_public.as_bytes()), e0e0.as_bytes()), se0.as_bytes()), + )); - let original_ciphertext = buffer.clone(); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), false); - c.init(&get_aes_gcm_nonce(buffer)); - c.crypt_in_place(&mut buffer[(HEADER_SIZE + P384_PUBLIC_KEY_SIZE)..payload_end]); + c.init(&get_aes_gcm_nonce(incoming_packet)); + c.crypt_in_place(&mut incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..payload_end]); let c = c.finish(); - if !c.eq(&buffer[payload_end..aes_gcm_tag_end]) { + if !c.eq(&incoming_packet[payload_end..aes_gcm_tag_end]) { return Err(Error::FailedAuthentication); } // Alice has now completed Noise_IK for P-384 and verified with GCM auth, now for the hybrid add-on. - let (bob_session_id, bob_e1_public) = parse_KEY_COUNTER_OFFER_after_header(&buffer[(HEADER_SIZE + P384_PUBLIC_KEY_SIZE)..payload_end])?; + let (bob_session_id, _, _, bob_e1_public) = parse_key_offer_after_header(&incoming_packet[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..], packet_type)?; - let e1e1 = if jedi && bob_e1_public.is_some() && offer.alice_e1_keypair.is_some() { - if let Ok(e1e1) = pqc_kyber::decapsulate(bob_e1_public.as_ref().unwrap(), &offer.alice_e1_keypair.as_ref().unwrap().secret) { + let e1e1 = if jedi && bob_e1_public.len() > 0 && offer.alice_e1_keypair.is_some() { + if let Ok(e1e1) = pqc_kyber::decapsulate(bob_e1_public, &offer.alice_e1_keypair.as_ref().unwrap().secret) { Secret(e1e1) } else { return Err(Error::FailedAuthentication); @@ -655,7 +746,9 @@ impl Session { let key = Secret(hmac_sha512(e1e1.as_bytes(), key.as_bytes())); - if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]).eq(&buffer[aes_gcm_tag_end..incoming_packet.len()]) { + if !hmac_sha384(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_HMAC).first_n::<48>(), &original_ciphertext[..aes_gcm_tag_end]) + .eq(&incoming_packet[aes_gcm_tag_end..incoming_packet.len()]) + { return Err(Error::FailedAuthentication); } @@ -668,21 +761,29 @@ impl Session { let _ = state.remote_session_id.replace(bob_session_id); if state.keys[0].is_some() { let _ = state.keys[1].replace(SessionKey::new(key, Role::Alice, current_time, session.send_counter.current(), jedi)); - return Ok(ReceiveResult::Ok); } else { let counter = session.send_counter.next(); let key = SessionKey::new(key, Role::Alice, current_time, counter, jedi); - let dummy_data_len = (random::next_u32_secure() % (MAX_PACKET_SIZE - (HEADER_SIZE + AES_GCM_TAG_SIZE)) as u32) as usize; - let mut dummy_data = [0_u8; MAX_PACKET_SIZE]; - random::fill_bytes_secure(&mut dummy_data[..dummy_data_len]); - let nop_len = assemble_and_armor_DATA(buffer, &dummy_data[..dummy_data_len], PACKET_TYPE_NOP, u64::from(bob_session_id), counter, &key, &session.outgoing_obfuscator)?; + let mut reply_buf = [0_u8; MIN_MTU]; + let dummy_data_len = (random::next_u32_secure() % (mtu - (HEADER_SIZE + AES_GCM_TAG_SIZE)) as u32) as usize; + let reply_len = dummy_data_len + HEADER_SIZE + AES_GCM_TAG_SIZE; + let header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_NOP, bob_session_id.into(), counter); + reply_buf[..HEADER_SIZE].copy_from_slice(&header); + + let mut c = key.get_send_cipher(counter)?; + c.init(&get_aes_gcm_nonce(&reply_buf)); + c.crypt_in_place(&mut reply_buf[HEADER_SIZE..(HEADER_SIZE + dummy_data_len)]); + reply_buf[(HEADER_SIZE + dummy_data_len)..reply_len].copy_from_slice(&c.finish()); + key.return_send_cipher(c); + + send(&mut reply_buf[..reply_len]); let _ = state.keys[0].replace(key); let _ = state.keys[1].take(); - - return Ok(ReceiveResult::OkSendReply(&buffer[..nop_len])); } + + return Ok(ReceiveResult::Ok); } } @@ -695,15 +796,13 @@ impl Session { } } } - */ } -#[repr(transparent)] struct Counter(AtomicU64); impl Counter { fn new() -> Self { - Self(AtomicU64::new(0)) + Self(AtomicU64::new(random::next_u32_secure() as u64)) } #[inline(always)] @@ -728,13 +827,8 @@ struct CounterValue(u64); impl CounterValue { #[inline(always)] - pub fn to_bytes(&self) -> [u8; 4] { - (self.0 as u32).to_le_bytes() - } - - #[inline(always)] - pub fn lsb(&self) -> u8 { - self.0 as u8 + pub fn to_u32(&self) -> u32 { + self.0 as u32 } } @@ -773,16 +867,16 @@ struct EphemeralOffer { } impl EphemeralOffer { - fn create_alice_offer( - mut send: SendFunction, + fn create_alice_offer( + send: SendFunction, counter: CounterValue, alice_session_id: SessionId, bob_session_id: Option, alice_s_public: &[u8], alice_metadata: &[u8], bob_s_public_p384: &P384PublicKey, + bob_s_public_hash: &[u8], ss: &Secret<48>, - outgoing_obfuscator: &Obfuscator, mtu: usize, current_time: i64, jedi: bool, @@ -799,18 +893,12 @@ impl EphemeralOffer { None }; - let bob_session_id_bytes = bob_session_id.map_or(0_u64, |i| i.into()).to_le_bytes(); - - const PACKET_BUF_SIZE: usize = 3072; + const PACKET_BUF_SIZE: usize = MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS; let mut packet_buf = [0_u8; PACKET_BUF_SIZE]; - let mut packet_len = { - let mut p = &mut packet_buf[..]; - - p.write_all(&[PACKET_TYPE_KEY_OFFER])?; - p.write_all(&bob_session_id_bytes[..SESSION_ID_SIZE])?; - p.write_all(&counter.to_bytes())?; + let mut p = &mut packet_buf[HEADER_SIZE..]; + p.write_all(&[SESSION_PROTOCOL_VERSION])?; p.write_all(alice_e0_keypair.public_key_bytes())?; p.write_all(&alice_session_id.0.get().to_le_bytes()[..SESSION_ID_SIZE])?; @@ -828,9 +916,8 @@ impl EphemeralOffer { PACKET_BUF_SIZE - p.len() }; - if packet_len > mtu { - packet_buf[0] |= PACKET_FLAG_CONTINUED; - } + let mut header = send_with_fragmentation_init_header(packet_len, mtu, PACKET_TYPE_KEY_OFFER, bob_session_id.map_or(0_u64, |i| i.into()), counter); + packet_buf[..HEADER_SIZE].copy_from_slice(&header); let key = Secret(hmac_sha512( &hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_keypair.public_key_bytes()), @@ -840,7 +927,7 @@ impl EphemeralOffer { let gcm_tag = { let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB).first_n::<32>(), true); c.init(&get_aes_gcm_nonce(&packet_buf)); - c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + P384_PUBLIC_KEY_SIZE)..packet_len]); + c.crypt_in_place(&mut packet_buf[(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)..packet_len]); c.finish() }; packet_buf[packet_len..(packet_len + AES_GCM_TAG_SIZE)].copy_from_slice(&gcm_tag); @@ -852,12 +939,11 @@ impl EphemeralOffer { packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); packet_len += HMAC_SIZE; - cbc_obfuscate_first_64(outgoing_obfuscator, &mut packet_buf); - if packet_len > mtu { - send_fragmented(send, &mut packet_buf[..packet_len], counter, mtu, &bob_session_id_bytes[..SESSION_ID_SIZE], outgoing_obfuscator)?; - } else { - send(&packet_buf[..packet_len]); - } + let hmac = hmac_sha384(bob_s_public_hash, &packet_buf[..packet_len]); + packet_buf[packet_len..(packet_len + HMAC_SIZE)].copy_from_slice(&hmac); + packet_len += HMAC_SIZE; + + send_with_fragmentation(send, &mut packet_buf[..packet_len], mtu, &mut header); Ok(EphemeralOffer { creation_time: current_time, @@ -868,92 +954,81 @@ impl EphemeralOffer { } } -/// Send a packet that must be fragmented. -/// -/// The packet MUST have its CONTINUED flag set in its header. This isn't used -/// for unfragmented packets. Those are just sent directly. -/// -/// The packet should be obfuscated as normal. This handles obfuscation of -/// fragments after the head. The contents of 'packet' are partly overwritten. -fn send_fragmented( - mut send: SendFunction, - packet: &mut [u8], - counter: CounterValue, - mtu: usize, - remote_session_id_bytes: &[u8], - outgoing_obfuscator: &Obfuscator, -) -> std::io::Result<()> { +#[inline(always)] +fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] { + let fragment_count = (packet_len / mtu) + (((packet_len % mtu) != 0) as usize); + debug_assert!(mtu >= MIN_MTU); + debug_assert!(packet_len >= HEADER_SIZE); + debug_assert!(fragment_count <= MAX_FRAGMENTS); + debug_assert!(packet_type <= 0x0f); // packet type is 4 bits + debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits + + // Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter + let mut header = (fragment_count.wrapping_shl(4) | (packet_type as usize)) as u128; + header |= recipient_session_id.wrapping_shl(16) as u128; + header |= (counter.to_u32() as u128).wrapping_shl(64); + header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap() +} + +#[inline(always)] +fn send_with_fragmentation(mut send: SendFunction, packet: &mut [u8], mtu: usize, header: &mut [u8; HEADER_SIZE]) { let packet_len = packet.len(); - debug_assert!(packet_len >= MIN_PACKET_SIZE); - debug_assert!(mtu > MIN_PACKET_SIZE); - - let frag_len_max = ((packet_len as f64) / ((packet_len as f64) / ((mtu - HEADER_SIZE) as f64)).ceil()).ceil() as usize; - debug_assert!(frag_len_max > 0); - let mut frag_len = packet_len.min(frag_len_max); - debug_assert!(frag_len > 0); - - send(&packet[..frag_len]); - - let frag0_tail = [packet[frag_len - 2], packet[frag_len - 1]]; - - let mut next_frag_start = frag_len; - let mut frag_no = 1_u8; - while next_frag_start < packet_len { - debug_assert!(next_frag_start > HEADER_SIZE); - frag_len = (packet_len - next_frag_start).min(frag_len_max); - debug_assert!(frag_len > MIN_PACKET_SIZE); - let frag_end = next_frag_start + frag_len; - debug_assert!(frag_end <= packet_len); - - next_frag_start -= HEADER_SIZE; - let mut frag_hdr = &mut packet[next_frag_start..]; - frag_hdr.write_all(&[if frag_end == packet_len { - PACKET_TYPE_CONTINUATION + let mut fragment_start = 0; + let mut fragment_end = packet_len.min(mtu); + loop { + send(&mut packet[fragment_start..fragment_end]); + if fragment_end < packet_len { + fragment_start = fragment_end - HEADER_SIZE; + fragment_end = (fragment_start + mtu).min(packet_len); + header[1] += 0x04; // increment fragment number, at bit 2 in byte 1 since type/fragment u16 is little-endian + packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header); } else { - PACKET_TYPE_CONTINUATION | PACKET_FLAG_CONTINUED - }])?; - frag_hdr.write_all(&remote_session_id_bytes)?; - frag_hdr.write_all(&[counter.lsb(), frag_no])?; - frag_no += 1; - frag_hdr.write_all(&frag0_tail)?; - - outgoing_obfuscator.0.encrypt_block_in_place(&mut packet[next_frag_start..(next_frag_start + 16)]); - send(&packet[next_frag_start..frag_end]); - - next_frag_start = frag_end; - } - - Ok(()) -} - -/// Obfuscate the first 64 bytes of a packet (used for key exchanges). -fn cbc_obfuscate_first_64(ob: &Obfuscator, data: &mut [u8]) { - ob.0.encrypt_block_in_place(&mut data[0..16]); - let mut i = 16; - while i < 64 { - let j = i + 16; - for k in i..j { - data[k] ^= data[k - 16]; + debug_assert_eq!(fragment_end, packet_len); + break; } - ob.0.encrypt_block_in_place(&mut data[i..j]); - i = j; } } -/// Deobfuscate the last 48 bytes of a packet (used for key exchanges). -/// -/// This is used when decoding key exchange packets. The first 16 bytes are always -/// deobfuscated, so this assumes that's already been done and finishes. -fn cbc_debofuscate_16_to_64(ob: &Obfuscator, input: &[u8], output: &mut [u8]) { - let mut i = 16; - while i < 64 { - let j = i + 16; - ob.0.decrypt_block(&input[i..j], &mut output[i..j]); - for k in i..j { - output[k] ^= input[k - 16]; - } - i = j; +fn parse_key_offer_after_header(incoming_packet: &[u8], packet_type: u8) -> Result<(SessionId, &[u8], &[u8], &[u8]), Error> { + let mut p = &incoming_packet[..]; + let alice_session_id = SessionId::new_from_reader(&mut p)?; + if alice_session_id.is_none() { + return Err(Error::InvalidPacket); } + let alice_session_id = alice_session_id.unwrap(); + let alice_s_public_len = varint::read(&mut p)?.0; + if (p.len() as u64) < alice_s_public_len { + return Err(Error::InvalidPacket); + } + let alice_s_public = &p[..(alice_s_public_len as usize)]; + p = &p[(alice_s_public_len as usize)..]; + let alice_metadata_len = varint::read(&mut p)?.0; + if (p.len() as u64) < alice_metadata_len { + return Err(Error::InvalidPacket); + } + let alice_metadata = &p[..(alice_metadata_len as usize)]; + p = &p[(alice_metadata_len as usize)..]; + if p.is_empty() { + return Err(Error::InvalidPacket); + } + let alice_e1_public = match p[0] { + E1_TYPE_KYBER1024 => { + if packet_type == PACKET_TYPE_KEY_OFFER { + if p.len() < (pqc_kyber::KYBER_PUBLICKEYBYTES + 1) { + return Err(Error::InvalidPacket); + } + &p[1..(pqc_kyber::KYBER_PUBLICKEYBYTES + 1)] + } else { + if p.len() < (pqc_kyber::KYBER_CIPHERTEXTBYTES + 1) { + return Err(Error::InvalidPacket); + } + &p[1..(pqc_kyber::KYBER_CIPHERTEXTBYTES + 1)] + } + } + _ => &[], + }; + + Ok((alice_session_id, alice_s_public, alice_metadata, alice_e1_public)) } enum Role { @@ -1017,112 +1092,15 @@ impl SessionKey { } } -/* -#[allow(non_snake_case)] -fn parse_KEY_OFFER_after_header(mut b: &[u8]) -> Result<(SessionId, [u8; STATIC_PUBLIC_SIZE], Option<[u8; pqc_kyber::KYBER_PUBLICKEYBYTES]>), Error> { - if b.len() >= SESSION_ID_SIZE { - let alice_session_id = SessionId::new_from_bytes(b).ok_or(Error::InvalidPacket)?; - b = &b[SESSION_ID_SIZE..]; - if b.len() >= STATIC_PUBLIC_SIZE { - let alice_s_public: [u8; STATIC_PUBLIC_SIZE] = b[..STATIC_PUBLIC_SIZE].try_into().unwrap(); - b = &b[STATIC_PUBLIC_SIZE..]; - if b.len() >= 1 { - let e1_type = b[0]; - b = &b[1..]; - let alice_e1_public = if e1_type == E1_TYPE_KYBER1024 { - if b.len() >= pqc_kyber::KYBER_PUBLICKEYBYTES { - let k: [u8; pqc_kyber::KYBER_PUBLICKEYBYTES] = b[..pqc_kyber::KYBER_PUBLICKEYBYTES].try_into().unwrap(); - b = &b[pqc_kyber::KYBER_PUBLICKEYBYTES..]; - Some(k) - } else { - return Err(Error::InvalidPacket); - } - } else { - None - }; - if b.len() >= 2 { - return Ok((alice_session_id, alice_s_public, alice_e1_public)); - } - } - } - } - return Err(Error::InvalidPacket); -} - -#[allow(non_snake_case)] -fn assemble_KEY_COUNTER_OFFER( - buffer: &mut [u8; MAX_PACKET_SIZE], - counter: CounterValue, - alice_session_id: SessionId, - bob_e0_public: &P384PublicKey, - bob_session_id: SessionId, - bob_e1_public: Option<&[u8; pqc_kyber::KYBER_CIPHERTEXTBYTES]>, -) -> usize { - buffer[0] = PACKET_TYPE_KEY_COUNTER_OFFER; - alice_session_id.copy_to(&mut buffer[1..7]); - buffer[7..11].copy_from_slice(&counter.to_bytes()); - let mut b = &mut buffer[HEADER_SIZE..]; - - b[..P384_PUBLIC_KEY_SIZE].copy_from_slice(bob_e0_public.as_bytes()); - b = &mut b[P384_PUBLIC_KEY_SIZE..]; - - bob_session_id.copy_to(b); - b = &mut b[SESSION_ID_SIZE..]; - - if let Some(k) = bob_e1_public { - b[0] = E1_TYPE_KYBER1024; - b[1..1 + pqc_kyber::KYBER_CIPHERTEXTBYTES].copy_from_slice(k); - b = &mut b[1 + pqc_kyber::KYBER_CIPHERTEXTBYTES..]; - } else { - b[0] = E1_TYPE_NONE; - b = &mut b[1..]; - } - - b[0] = 0; - b[1] = 0; // reserved for future use - b = &mut b[2..]; - - MAX_PACKET_SIZE - b.len() -} - -#[allow(non_snake_case)] -fn parse_KEY_COUNTER_OFFER_after_header(mut b: &[u8]) -> Result<(SessionId, Option<[u8; pqc_kyber::KYBER_CIPHERTEXTBYTES]>), Error> { - if b.len() >= SESSION_ID_SIZE { - let bob_session_id = SessionId::new_from_bytes(b).ok_or(Error::InvalidPacket)?; - b = &b[SESSION_ID_SIZE..]; - if b.len() >= 1 { - let e1_type = b[0]; - b = &b[1..]; - let bob_e1_public = if e1_type == E1_TYPE_KYBER1024 { - if b.len() >= pqc_kyber::KYBER_CIPHERTEXTBYTES { - let k: [u8; pqc_kyber::KYBER_CIPHERTEXTBYTES] = b[..pqc_kyber::KYBER_CIPHERTEXTBYTES].try_into().unwrap(); - b = &b[pqc_kyber::KYBER_CIPHERTEXTBYTES..]; - Some(k) - } else { - return Err(Error::InvalidPacket); - } - } else { - None - }; - if b.len() >= 1 && b[0] == 0 { - return Ok((bob_session_id, bob_e1_public)); - } - } - } - return Err(Error::InvalidPacket); -} - -*/ - /// HMAC-SHA512 key derivation function modeled on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 12) fn kbkdf512(key: &[u8], label: u8) -> Secret<64> { Secret(hmac_sha512(key, &[0, 0, 0, 0, b'Z', b'T', label, 0, 0, 0, 0, 0x02, 0x00])) } #[inline(always)] -fn get_aes_gcm_nonce(deobfuscated_packet: &[u8]) -> [u8; 16] { - let mut tmp = 0_u128.to_ne_bytes(); - tmp[..HEADER_SIZE].copy_from_slice(&deobfuscated_packet[..HEADER_SIZE]); +fn get_aes_gcm_nonce(packet: &[u8]) -> [u8; 16] { + let mut tmp = 0u128.to_ne_bytes(); + tmp[..HEADER_SIZE].copy_from_slice(&packet[..HEADER_SIZE]); tmp } diff --git a/network-hypervisor/Cargo.toml b/network-hypervisor/Cargo.toml index e0548cba2..2be745cdc 100644 --- a/network-hypervisor/Cargo.toml +++ b/network-hypervisor/Cargo.toml @@ -11,6 +11,7 @@ debug_events = [] [dependencies] zerotier-core-crypto = { path = "../core-crypto" } +zerotier-utils = { path = "../utils" } async-trait = "^0" base64 = "^0" lz4_flex = { version = "^0", features = ["safe-encode", "safe-decode", "checked-decode"] } diff --git a/utils/Cargo.toml b/utils/Cargo.toml new file mode 100644 index 000000000..0cb95ea73 --- /dev/null +++ b/utils/Cargo.toml @@ -0,0 +1,8 @@ +[package] +authors = ["ZeroTier, Inc. "] +edition = "2021" +license = "MPL-2.0" +name = "zerotier-utils" +version = "0.1.0" + +[dependencies] diff --git a/utils/rustfmt.toml b/utils/rustfmt.toml new file mode 120000 index 000000000..39f97b043 --- /dev/null +++ b/utils/rustfmt.toml @@ -0,0 +1 @@ +../rustfmt.toml \ No newline at end of file diff --git a/utils/src/arrayvec.rs b/utils/src/arrayvec.rs new file mode 100644 index 000000000..ee089d27d --- /dev/null +++ b/utils/src/arrayvec.rs @@ -0,0 +1,106 @@ +// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. + +use std::mem::{size_of, MaybeUninit}; +use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut}; + +/// A simple vector backed by a static sized array with no memory allocations and no overhead construction. +pub struct ArrayVec { + pub(crate) a: [MaybeUninit; C], + pub(crate) s: usize, +} + +impl ArrayVec { + #[inline(always)] + pub fn new() -> Self { + assert_eq!(size_of::<[T; C]>(), size_of::<[MaybeUninit; C]>()); + Self { a: unsafe { MaybeUninit::uninit().assume_init() }, s: 0 } + } + + #[inline(always)] + pub fn push(&mut self, v: T) { + if self.s < C { + let i = self.s; + unsafe { self.a.get_unchecked_mut(i).write(v) }; + self.s = i + 1; + } else { + panic!(); + } + } + + #[inline(always)] + pub fn try_push(&mut self, v: T) -> bool { + if self.s < C { + let i = self.s; + unsafe { self.a.get_unchecked_mut(i).write(v) }; + self.s = i + 1; + true + } else { + false + } + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.s == 0 + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.s + } + + #[inline(always)] + pub fn pop(&mut self) -> Option { + if self.s > 0 { + let i = self.s - 1; + debug_assert!(i < C); + self.s = i; + Some(unsafe { self.a.get_unchecked(i).assume_init_read() }) + } else { + None + } + } +} + +impl Drop for ArrayVec { + #[inline(always)] + fn drop(&mut self) { + for i in 0..self.s { + unsafe { self.a.get_unchecked_mut(i).assume_init_drop() }; + } + } +} + +impl AsRef<[T]> for ArrayVec { + #[inline(always)] + fn as_ref(&self) -> &[T] { + unsafe { &*slice_from_raw_parts(self.a.as_ptr().cast(), self.s) } + } +} + +impl AsMut<[T]> for ArrayVec { + #[inline(always)] + fn as_mut(&mut self) -> &mut [T] { + unsafe { &mut *slice_from_raw_parts_mut(self.a.as_mut_ptr().cast(), self.s) } + } +} + +#[cfg(test)] +mod tests { + use super::ArrayVec; + + #[test] + fn array_vec() { + let mut v = ArrayVec::::new(); + for i in 0..128 { + v.push(i); + } + assert_eq!(v.len(), 128); + assert!(!v.try_push(1000)); + assert_eq!(v.len(), 128); + for _ in 0..128 { + assert!(v.pop().is_some()); + } + assert!(v.pop().is_none()); + } +} diff --git a/utils/src/gatherarray.rs b/utils/src/gatherarray.rs new file mode 100644 index 000000000..4cf93a230 --- /dev/null +++ b/utils/src/gatherarray.rs @@ -0,0 +1,92 @@ +// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. + +use std::mem::{size_of, MaybeUninit}; +use std::ptr::copy_nonoverlapping; + +use crate::arrayvec::ArrayVec; + +/// A fixed sized array of items to be gathered with fast check logic to return when complete. +/// +/// This supports a maximum capacity of 64 and will panic if created with a larger value for C. +pub struct GatherArray { + a: [MaybeUninit; C], + have_bits: u64, + have_count: u32, + goal: u32, +} + +impl GatherArray { + /// Create a new gather array, which must be initialized prior to use. + #[inline(always)] + pub fn new(goal: u32) -> Self { + assert!(C <= 64); + assert!(goal <= (C as u32)); + assert_eq!(size_of::<[T; C]>(), size_of::<[MaybeUninit; C]>()); + Self { + a: unsafe { MaybeUninit::uninit().assume_init() }, + have_bits: 0, + have_count: 0, + goal, + } + } + + /// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here. + #[inline(always)] + pub fn add(&mut self, index: u32, value: T) -> Option> { + if index < self.goal { + let mut have = self.have_bits; + let got = 1u64.wrapping_shl(index); + if (have & got) == 0 { + have |= got; + self.have_bits = have; + let count = self.have_count + 1; + self.have_count = count; + let goal = self.goal as usize; + unsafe { + self.a.get_unchecked_mut(index as usize).write(value); + if (self.have_count as usize) == goal { + debug_assert_eq!(0xffffffffffffffffu64.wrapping_shr(64 - goal as u32), have); + let mut tmp = ArrayVec::new(); + copy_nonoverlapping(self.a.as_ptr().cast::(), tmp.a.as_mut_ptr().cast::(), size_of::>() * goal); + tmp.s = goal; + self.goal = 0; + return Some(tmp); + } + } + } + } + return None; + } +} + +impl Drop for GatherArray { + #[inline(always)] + fn drop(&mut self) { + let have = self.have_bits; + for i in 0..self.goal { + if (have & 1u64.wrapping_shl(i)) != 0 { + unsafe { self.a.get_unchecked_mut(i as usize).assume_init_drop() }; + } + } + self.goal = 0; + } +} + +#[cfg(test)] +mod tests { + use super::GatherArray; + + #[test] + fn gather_array() { + for goal in 2..64 { + let mut m = GatherArray::::new(goal); + for x in 0..(goal - 1) { + assert!(m.add(x, x).is_none()); + } + let r = m.add(goal - 1, goal - 1).unwrap(); + for x in 0..goal { + assert_eq!(r.as_ref()[x as usize], x); + } + } + } +} diff --git a/core-crypto/src/hex.rs b/utils/src/hex.rs similarity index 100% rename from core-crypto/src/hex.rs rename to utils/src/hex.rs diff --git a/utils/src/lib.rs b/utils/src/lib.rs new file mode 100644 index 000000000..d134532dc --- /dev/null +++ b/utils/src/lib.rs @@ -0,0 +1,5 @@ +pub mod arrayvec; +pub mod gatherarray; +pub mod hex; +pub mod ringbuffermap; +pub mod varint; diff --git a/utils/src/ringbuffermap.rs b/utils/src/ringbuffermap.rs new file mode 100644 index 000000000..79508ad10 --- /dev/null +++ b/utils/src/ringbuffermap.rs @@ -0,0 +1,255 @@ +// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md. + +use std::hash::{Hash, Hasher}; + +use std::mem::MaybeUninit; + +#[inline(always)] +fn xorshift64(mut x: u64) -> u64 { + x ^= x.wrapping_shl(13); + x ^= x.wrapping_shr(7); + x ^= x.wrapping_shl(17); + x +} + +struct XorShiftHasher(u64); + +impl XorShiftHasher { + #[inline(always)] + fn new(salt: u32) -> Self { + Self(salt as u64) + } +} + +impl Hasher for XorShiftHasher { + #[inline(always)] + fn finish(&self) -> u64 { + self.0 + } + + #[inline(always)] + fn write(&mut self, mut bytes: &[u8]) { + let mut x = self.0; + while bytes.len() >= 8 { + x = xorshift64(x.wrapping_add(u64::from_ne_bytes(unsafe { *bytes.as_ptr().cast::<[u8; 8]>() }))); + bytes = &bytes[8..]; + } + while bytes.len() >= 4 { + x = xorshift64(x.wrapping_add(u32::from_ne_bytes(unsafe { *bytes.as_ptr().cast::<[u8; 4]>() }) as u64)); + bytes = &bytes[4..]; + } + for b in bytes.iter() { + x = xorshift64(x.wrapping_add(*b as u64)); + } + self.0 = x; + } + + #[inline(always)] + fn write_isize(&mut self, i: isize) { + self.0 = xorshift64(self.0.wrapping_add(i as u64)); + } + + #[inline(always)] + fn write_usize(&mut self, i: usize) { + self.0 = xorshift64(self.0.wrapping_add(i as u64)); + } + + #[inline(always)] + fn write_i32(&mut self, i: i32) { + self.0 = xorshift64(self.0.wrapping_add(i as u64)); + } + + #[inline(always)] + fn write_u32(&mut self, i: u32) { + self.0 = xorshift64(self.0.wrapping_add(i as u64)); + } + + #[inline(always)] + fn write_i64(&mut self, i: i64) { + self.0 = xorshift64(self.0.wrapping_add(i as u64)); + } + + #[inline(always)] + fn write_u64(&mut self, i: u64) { + self.0 = xorshift64(self.0.wrapping_add(i)); + } +} + +struct Entry { + key: MaybeUninit, + value: MaybeUninit, + bucket: i32, // which bucket is this in? -1 for none + next: i32, // next item in bucket's linked list, -1 for none + prev: i32, // previous entry to permit deletion of old entries from bucket lists +} + +/// A hybrid between a circular buffer and a map. +/// +/// The map has a finite capacity. If a new entry is added and there's no more room the oldest +/// entry is removed and overwritten. The same could be achieved by pairing a circular buffer +/// with a HashMap but that would be less efficient. This requires no memory allocations unless +/// the K or V types allocate memory and occupies a fixed amount of memory. +/// +/// This is pretty basic and doesn't have a remove function. Old entries just roll off. This +/// only contains what is needed elsewhere in the project. +/// +/// The C template parameter is the total capacity while the B parameter is the number of +/// buckets in the hash table. +pub struct RingBufferMap { + entries: [Entry; C], + buckets: [i32; B], + entry_ptr: u32, + salt: u32, +} + +impl RingBufferMap { + #[inline] + pub fn new(salt: u32) -> Self { + Self { + entries: std::array::from_fn(|_| Entry:: { + key: MaybeUninit::uninit(), + value: MaybeUninit::uninit(), + bucket: -1, + next: -1, + prev: -1, + }), + buckets: [-1; B], + entry_ptr: 0, + salt, + } + } + + #[inline] + pub fn get(&self, key: &K) -> Option<&V> { + let mut h = XorShiftHasher::new(self.salt); + key.hash(&mut h); + let mut e = self.buckets[(h.finish() as usize) % B]; + while e >= 0 { + let ee = &self.entries[e as usize]; + debug_assert!(ee.bucket >= 0); + if unsafe { ee.key.assume_init_ref().eq(key) } { + return Some(unsafe { &ee.value.assume_init_ref() }); + } + e = ee.next; + } + return None; + } + + /// Get an entry, creating if not present. + #[inline] + pub fn get_or_create_mut V>(&mut self, key: &K, create: CF) -> &mut V { + let mut h = XorShiftHasher::new(self.salt); + key.hash(&mut h); + let bucket = (h.finish() as usize) % B; + + let mut e = self.buckets[bucket]; + while e >= 0 { + unsafe { + let e_ptr = &mut *self.entries.as_mut_ptr().add(e as usize); + debug_assert!(e_ptr.bucket >= 0); + if e_ptr.key.assume_init_ref().eq(key) { + return e_ptr.value.assume_init_mut(); + } + e = e_ptr.next; + } + } + + return self.internal_add(bucket, key.clone(), create()); + } + + /// Set a value or create a new entry if not found. + #[inline] + pub fn set(&mut self, key: K, value: V) { + let mut h = XorShiftHasher::new(self.salt); + key.hash(&mut h); + let bucket = (h.finish() as usize) % B; + + let mut e = self.buckets[bucket]; + while e >= 0 { + let e_ptr = &mut self.entries[e as usize]; + debug_assert!(e_ptr.bucket >= 0); + if unsafe { e_ptr.key.assume_init_ref().eq(&key) } { + unsafe { *e_ptr.value.assume_init_mut() = value }; + return; + } + e = e_ptr.next; + } + + self.internal_add(bucket, key, value); + } + + #[inline] + fn internal_add(&mut self, bucket: usize, key: K, value: V) -> &mut V { + let e = (self.entry_ptr as usize) % C; + self.entry_ptr = self.entry_ptr.wrapping_add(1); + let e_ptr = unsafe { &mut *self.entries.as_mut_ptr().add(e) }; + + if e_ptr.bucket >= 0 { + if e_ptr.prev >= 0 { + self.entries[e_ptr.prev as usize].next = e_ptr.next; + } else { + self.buckets[e_ptr.bucket as usize] = e_ptr.next; + } + unsafe { + e_ptr.key.assume_init_drop(); + e_ptr.value.assume_init_drop(); + } + } + + e_ptr.key.write(key); + e_ptr.value.write(value); + e_ptr.bucket = bucket as i32; + e_ptr.next = self.buckets[bucket]; + if e_ptr.next >= 0 { + self.entries[e_ptr.next as usize].prev = e as i32; + } + self.buckets[bucket] = e as i32; + e_ptr.prev = -1; + unsafe { e_ptr.value.assume_init_mut() } + } +} + +impl Drop for RingBufferMap { + #[inline] + fn drop(&mut self) { + for e in self.entries.iter_mut() { + if e.bucket >= 0 { + unsafe { + e.key.assume_init_drop(); + e.value.assume_init_drop(); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::RingBufferMap; + + #[test] + fn finite_map() { + let mut m = RingBufferMap::::new(1); + for i in 0..64 { + m.set(i, i); + } + for i in 0..64 { + assert_eq!(*m.get(&i).unwrap(), i); + } + + for i in 0..256 { + m.set(i, i); + } + for i in 0..128 { + assert!(m.get(&i).is_none()); + } + for i in 128..256 { + assert_eq!(*m.get(&i).unwrap(), i); + } + + m.set(1000, 1000); + assert!(m.get(&128).is_none()); + assert_eq!(*m.get(&129).unwrap(), 129); + assert_eq!(*m.get(&1000).unwrap(), 1000); + } +} diff --git a/core-crypto/src/varint.rs b/utils/src/varint.rs similarity index 100% rename from core-crypto/src/varint.rs rename to utils/src/varint.rs