Session works again, and some optimization.

This commit is contained in:
Adam Ierymenko 2022-09-06 18:03:31 -04:00
parent 3770fcdc83
commit 06573c1ea8
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
5 changed files with 374 additions and 261 deletions

View file

@ -12,6 +12,7 @@ use crate::random;
use crate::secret::Secret; use crate::secret::Secret;
use zerotier_utils::gatherarray::GatherArray; use zerotier_utils::gatherarray::GatherArray;
use zerotier_utils::memory;
use zerotier_utils::ringbuffermap::RingBufferMap; use zerotier_utils::ringbuffermap::RingBufferMap;
use zerotier_utils::varint; use zerotier_utils::varint;
@ -41,6 +42,7 @@ const OFFER_RATE_LIMIT_MS: i64 = 2000;
/// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication) /// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication)
const SESSION_PROTOCOL_VERSION: u8 = 0x00; const SESSION_PROTOCOL_VERSION: u8 = 0x00;
// Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use
const PACKET_TYPE_DATA: u8 = 0; const PACKET_TYPE_DATA: u8 = 0;
const PACKET_TYPE_NOP: u8 = 1; const PACKET_TYPE_NOP: u8 = 1;
const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice" const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice"
@ -56,6 +58,10 @@ const AES_GCM_TAG_SIZE: usize = 16;
const HMAC_SIZE: usize = 48; const HMAC_SIZE: usize = 48;
const SESSION_ID_SIZE: usize = 6; const SESSION_ID_SIZE: usize = 6;
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
/// Aribitrary starting value for key derivation chain. /// Aribitrary starting value for key derivation chain.
/// ///
/// It doesn't matter very much what this is, but it's good for it to be unique. /// It doesn't matter very much what this is, but it's good for it to be unique.
@ -66,10 +72,6 @@ const KEY_DERIVATION_CHAIN_STARTING_SALT: [u8; 64] = [
0xb4, 0x32, 0x85, 0xaf, 0x7f, 0x0d, 0xa9, 0x6c, 0x01, 0xfb, 0x72, 0x46, 0xc0, 0x09, 0x58, 0xb8, 0xe0, 0xa8, 0xcf, 0xb1, 0x58, 0x04, 0x6e, 0x32, 0xba, 0xa8, 0xb8, 0xf9, 0x0a, 0xa4, 0xbf, 0x36, 0xb4, 0x32, 0x85, 0xaf, 0x7f, 0x0d, 0xa9, 0x6c, 0x01, 0xfb, 0x72, 0x46, 0xc0, 0x09, 0x58, 0xb8, 0xe0, 0xa8, 0xcf, 0xb1, 0x58, 0x04, 0x6e, 0x32, 0xba, 0xa8, 0xb8, 0xf9, 0x0a, 0xa4, 0xbf, 0x36,
]; ];
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
pub enum Error { pub enum Error {
/// The packet was addressed to an unrecognized local session /// The packet was addressed to an unrecognized local session
UnknownLocalSessionId(SessionId), UnknownLocalSessionId(SessionId),
@ -147,16 +149,17 @@ pub enum ReceiveResult<'a, H: Host> {
/// Packet is valid, no action needs to be taken. /// Packet is valid, no action needs to be taken.
Ok, Ok,
/// Packet is valid and contained a data payload. /// Packet is valid and a data payload was decoded and authenticated.
OkData(&'a [u8]), ///
/// The returned reference is to the filled parts of the data buffer supplied to receive.
OkData(&'a mut [u8]),
/// Packet is valid and a new session was created, also includes a reply to be sent back. /// Packet is valid and a new session was created.
///
/// The session will have already been gated by the accept_new_session() method in the Host trait.
OkNewSession(Session<H>), OkNewSession(Session<H>),
/// Packet appears valid but was ignored as a duplicate. /// Packet apperas valid but was ignored e.g. as a duplicate.
Duplicate,
/// Packet apperas valid but was ignored for another reason.
Ignored, Ignored,
} }
@ -167,20 +170,15 @@ pub struct SessionId(NonZeroU64);
impl SessionId { impl SessionId {
pub const MAX_BIT_MASK: u64 = 0xffffffffffff; pub const MAX_BIT_MASK: u64 = 0xffffffffffff;
pub fn new_from_bytes(b: &[u8]) -> Option<SessionId> { #[inline(always)]
if b.len() >= 6 { pub fn new_from_u64(i: u64) -> Option<SessionId> {
let value = (u32::from_le_bytes(b[..4].try_into().unwrap()) as u64) | (u16::from_le_bytes(b[4..6].try_into().unwrap()) as u64).wrapping_shl(32); debug_assert!(i <= Self::MAX_BIT_MASK);
if value > 0 && value <= Self::MAX_BIT_MASK { NonZeroU64::new(i).map(|i| Self(i))
return Some(Self(NonZeroU64::new(value).unwrap()));
}
}
return None;
} }
pub fn new_from_reader<R: Read>(r: &mut R) -> std::io::Result<Option<SessionId>> { pub fn new_from_reader<R: Read>(r: &mut R) -> std::io::Result<Option<SessionId>> {
let mut tmp = [0_u8; SESSION_ID_SIZE]; let mut tmp = [0_u8; 8];
r.read_exact(&mut tmp)?; r.read_exact(&mut tmp[..SESSION_ID_SIZE]).map(|_| NonZeroU64::new(u64::from_le_bytes(tmp)).map(|i| Self(i)))
Ok(Self::new_from_bytes(&tmp))
} }
pub fn new_random() -> Self { pub fn new_random() -> Self {
@ -193,19 +191,6 @@ impl SessionId {
} }
} }
impl TryFrom<u64> for SessionId {
type Error = self::Error;
#[inline(always)]
fn try_from(value: u64) -> Result<Self, Self::Error> {
if value > 0 && value <= Self::MAX_BIT_MASK {
Ok(Self(NonZeroU64::new(value).unwrap()))
} else {
Err(Error::InvalidParameter)
}
}
}
impl From<SessionId> for u64 { impl From<SessionId> for u64 {
#[inline(always)] #[inline(always)]
fn from(sid: SessionId) -> Self { fn from(sid: SessionId) -> Self {
@ -213,21 +198,31 @@ impl From<SessionId> for u64 {
} }
} }
/// Trait to implement to integrate the session into an application.
pub trait Host: Sized { pub trait Host: Sized {
type AssociatedObject: Sized; /// Arbitrary object that can be associated with sessions.
type AssociatedObject;
/// Arbitrary object that dereferences to the session, such as Arc<Session<Self>>.
type SessionRef: Deref<Target = Session<Self>>; type SessionRef: Deref<Target = Session<Self>>;
/// A buffer containing data read from the network that can be cached.
///
/// This can be e.g. a pooled buffer that automatically returns itself to the pool when dropped.
type IncomingPacketBuffer: AsRef<[u8]>; type IncomingPacketBuffer: AsRef<[u8]>;
/// Get a reference to this host's static public key blob. /// Get a reference to this host's static public key blob.
///
/// This must contain a NIST P-384 public key but can contain other information.
fn get_local_s_public(&self) -> &[u8]; fn get_local_s_public(&self) -> &[u8];
/// Get SHA384(this host's static public key blob) /// Get SHA384(this host's static public key blob), included here so we don't have to calculate it each time.
fn get_local_s_public_hash(&self) -> &[u8; 48]; fn get_local_s_public_hash(&self) -> &[u8; 48];
/// Get a reference to this hosts' static public key's NIST P-384 secret key pair /// Get a reference to this hosts' static public key's NIST P-384 secret key pair
fn get_local_s_keypair_p384(&self) -> &P384KeyPair; fn get_local_s_keypair_p384(&self) -> &P384KeyPair;
/// Extract a NIST P-384 ECC public key from a static public key blob. /// Extract the NIST P-384 ECC public key component from a static public key blob or return None on failure.
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey>; fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey>;
/// Look up a local session by local ID. /// Look up a local session by local ID.
@ -246,23 +241,23 @@ pub struct Session<H: Host> {
pub associated_object: H::AssociatedObject, pub associated_object: H::AssociatedObject,
send_counter: Counter, send_counter: Counter,
psk: Secret<64>, psk: Secret<64>, // Arbitrary PSK provided by external code
ss: Secret<48>, ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer
state: RwLock<MutableState>, state: RwLock<MutableState>, // Mutable parts of state (other than defrag buffers)
remote_s_public_hash: [u8; 48], remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob)
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 64, 32>>, defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 16, 4>>,
} }
struct MutableState { struct MutableState {
remote_session_id: Option<SessionId>, remote_session_id: Option<SessionId>,
keys: [Option<SessionKey>; 2], // current, next keys: [Option<SessionKey>; 2], // current, next (promoted to current on successful decrypt)
offer: Option<EphemeralOffer>, offer: Option<EphemeralOffer>,
} }
/// State information to associate with receiving contexts such as sockets or remote paths/endpoints. /// State information to associate with receiving contexts such as sockets or remote paths/endpoints.
pub struct ReceiveContext<H: Host> { pub struct ReceiveContext<H: Host> {
initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 256>>, initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 128>>,
} }
impl<H: Host> Session<H> { impl<H: Host> Session<H> {
@ -362,26 +357,25 @@ impl<H: Host> ReceiveContext<H> {
current_time: i64, current_time: i64,
) -> Result<ReceiveResult<'a, H>, Error> { ) -> Result<ReceiveResult<'a, H>, Error> {
let incoming_packet = incoming_packet_buf.as_ref(); let incoming_packet = incoming_packet_buf.as_ref();
if incoming_packet.len() < MIN_PACKET_SIZE { if incoming_packet.len() < MIN_PACKET_SIZE {
unlikely_branch(); unlikely_branch();
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
let type_and_frag_info = u16::from_le_bytes(incoming_packet[0..2].try_into().unwrap()); let header_0_8 = memory::u64_from_le_bytes(incoming_packet); // type, frag info, session ID
let local_session_id = SessionId::new_from_bytes(&incoming_packet[2..8]); let counter = memory::u32_from_le_bytes(&incoming_packet[8..]);
let counter = u32::from_le_bytes(incoming_packet[8..12].try_into().unwrap()); let local_session_id = SessionId::new_from_u64(header_0_8.wrapping_shr(16));
let packet_type = (type_and_frag_info as u8) & 15; let packet_type = (header_0_8 as u8) & 15;
let fragment_count = type_and_frag_info.wrapping_shr(4) & 63; let fragment_count = ((header_0_8.wrapping_shr(4) as u8) & 63).wrapping_add(1);
let fragment_no = type_and_frag_info.wrapping_shr(10); // & 63 not needed let fragment_no = (header_0_8.wrapping_shr(10) as u8) & 63;
if fragment_count > 1 { if fragment_count > 1 {
if let Some(local_session_id) = local_session_id { if let Some(local_session_id) = local_session_id {
if fragment_count < (MAX_FRAGMENTS as u16) && fragment_no < fragment_count { if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
if let Some(session) = host.session_lookup(local_session_id) { if let Some(session) = host.session_lookup(local_session_id) {
let mut defrag = session.defrag.lock(); let mut defrag = session.defrag.lock();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32)); let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock drop(defrag); // release lock
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time); return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time);
} }
@ -394,10 +388,10 @@ impl<H: Host> ReceiveContext<H> {
return Err(Error::InvalidPacket); return Err(Error::InvalidPacket);
} }
} else { } else {
if fragment_count < (KEY_EXCHANGE_MAX_FRAGMENTS as u16) && fragment_no < fragment_count { if fragment_count <= (KEY_EXCHANGE_MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
let mut defrag = self.initial_offer_defrag.lock(); let mut defrag = self.initial_offer_defrag.lock();
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32)); let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) { if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
drop(defrag); // release lock drop(defrag); // release lock
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time); return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
} }
@ -444,22 +438,19 @@ impl<H: Host> ReceiveContext<H> {
let state = session.state.read(); let state = session.state.read();
for ki in 0..2 { for ki in 0..2 {
if let Some(key) = state.keys[ki].as_ref() { if let Some(key) = state.keys[ki].as_ref() {
let head = fragments.first().unwrap().as_ref(); let tail = fragments.last().unwrap().as_ref();
debug_assert!(head.len() >= MIN_PACKET_SIZE); if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
return Err(Error::InvalidPacket);
}
let mut c = key.get_receive_cipher(); let mut c = key.get_receive_cipher();
c.init(&get_aes_gcm_nonce(head)); c.init(&get_aes_gcm_nonce(fragments.first().unwrap().as_ref()));
let mut data_len = head.len() - HEADER_SIZE; let mut data_len = 0;
if data_len > data_buf.len() {
unlikely_branch();
key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall);
}
c.crypt(&head[HEADER_SIZE..], &mut data_buf[..data_len]);
for fi in 1..(fragments.len() - 1) { for f in fragments[..(fragments.len() - 1)].iter() {
let f = fragments[fi].as_ref(); let f = f.as_ref();
debug_assert!(f.len() >= HEADER_SIZE); debug_assert!(f.len() >= HEADER_SIZE);
let current_frag_data_start = data_len; let current_frag_data_start = data_len;
data_len += f.len() - HEADER_SIZE; data_len += f.len() - HEADER_SIZE;
@ -471,21 +462,14 @@ impl<H: Host> ReceiveContext<H> {
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]); c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
} }
let tail = fragments.last().unwrap().as_ref();
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
unlikely_branch();
key.return_receive_cipher(c);
return Err(Error::InvalidPacket);
}
let tail_data_len = tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
let current_frag_data_start = data_len; let current_frag_data_start = data_len;
data_len += tail_data_len; data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
if data_len > data_buf.len() { if data_len > data_buf.len() {
unlikely_branch(); unlikely_branch();
key.return_receive_cipher(c); key.return_receive_cipher(c);
return Err(Error::DataBufferTooSmall); return Err(Error::DataBufferTooSmall);
} }
c.crypt(&tail[HEADER_SIZE..tail_data_len], &mut data_buf[current_frag_data_start..data_len]); c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
let tag = c.finish(); let tag = c.finish();
key.return_receive_cipher(c); key.return_receive_cipher(c);
@ -500,7 +484,7 @@ impl<H: Host> ReceiveContext<H> {
} }
if packet_type == PACKET_TYPE_DATA { if packet_type == PACKET_TYPE_DATA {
return Ok(ReceiveResult::OkData(&data_buf[..data_len])); return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
} else { } else {
return Ok(ReceiveResult::Ok); return Ok(ReceiveResult::Ok);
} }
@ -540,8 +524,6 @@ impl<H: Host> ReceiveContext<H> {
return Err(Error::UnknownProtocolVersion); return Err(Error::UnknownProtocolVersion);
} }
let local_s_keypair_p384 = host.get_local_s_keypair_p384();
match packet_type { match packet_type {
PACKET_TYPE_KEY_OFFER => { PACKET_TYPE_KEY_OFFER => {
// alice (remote) -> bob (local) // alice (remote) -> bob (local)
@ -565,7 +547,7 @@ impl<H: Host> ReceiveContext<H> {
} }
let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
.and_then(|pk| local_s_keypair_p384.agree(&pk).map(move |s| (pk, s))) .and_then(|pk| host.get_local_s_keypair_p384().agree(&pk).map(move |s| (pk, s)))
.ok_or(Error::FailedAuthentication)?; .ok_or(Error::FailedAuthentication)?;
let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes())); let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes()));
@ -588,7 +570,7 @@ impl<H: Host> ReceiveContext<H> {
} }
let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?; let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?;
let ss = local_s_keypair_p384.agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?; let ss = host.get_local_s_keypair_p384().agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?;
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes())); let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
@ -668,7 +650,7 @@ impl<H: Host> ReceiveContext<H> {
} }
(MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len() (MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len()
}; };
let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_OFFER, session.id.into(), reply_counter); let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter);
reply_buf[..HEADER_SIZE].copy_from_slice(&header); reply_buf[..HEADER_SIZE].copy_from_slice(&header);
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true); let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true);
@ -715,7 +697,7 @@ impl<H: Host> ReceiveContext<H> {
let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)]) let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
.and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s))) .and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s)))
.ok_or(Error::FailedAuthentication)?; .ok_or(Error::FailedAuthentication)?;
let se0 = local_s_keypair_p384.agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?; let se0 = host.get_local_s_keypair_p384().agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?;
let key = Secret(hmac_sha512( let key = Secret(hmac_sha512(
session.psk.as_bytes(), session.psk.as_bytes(),
@ -956,15 +938,16 @@ impl EphemeralOffer {
#[inline(always)] #[inline(always)]
fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] { fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] {
let fragment_count = (packet_len / mtu) + (((packet_len % mtu) != 0) as usize); let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
debug_assert!(mtu >= MIN_MTU); debug_assert!(mtu >= MIN_MTU);
debug_assert!(packet_len >= HEADER_SIZE); debug_assert!(packet_len >= HEADER_SIZE);
debug_assert!(fragment_count <= MAX_FRAGMENTS); debug_assert!(fragment_count <= MAX_FRAGMENTS);
debug_assert!(fragment_count > 0);
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
// Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter // Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter
let mut header = (fragment_count.wrapping_shl(4) | (packet_type as usize)) as u128; let mut header = ((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u128;
header |= recipient_session_id.wrapping_shl(16) as u128; header |= recipient_session_id.wrapping_shl(16) as u128;
header |= (counter.to_u32() as u128).wrapping_shl(64); header |= (counter.to_u32() as u128).wrapping_shl(64);
header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap() header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap()
@ -980,7 +963,8 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(mut send: SendFunctio
if fragment_end < packet_len { if fragment_end < packet_len {
fragment_start = fragment_end - HEADER_SIZE; fragment_start = fragment_end - HEADER_SIZE;
fragment_end = (fragment_start + mtu).min(packet_len); fragment_end = (fragment_start + mtu).min(packet_len);
header[1] += 0x04; // increment fragment number, at bit 2 in byte 1 since type/fragment u16 is little-endian debug_assert!(header[1].wrapping_shr(2) < 63);
header[1] += 0x04; // increment fragment number in least significant 6 bits of byte 1
packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header); packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header);
} else { } else {
debug_assert_eq!(fragment_end, packet_len); debug_assert_eq!(fragment_end, packet_len);
@ -1104,171 +1088,154 @@ fn get_aes_gcm_nonce(packet: &[u8]) -> [u8; 16] {
tmp tmp
} }
#[cfg(test)]
mod tests {}
/*
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::rc::Rc; use parking_lot::Mutex;
use std::collections::LinkedList;
use std::sync::Arc;
#[allow(unused_imports)] #[allow(unused_imports)]
use super::*; use super::*;
struct TestHost {
local_s: P384KeyPair,
local_s_hash: [u8; 48],
psk: Secret<64>,
session: Mutex<Option<Arc<Session<Box<TestHost>>>>>,
session_id_counter: Mutex<u64>,
pub queue: Mutex<LinkedList<Vec<u8>>>,
pub this_name: &'static str,
pub other_name: &'static str,
}
impl TestHost {
fn new(psk: Secret<64>, this_name: &'static str, other_name: &'static str) -> Self {
let local_s = P384KeyPair::generate();
let local_s_hash = SHA384::hash(local_s.public_key_bytes());
Self {
local_s,
local_s_hash,
psk,
session: Mutex::new(None),
session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1),
queue: Mutex::new(LinkedList::new()),
this_name,
other_name,
}
}
}
impl Host for Box<TestHost> {
type AssociatedObject = u32;
type SessionRef = Arc<Session<Box<TestHost>>>;
type IncomingPacketBuffer = Vec<u8>;
fn get_local_s_public(&self) -> &[u8] {
self.local_s.public_key_bytes()
}
fn get_local_s_public_hash(&self) -> &[u8; 48] {
&self.local_s_hash
}
fn get_local_s_keypair_p384(&self) -> &P384KeyPair {
&self.local_s
}
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey> {
P384PublicKey::from_bytes(static_public)
}
fn session_lookup(&self, local_session_id: SessionId) -> Option<Self::SessionRef> {
self.session.lock().as_ref().and_then(|s| {
if s.id == local_session_id {
Some(s.clone())
} else {
None
}
})
}
fn accept_new_session(&self, _: &[u8], _: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)> {
loop {
let mut new_id = self.session_id_counter.lock();
*new_id += 1;
return Some((SessionId::new_from_u64(*new_id).unwrap(), self.psk.clone(), 0));
}
}
}
#[allow(unused_variables)]
#[test] #[test]
fn alice_bob() { fn establish_session() {
let psk: Secret<64> = Secret::default(); let mut psk: Secret<64> = Secret::default();
let mut a_buffer = [0_u8; 1500]; random::fill_bytes_secure(&mut psk.0);
let mut b_buffer = [0_u8; 1500]; let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob"));
let alice_static_keypair = P384KeyPair::generate(); let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice"));
let bob_static_keypair = P384KeyPair::generate(); let rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new());
let outgoing_obfuscator_to_alice = Obfuscator::new(alice_static_keypair.public_key_bytes()); let mut data_buf = [0_u8; 4096];
let outgoing_obfuscator_to_bob = Obfuscator::new(bob_static_keypair.public_key_bytes());
let mut from_alice: Vec<Vec<u8>> = Vec::new(); //println!("zssp: size of session (bytes): {}", std::mem::size_of::<Session<Box<TestHost>>>());
let mut from_bob: Vec<Vec<u8>> = Vec::new();
// Session TO Bob, on Alice's side. let _ = alice_host.session.lock().insert(Arc::new(
let (alice, packet) = Session::new( Session::new(
&mut a_buffer, &alice_host,
|data| bob_host.queue.lock().push_front(data.to_vec()),
SessionId::new_random(), SessionId::new_random(),
alice_static_keypair.public_key_bytes(), bob_host.local_s.public_key_bytes(),
&alice_static_keypair, &[],
bob_static_keypair.public_key_bytes(),
bob_static_keypair.public_key(),
&psk, &psk,
0, 1,
0, 1280,
1,
true, true,
) )
.unwrap(); .unwrap(),
let alice = Rc::new(alice); ));
from_alice.push(packet.to_vec());
// Session FROM Alice, on Bob's side.
let mut bob: Option<Rc<Session<u32>>> = None;
let mut ts = 0;
for _ in 0..256 { for _ in 0..256 {
while !from_alice.is_empty() || !from_bob.is_empty() { for host in [&alice_host, &bob_host] {
if let Some(packet) = from_alice.pop() { let send_to_other = |data: &mut [u8]| {
let r = Session::receive( if std::ptr::eq(host, &alice_host) {
packet.as_slice(), bob_host.queue.lock().push_front(data.to_vec());
&mut b_buffer,
&bob_static_keypair,
&outgoing_obfuscator_to_bob,
|p: &[u8; P384_PUBLIC_KEY_SIZE]| P384PublicKey::from_bytes(p),
|sid| {
println!("[noise] [bob] session ID: {}", u64::from(sid));
if let Some(bob) = bob.as_ref() {
if sid == bob.id {
Some(bob.clone())
} else { } else {
None alice_host.queue.lock().push_front(data.to_vec());
} }
} else { };
None
} loop {
}, if let Some(qi) = host.queue.lock().pop_back() {
|_: &[u8; P384_PUBLIC_KEY_SIZE]| { let qi_len = qi.len();
if bob.is_none() { ts += 1;
Some((SessionId::new_random(), psk.clone(), 0)) let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, true, ts);
} else { if r.is_ok() {
panic!("[noise] [bob] Bob received a second new session request from Alice"); let r = r.unwrap();
}
},
0,
true,
);
if let Ok(r) = r {
match r { match r {
ReceiveResult::OkData(data, counter) => {
println!("[noise] [bob] DATA len=={} counter=={}", data.len(), counter);
}
ReceiveResult::OkSendReply(p) => {
println!("[noise] [bob] OK (reply: {} bytes)", p.len());
from_bob.push(p.to_vec());
}
ReceiveResult::OkNewSession(ns, p) => {
if bob.is_some() {
panic!("[noise] [bob] attempt to create new session on Bob's side when he already has one");
}
let id: u64 = ns.id.into();
let _ = bob.replace(Rc::new(ns));
from_bob.push(p.to_vec());
println!("[noise] [bob] NEW SESSION {}", id);
}
ReceiveResult::Ok => { ReceiveResult::Ok => {
println!("[noise] [bob] OK"); println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len);
} }
ReceiveResult::Duplicate => { ReceiveResult::OkData(data) => {
println!("[noise] [bob] duplicate packet"); println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len());
}
ReceiveResult::OkNewSession(new_session) => {
println!("zssp: {} => {} ({}): OkNewSession ({})", host.other_name, host.this_name, qi_len, u64::from(new_session.id));
let mut hs = host.session.lock();
assert!(hs.is_none());
let _ = hs.insert(Arc::new(new_session));
} }
ReceiveResult::Ignored => { ReceiveResult::Ignored => {
println!("[noise] [bob] ignored packet"); println!("zssp: {} => {} ({}): Ignored", host.other_name, host.this_name, qi_len);
} }
} }
} else { } else {
println!("ERROR (bob): {}", r.err().unwrap().to_string()); println!("zssp: {} => {}: error: {}", host.other_name, host.this_name, r.err().unwrap().to_string());
panic!();
}
}
if let Some(packet) = from_bob.pop() {
let r = Session::receive(
packet.as_slice(),
&mut b_buffer,
&alice_static_keypair,
&outgoing_obfuscator_to_alice,
|p: &[u8; P384_PUBLIC_KEY_SIZE]| P384PublicKey::from_bytes(p),
|sid| {
println!("[noise] [alice] session ID: {}", u64::from(sid));
if sid == alice.id {
Some(alice.clone())
} else {
panic!("[noise] [alice] received from Bob addressed to unknown session ID, not Alice");
}
},
|_: &[u8; P384_PUBLIC_KEY_SIZE]| {
panic!("[noise] [alice] Alice received an unexpected new session request from Bob");
},
0,
true,
);
if let Ok(r) = r {
match r {
ReceiveResult::OkData(data, counter) => {
println!("[noise] [alice] DATA len=={} counter=={}", data.len(), counter);
}
ReceiveResult::OkSendReply(p) => {
println!("[noise] [alice] OK (reply: {} bytes)", p.len());
from_alice.push(p.to_vec());
}
ReceiveResult::OkNewSession(_, _) => {
panic!("[noise] [alice] attempt to create new session on Alice's side; Bob should not initiate");
}
ReceiveResult::Ok => {
println!("[noise] [alice] OK");
}
ReceiveResult::Duplicate => {
println!("[noise] [alice] duplicate packet");
}
ReceiveResult::Ignored => {
println!("[noise] [alice] ignored packet");
}
} }
} else { } else {
println!("ERROR (alice): {}", r.err().unwrap().to_string()); break;
panic!(); }
} }
}
}
if (random::next_u32_secure() & 1) == 0 {
from_alice.push(alice.send(&mut a_buffer, &[0_u8; 16]).unwrap().to_vec());
} else if bob.is_some() {
from_bob.push(bob.as_ref().unwrap().send(&mut b_buffer, &[0_u8; 16]).unwrap().to_vec());
} }
} }
} }
} }
*/

View file

@ -11,16 +11,16 @@ use crate::arrayvec::ArrayVec;
pub struct GatherArray<T, const C: usize> { pub struct GatherArray<T, const C: usize> {
a: [MaybeUninit<T>; C], a: [MaybeUninit<T>; C],
have_bits: u64, have_bits: u64,
have_count: u32, have_count: u8,
goal: u32, goal: u8,
} }
impl<T, const C: usize> GatherArray<T, C> { impl<T, const C: usize> GatherArray<T, C> {
/// Create a new gather array, which must be initialized prior to use. /// Create a new gather array, which must be initialized prior to use.
#[inline(always)] #[inline(always)]
pub fn new(goal: u32) -> Self { pub fn new(goal: u8) -> Self {
assert!(C <= 64); assert!(C <= 64);
assert!(goal <= (C as u32)); assert!(goal <= (C as u8));
assert_eq!(size_of::<[T; C]>(), size_of::<[MaybeUninit<T>; C]>()); assert_eq!(size_of::<[T; C]>(), size_of::<[MaybeUninit<T>; C]>());
Self { Self {
a: unsafe { MaybeUninit::uninit().assume_init() }, a: unsafe { MaybeUninit::uninit().assume_init() },
@ -32,10 +32,10 @@ impl<T, const C: usize> GatherArray<T, C> {
/// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here. /// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here.
#[inline(always)] #[inline(always)]
pub fn add(&mut self, index: u32, value: T) -> Option<ArrayVec<T, C>> { pub fn add(&mut self, index: u8, value: T) -> Option<ArrayVec<T, C>> {
if index < self.goal { if index < self.goal {
let mut have = self.have_bits; let mut have = self.have_bits;
let got = 1u64.wrapping_shl(index); let got = 1u64.wrapping_shl(index as u32);
if (have & got) == 0 { if (have & got) == 0 {
have |= got; have |= got;
self.have_bits = have; self.have_bits = have;
@ -64,7 +64,7 @@ impl<T, const C: usize> Drop for GatherArray<T, C> {
fn drop(&mut self) { fn drop(&mut self) {
let have = self.have_bits; let have = self.have_bits;
for i in 0..self.goal { for i in 0..self.goal {
if (have & 1u64.wrapping_shl(i)) != 0 { if (have & 1u64.wrapping_shl(i as u32)) != 0 {
unsafe { self.a.get_unchecked_mut(i as usize).assume_init_drop() }; unsafe { self.a.get_unchecked_mut(i as usize).assume_init_drop() };
} }
} }
@ -78,8 +78,8 @@ mod tests {
#[test] #[test]
fn gather_array() { fn gather_array() {
for goal in 2..64 { for goal in 2u8..64u8 {
let mut m = GatherArray::<u32, 64>::new(goal); let mut m = GatherArray::<u8, 64>::new(goal);
for x in 0..(goal - 1) { for x in 0..(goal - 1) {
assert!(m.add(x, x).is_none()); assert!(m.add(x, x).is_none());
} }

View file

@ -1,5 +1,6 @@
pub mod arrayvec; pub mod arrayvec;
pub mod gatherarray; pub mod gatherarray;
pub mod hex; pub mod hex;
pub mod memory;
pub mod ringbuffermap; pub mod ringbuffermap;
pub mod varint; pub mod varint;

143
utils/src/memory.rs Normal file
View file

@ -0,0 +1,143 @@
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
#[allow(unused)]
mod fast_int_memory_access {
#[inline(always)]
pub fn u64_from_le_bytes(b: &[u8]) -> u64 {
assert!(b.len() >= 8);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u32_from_le_bytes(b: &[u8]) -> u32 {
assert!(b.len() >= 4);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u16_from_le_bytes(b: &[u8]) -> u16 {
assert!(b.len() >= 2);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i64_from_le_bytes(b: &[u8]) -> i64 {
assert!(b.len() >= 8);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i32_from_le_bytes(b: &[u8]) -> i32 {
assert!(b.len() >= 4);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn i16_from_le_bytes(b: &[u8]) -> i16 {
assert!(b.len() >= 2);
unsafe { *b.as_ptr().cast() }
}
#[inline(always)]
pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
assert!(b.len() >= 8);
unsafe { *b.as_ptr().cast::<u64>() }.swap_bytes()
}
#[inline(always)]
pub fn u32_from_be_bytes(b: &[u8]) -> u32 {
assert!(b.len() >= 4);
unsafe { *b.as_ptr().cast::<u32>() }.swap_bytes()
}
#[inline(always)]
pub fn u16_from_be_bytes(b: &[u8]) -> u16 {
assert!(b.len() >= 2);
unsafe { *b.as_ptr().cast::<u16>() }.swap_bytes()
}
#[inline(always)]
pub fn i64_from_be_bytes(b: &[u8]) -> i64 {
assert!(b.len() >= 8);
unsafe { *b.as_ptr().cast::<i64>() }.swap_bytes()
}
#[inline(always)]
pub fn i32_from_be_bytes(b: &[u8]) -> i32 {
assert!(b.len() >= 4);
unsafe { *b.as_ptr().cast::<i32>() }.swap_bytes()
}
#[inline(always)]
pub fn i16_from_be_bytes(b: &[u8]) -> i16 {
assert!(b.len() >= 2);
unsafe { *b.as_ptr().cast::<i16>() }.swap_bytes()
}
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
#[allow(unused)]
mod fast_int_memory_access {
#[inline(always)]
pub fn u64_from_le_bytes(b: &[u8]) -> u64 {
u64::from_le_bytes(b[..8].try_into().unwrap())
}
#[inline(always)]
pub fn u32_from_le_bytes(b: &[u8]) -> u32 {
u32::from_le_bytes(b[..4].try_into().unwrap())
}
#[inline(always)]
pub fn u16_from_le_bytes(b: &[u8]) -> u16 {
u16::from_le_bytes(b[..2].try_into().unwrap())
}
#[inline(always)]
pub fn i64_from_le_bytes(b: &[u8]) -> i64 {
i64::from_le_bytes(b[..8].try_into().unwrap())
}
#[inline(always)]
pub fn i32_from_le_bytes(b: &[u8]) -> i32 {
i32::from_le_bytes(b[..4].try_into().unwrap())
}
#[inline(always)]
pub fn i16_from_le_bytes(b: &[u8]) -> i16 {
i16::from_le_bytes(b[..2].try_into().unwrap())
}
#[inline(always)]
pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
u64::from_be_bytes(b[..8].try_into().unwrap())
}
#[inline(always)]
pub fn u32_from_be_bytes(b: &[u8]) -> u32 {
u32::from_be_bytes(b[..4].try_into().unwrap())
}
#[inline(always)]
pub fn u16_from_be_bytes(b: &[u8]) -> u16 {
u16::from_be_bytes(b[..2].try_into().unwrap())
}
#[inline(always)]
pub fn i64_from_be_bytes(b: &[u8]) -> i64 {
i64::from_be_bytes(b[..8].try_into().unwrap())
}
#[inline(always)]
pub fn i32_from_be_bytes(b: &[u8]) -> i32 {
i32::from_be_bytes(b[..4].try_into().unwrap())
}
#[inline(always)]
pub fn i16_from_be_bytes(b: &[u8]) -> i16 {
i16::from_be_bytes(b[..2].try_into().unwrap())
}
}
pub use fast_int_memory_access::*;

View file

@ -106,13 +106,15 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
#[inline] #[inline]
pub fn new(salt: u32) -> Self { pub fn new(salt: u32) -> Self {
Self { Self {
entries: std::array::from_fn(|_| Entry::<K, V> { entries: {
key: MaybeUninit::uninit(), let mut entries: [Entry<K, V>; C] = unsafe { MaybeUninit::uninit().assume_init() };
value: MaybeUninit::uninit(), for e in entries.iter_mut() {
bucket: -1, e.bucket = -1;
next: -1, e.next = -1;
prev: -1, e.prev = -1;
}), }
entries
},
buckets: [-1; B], buckets: [-1; B],
entry_ptr: 0, entry_ptr: 0,
salt, salt,