mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-06-08 05:23:44 +02:00
Session works again, and some optimization.
This commit is contained in:
parent
3770fcdc83
commit
06573c1ea8
5 changed files with 374 additions and 261 deletions
|
@ -12,6 +12,7 @@ use crate::random;
|
|||
use crate::secret::Secret;
|
||||
|
||||
use zerotier_utils::gatherarray::GatherArray;
|
||||
use zerotier_utils::memory;
|
||||
use zerotier_utils::ringbuffermap::RingBufferMap;
|
||||
use zerotier_utils::varint;
|
||||
|
||||
|
@ -41,6 +42,7 @@ const OFFER_RATE_LIMIT_MS: i64 = 2000;
|
|||
/// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication)
|
||||
const SESSION_PROTOCOL_VERSION: u8 = 0x00;
|
||||
|
||||
// Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use
|
||||
const PACKET_TYPE_DATA: u8 = 0;
|
||||
const PACKET_TYPE_NOP: u8 = 1;
|
||||
const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice"
|
||||
|
@ -56,6 +58,10 @@ const AES_GCM_TAG_SIZE: usize = 16;
|
|||
const HMAC_SIZE: usize = 48;
|
||||
const SESSION_ID_SIZE: usize = 6;
|
||||
|
||||
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
|
||||
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
|
||||
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
|
||||
|
||||
/// Aribitrary starting value for key derivation chain.
|
||||
///
|
||||
/// It doesn't matter very much what this is, but it's good for it to be unique.
|
||||
|
@ -66,10 +72,6 @@ const KEY_DERIVATION_CHAIN_STARTING_SALT: [u8; 64] = [
|
|||
0xb4, 0x32, 0x85, 0xaf, 0x7f, 0x0d, 0xa9, 0x6c, 0x01, 0xfb, 0x72, 0x46, 0xc0, 0x09, 0x58, 0xb8, 0xe0, 0xa8, 0xcf, 0xb1, 0x58, 0x04, 0x6e, 0x32, 0xba, 0xa8, 0xb8, 0xf9, 0x0a, 0xa4, 0xbf, 0x36,
|
||||
];
|
||||
|
||||
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
|
||||
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
|
||||
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
|
||||
|
||||
pub enum Error {
|
||||
/// The packet was addressed to an unrecognized local session
|
||||
UnknownLocalSessionId(SessionId),
|
||||
|
@ -147,16 +149,17 @@ pub enum ReceiveResult<'a, H: Host> {
|
|||
/// Packet is valid, no action needs to be taken.
|
||||
Ok,
|
||||
|
||||
/// Packet is valid and contained a data payload.
|
||||
OkData(&'a [u8]),
|
||||
/// Packet is valid and a data payload was decoded and authenticated.
|
||||
///
|
||||
/// The returned reference is to the filled parts of the data buffer supplied to receive.
|
||||
OkData(&'a mut [u8]),
|
||||
|
||||
/// Packet is valid and a new session was created, also includes a reply to be sent back.
|
||||
/// Packet is valid and a new session was created.
|
||||
///
|
||||
/// The session will have already been gated by the accept_new_session() method in the Host trait.
|
||||
OkNewSession(Session<H>),
|
||||
|
||||
/// Packet appears valid but was ignored as a duplicate.
|
||||
Duplicate,
|
||||
|
||||
/// Packet apperas valid but was ignored for another reason.
|
||||
/// Packet apperas valid but was ignored e.g. as a duplicate.
|
||||
Ignored,
|
||||
}
|
||||
|
||||
|
@ -167,20 +170,15 @@ pub struct SessionId(NonZeroU64);
|
|||
impl SessionId {
|
||||
pub const MAX_BIT_MASK: u64 = 0xffffffffffff;
|
||||
|
||||
pub fn new_from_bytes(b: &[u8]) -> Option<SessionId> {
|
||||
if b.len() >= 6 {
|
||||
let value = (u32::from_le_bytes(b[..4].try_into().unwrap()) as u64) | (u16::from_le_bytes(b[4..6].try_into().unwrap()) as u64).wrapping_shl(32);
|
||||
if value > 0 && value <= Self::MAX_BIT_MASK {
|
||||
return Some(Self(NonZeroU64::new(value).unwrap()));
|
||||
}
|
||||
}
|
||||
return None;
|
||||
#[inline(always)]
|
||||
pub fn new_from_u64(i: u64) -> Option<SessionId> {
|
||||
debug_assert!(i <= Self::MAX_BIT_MASK);
|
||||
NonZeroU64::new(i).map(|i| Self(i))
|
||||
}
|
||||
|
||||
pub fn new_from_reader<R: Read>(r: &mut R) -> std::io::Result<Option<SessionId>> {
|
||||
let mut tmp = [0_u8; SESSION_ID_SIZE];
|
||||
r.read_exact(&mut tmp)?;
|
||||
Ok(Self::new_from_bytes(&tmp))
|
||||
let mut tmp = [0_u8; 8];
|
||||
r.read_exact(&mut tmp[..SESSION_ID_SIZE]).map(|_| NonZeroU64::new(u64::from_le_bytes(tmp)).map(|i| Self(i)))
|
||||
}
|
||||
|
||||
pub fn new_random() -> Self {
|
||||
|
@ -193,19 +191,6 @@ impl SessionId {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u64> for SessionId {
|
||||
type Error = self::Error;
|
||||
|
||||
#[inline(always)]
|
||||
fn try_from(value: u64) -> Result<Self, Self::Error> {
|
||||
if value > 0 && value <= Self::MAX_BIT_MASK {
|
||||
Ok(Self(NonZeroU64::new(value).unwrap()))
|
||||
} else {
|
||||
Err(Error::InvalidParameter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SessionId> for u64 {
|
||||
#[inline(always)]
|
||||
fn from(sid: SessionId) -> Self {
|
||||
|
@ -213,21 +198,31 @@ impl From<SessionId> for u64 {
|
|||
}
|
||||
}
|
||||
|
||||
/// Trait to implement to integrate the session into an application.
|
||||
pub trait Host: Sized {
|
||||
type AssociatedObject: Sized;
|
||||
/// Arbitrary object that can be associated with sessions.
|
||||
type AssociatedObject;
|
||||
|
||||
/// Arbitrary object that dereferences to the session, such as Arc<Session<Self>>.
|
||||
type SessionRef: Deref<Target = Session<Self>>;
|
||||
|
||||
/// A buffer containing data read from the network that can be cached.
|
||||
///
|
||||
/// This can be e.g. a pooled buffer that automatically returns itself to the pool when dropped.
|
||||
type IncomingPacketBuffer: AsRef<[u8]>;
|
||||
|
||||
/// Get a reference to this host's static public key blob.
|
||||
///
|
||||
/// This must contain a NIST P-384 public key but can contain other information.
|
||||
fn get_local_s_public(&self) -> &[u8];
|
||||
|
||||
/// Get SHA384(this host's static public key blob)
|
||||
/// Get SHA384(this host's static public key blob), included here so we don't have to calculate it each time.
|
||||
fn get_local_s_public_hash(&self) -> &[u8; 48];
|
||||
|
||||
/// Get a reference to this hosts' static public key's NIST P-384 secret key pair
|
||||
fn get_local_s_keypair_p384(&self) -> &P384KeyPair;
|
||||
|
||||
/// Extract a NIST P-384 ECC public key from a static public key blob.
|
||||
/// Extract the NIST P-384 ECC public key component from a static public key blob or return None on failure.
|
||||
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey>;
|
||||
|
||||
/// Look up a local session by local ID.
|
||||
|
@ -246,23 +241,23 @@ pub struct Session<H: Host> {
|
|||
pub associated_object: H::AssociatedObject,
|
||||
|
||||
send_counter: Counter,
|
||||
psk: Secret<64>,
|
||||
ss: Secret<48>,
|
||||
state: RwLock<MutableState>,
|
||||
remote_s_public_hash: [u8; 48],
|
||||
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE],
|
||||
defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 64, 32>>,
|
||||
psk: Secret<64>, // Arbitrary PSK provided by external code
|
||||
ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer
|
||||
state: RwLock<MutableState>, // Mutable parts of state (other than defrag buffers)
|
||||
remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob)
|
||||
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
|
||||
defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 16, 4>>,
|
||||
}
|
||||
|
||||
struct MutableState {
|
||||
remote_session_id: Option<SessionId>,
|
||||
keys: [Option<SessionKey>; 2], // current, next
|
||||
keys: [Option<SessionKey>; 2], // current, next (promoted to current on successful decrypt)
|
||||
offer: Option<EphemeralOffer>,
|
||||
}
|
||||
|
||||
/// State information to associate with receiving contexts such as sockets or remote paths/endpoints.
|
||||
pub struct ReceiveContext<H: Host> {
|
||||
initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 256>>,
|
||||
initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 128>>,
|
||||
}
|
||||
|
||||
impl<H: Host> Session<H> {
|
||||
|
@ -362,26 +357,25 @@ impl<H: Host> ReceiveContext<H> {
|
|||
current_time: i64,
|
||||
) -> Result<ReceiveResult<'a, H>, Error> {
|
||||
let incoming_packet = incoming_packet_buf.as_ref();
|
||||
|
||||
if incoming_packet.len() < MIN_PACKET_SIZE {
|
||||
unlikely_branch();
|
||||
return Err(Error::InvalidPacket);
|
||||
}
|
||||
|
||||
let type_and_frag_info = u16::from_le_bytes(incoming_packet[0..2].try_into().unwrap());
|
||||
let local_session_id = SessionId::new_from_bytes(&incoming_packet[2..8]);
|
||||
let counter = u32::from_le_bytes(incoming_packet[8..12].try_into().unwrap());
|
||||
let packet_type = (type_and_frag_info as u8) & 15;
|
||||
let fragment_count = type_and_frag_info.wrapping_shr(4) & 63;
|
||||
let fragment_no = type_and_frag_info.wrapping_shr(10); // & 63 not needed
|
||||
let header_0_8 = memory::u64_from_le_bytes(incoming_packet); // type, frag info, session ID
|
||||
let counter = memory::u32_from_le_bytes(&incoming_packet[8..]);
|
||||
let local_session_id = SessionId::new_from_u64(header_0_8.wrapping_shr(16));
|
||||
let packet_type = (header_0_8 as u8) & 15;
|
||||
let fragment_count = ((header_0_8.wrapping_shr(4) as u8) & 63).wrapping_add(1);
|
||||
let fragment_no = (header_0_8.wrapping_shr(10) as u8) & 63;
|
||||
|
||||
if fragment_count > 1 {
|
||||
if let Some(local_session_id) = local_session_id {
|
||||
if fragment_count < (MAX_FRAGMENTS as u16) && fragment_no < fragment_count {
|
||||
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
|
||||
if let Some(session) = host.session_lookup(local_session_id) {
|
||||
let mut defrag = session.defrag.lock();
|
||||
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32));
|
||||
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) {
|
||||
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
|
||||
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
|
||||
drop(defrag); // release lock
|
||||
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time);
|
||||
}
|
||||
|
@ -394,10 +388,10 @@ impl<H: Host> ReceiveContext<H> {
|
|||
return Err(Error::InvalidPacket);
|
||||
}
|
||||
} else {
|
||||
if fragment_count < (KEY_EXCHANGE_MAX_FRAGMENTS as u16) && fragment_no < fragment_count {
|
||||
if fragment_count <= (KEY_EXCHANGE_MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
|
||||
let mut defrag = self.initial_offer_defrag.lock();
|
||||
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32));
|
||||
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) {
|
||||
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
|
||||
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
|
||||
drop(defrag); // release lock
|
||||
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
|
||||
}
|
||||
|
@ -444,22 +438,19 @@ impl<H: Host> ReceiveContext<H> {
|
|||
let state = session.state.read();
|
||||
for ki in 0..2 {
|
||||
if let Some(key) = state.keys[ki].as_ref() {
|
||||
let head = fragments.first().unwrap().as_ref();
|
||||
debug_assert!(head.len() >= MIN_PACKET_SIZE);
|
||||
let tail = fragments.last().unwrap().as_ref();
|
||||
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
|
||||
unlikely_branch();
|
||||
return Err(Error::InvalidPacket);
|
||||
}
|
||||
|
||||
let mut c = key.get_receive_cipher();
|
||||
c.init(&get_aes_gcm_nonce(head));
|
||||
c.init(&get_aes_gcm_nonce(fragments.first().unwrap().as_ref()));
|
||||
|
||||
let mut data_len = head.len() - HEADER_SIZE;
|
||||
if data_len > data_buf.len() {
|
||||
unlikely_branch();
|
||||
key.return_receive_cipher(c);
|
||||
return Err(Error::DataBufferTooSmall);
|
||||
}
|
||||
c.crypt(&head[HEADER_SIZE..], &mut data_buf[..data_len]);
|
||||
let mut data_len = 0;
|
||||
|
||||
for fi in 1..(fragments.len() - 1) {
|
||||
let f = fragments[fi].as_ref();
|
||||
for f in fragments[..(fragments.len() - 1)].iter() {
|
||||
let f = f.as_ref();
|
||||
debug_assert!(f.len() >= HEADER_SIZE);
|
||||
let current_frag_data_start = data_len;
|
||||
data_len += f.len() - HEADER_SIZE;
|
||||
|
@ -471,21 +462,14 @@ impl<H: Host> ReceiveContext<H> {
|
|||
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
|
||||
}
|
||||
|
||||
let tail = fragments.last().unwrap().as_ref();
|
||||
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
|
||||
unlikely_branch();
|
||||
key.return_receive_cipher(c);
|
||||
return Err(Error::InvalidPacket);
|
||||
}
|
||||
let tail_data_len = tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
|
||||
let current_frag_data_start = data_len;
|
||||
data_len += tail_data_len;
|
||||
data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
|
||||
if data_len > data_buf.len() {
|
||||
unlikely_branch();
|
||||
key.return_receive_cipher(c);
|
||||
return Err(Error::DataBufferTooSmall);
|
||||
}
|
||||
c.crypt(&tail[HEADER_SIZE..tail_data_len], &mut data_buf[current_frag_data_start..data_len]);
|
||||
c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
|
||||
|
||||
let tag = c.finish();
|
||||
key.return_receive_cipher(c);
|
||||
|
@ -500,7 +484,7 @@ impl<H: Host> ReceiveContext<H> {
|
|||
}
|
||||
|
||||
if packet_type == PACKET_TYPE_DATA {
|
||||
return Ok(ReceiveResult::OkData(&data_buf[..data_len]));
|
||||
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
|
||||
} else {
|
||||
return Ok(ReceiveResult::Ok);
|
||||
}
|
||||
|
@ -540,8 +524,6 @@ impl<H: Host> ReceiveContext<H> {
|
|||
return Err(Error::UnknownProtocolVersion);
|
||||
}
|
||||
|
||||
let local_s_keypair_p384 = host.get_local_s_keypair_p384();
|
||||
|
||||
match packet_type {
|
||||
PACKET_TYPE_KEY_OFFER => {
|
||||
// alice (remote) -> bob (local)
|
||||
|
@ -565,7 +547,7 @@ impl<H: Host> ReceiveContext<H> {
|
|||
}
|
||||
|
||||
let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
|
||||
.and_then(|pk| local_s_keypair_p384.agree(&pk).map(move |s| (pk, s)))
|
||||
.and_then(|pk| host.get_local_s_keypair_p384().agree(&pk).map(move |s| (pk, s)))
|
||||
.ok_or(Error::FailedAuthentication)?;
|
||||
|
||||
let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes()));
|
||||
|
@ -588,7 +570,7 @@ impl<H: Host> ReceiveContext<H> {
|
|||
}
|
||||
|
||||
let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?;
|
||||
let ss = local_s_keypair_p384.agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?;
|
||||
let ss = host.get_local_s_keypair_p384().agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?;
|
||||
|
||||
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
|
||||
|
||||
|
@ -668,7 +650,7 @@ impl<H: Host> ReceiveContext<H> {
|
|||
}
|
||||
(MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len()
|
||||
};
|
||||
let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_OFFER, session.id.into(), reply_counter);
|
||||
let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter);
|
||||
reply_buf[..HEADER_SIZE].copy_from_slice(&header);
|
||||
|
||||
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true);
|
||||
|
@ -715,7 +697,7 @@ impl<H: Host> ReceiveContext<H> {
|
|||
let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
|
||||
.and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s)))
|
||||
.ok_or(Error::FailedAuthentication)?;
|
||||
let se0 = local_s_keypair_p384.agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?;
|
||||
let se0 = host.get_local_s_keypair_p384().agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?;
|
||||
|
||||
let key = Secret(hmac_sha512(
|
||||
session.psk.as_bytes(),
|
||||
|
@ -956,15 +938,16 @@ impl EphemeralOffer {
|
|||
|
||||
#[inline(always)]
|
||||
fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] {
|
||||
let fragment_count = (packet_len / mtu) + (((packet_len % mtu) != 0) as usize);
|
||||
let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
|
||||
debug_assert!(mtu >= MIN_MTU);
|
||||
debug_assert!(packet_len >= HEADER_SIZE);
|
||||
debug_assert!(fragment_count <= MAX_FRAGMENTS);
|
||||
debug_assert!(fragment_count > 0);
|
||||
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
|
||||
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
|
||||
|
||||
// Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter
|
||||
let mut header = (fragment_count.wrapping_shl(4) | (packet_type as usize)) as u128;
|
||||
let mut header = ((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u128;
|
||||
header |= recipient_session_id.wrapping_shl(16) as u128;
|
||||
header |= (counter.to_u32() as u128).wrapping_shl(64);
|
||||
header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap()
|
||||
|
@ -980,7 +963,8 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(mut send: SendFunctio
|
|||
if fragment_end < packet_len {
|
||||
fragment_start = fragment_end - HEADER_SIZE;
|
||||
fragment_end = (fragment_start + mtu).min(packet_len);
|
||||
header[1] += 0x04; // increment fragment number, at bit 2 in byte 1 since type/fragment u16 is little-endian
|
||||
debug_assert!(header[1].wrapping_shr(2) < 63);
|
||||
header[1] += 0x04; // increment fragment number in least significant 6 bits of byte 1
|
||||
packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header);
|
||||
} else {
|
||||
debug_assert_eq!(fragment_end, packet_len);
|
||||
|
@ -1104,171 +1088,154 @@ fn get_aes_gcm_nonce(packet: &[u8]) -> [u8; 16] {
|
|||
tmp
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
|
||||
/*
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::rc::Rc;
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::LinkedList;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn alice_bob() {
|
||||
let psk: Secret<64> = Secret::default();
|
||||
let mut a_buffer = [0_u8; 1500];
|
||||
let mut b_buffer = [0_u8; 1500];
|
||||
let alice_static_keypair = P384KeyPair::generate();
|
||||
let bob_static_keypair = P384KeyPair::generate();
|
||||
let outgoing_obfuscator_to_alice = Obfuscator::new(alice_static_keypair.public_key_bytes());
|
||||
let outgoing_obfuscator_to_bob = Obfuscator::new(bob_static_keypair.public_key_bytes());
|
||||
struct TestHost {
|
||||
local_s: P384KeyPair,
|
||||
local_s_hash: [u8; 48],
|
||||
psk: Secret<64>,
|
||||
session: Mutex<Option<Arc<Session<Box<TestHost>>>>>,
|
||||
session_id_counter: Mutex<u64>,
|
||||
pub queue: Mutex<LinkedList<Vec<u8>>>,
|
||||
pub this_name: &'static str,
|
||||
pub other_name: &'static str,
|
||||
}
|
||||
|
||||
let mut from_alice: Vec<Vec<u8>> = Vec::new();
|
||||
let mut from_bob: Vec<Vec<u8>> = Vec::new();
|
||||
|
||||
// Session TO Bob, on Alice's side.
|
||||
let (alice, packet) = Session::new(
|
||||
&mut a_buffer,
|
||||
SessionId::new_random(),
|
||||
alice_static_keypair.public_key_bytes(),
|
||||
&alice_static_keypair,
|
||||
bob_static_keypair.public_key_bytes(),
|
||||
bob_static_keypair.public_key(),
|
||||
&psk,
|
||||
0,
|
||||
0,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
let alice = Rc::new(alice);
|
||||
from_alice.push(packet.to_vec());
|
||||
|
||||
// Session FROM Alice, on Bob's side.
|
||||
let mut bob: Option<Rc<Session<u32>>> = None;
|
||||
|
||||
for _ in 0..256 {
|
||||
while !from_alice.is_empty() || !from_bob.is_empty() {
|
||||
if let Some(packet) = from_alice.pop() {
|
||||
let r = Session::receive(
|
||||
packet.as_slice(),
|
||||
&mut b_buffer,
|
||||
&bob_static_keypair,
|
||||
&outgoing_obfuscator_to_bob,
|
||||
|p: &[u8; P384_PUBLIC_KEY_SIZE]| P384PublicKey::from_bytes(p),
|
||||
|sid| {
|
||||
println!("[noise] [bob] session ID: {}", u64::from(sid));
|
||||
if let Some(bob) = bob.as_ref() {
|
||||
if sid == bob.id {
|
||||
Some(bob.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
},
|
||||
|_: &[u8; P384_PUBLIC_KEY_SIZE]| {
|
||||
if bob.is_none() {
|
||||
Some((SessionId::new_random(), psk.clone(), 0))
|
||||
} else {
|
||||
panic!("[noise] [bob] Bob received a second new session request from Alice");
|
||||
}
|
||||
},
|
||||
0,
|
||||
true,
|
||||
);
|
||||
if let Ok(r) = r {
|
||||
match r {
|
||||
ReceiveResult::OkData(data, counter) => {
|
||||
println!("[noise] [bob] DATA len=={} counter=={}", data.len(), counter);
|
||||
}
|
||||
ReceiveResult::OkSendReply(p) => {
|
||||
println!("[noise] [bob] OK (reply: {} bytes)", p.len());
|
||||
from_bob.push(p.to_vec());
|
||||
}
|
||||
ReceiveResult::OkNewSession(ns, p) => {
|
||||
if bob.is_some() {
|
||||
panic!("[noise] [bob] attempt to create new session on Bob's side when he already has one");
|
||||
}
|
||||
let id: u64 = ns.id.into();
|
||||
let _ = bob.replace(Rc::new(ns));
|
||||
from_bob.push(p.to_vec());
|
||||
println!("[noise] [bob] NEW SESSION {}", id);
|
||||
}
|
||||
ReceiveResult::Ok => {
|
||||
println!("[noise] [bob] OK");
|
||||
}
|
||||
ReceiveResult::Duplicate => {
|
||||
println!("[noise] [bob] duplicate packet");
|
||||
}
|
||||
ReceiveResult::Ignored => {
|
||||
println!("[noise] [bob] ignored packet");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("ERROR (bob): {}", r.err().unwrap().to_string());
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(packet) = from_bob.pop() {
|
||||
let r = Session::receive(
|
||||
packet.as_slice(),
|
||||
&mut b_buffer,
|
||||
&alice_static_keypair,
|
||||
&outgoing_obfuscator_to_alice,
|
||||
|p: &[u8; P384_PUBLIC_KEY_SIZE]| P384PublicKey::from_bytes(p),
|
||||
|sid| {
|
||||
println!("[noise] [alice] session ID: {}", u64::from(sid));
|
||||
if sid == alice.id {
|
||||
Some(alice.clone())
|
||||
} else {
|
||||
panic!("[noise] [alice] received from Bob addressed to unknown session ID, not Alice");
|
||||
}
|
||||
},
|
||||
|_: &[u8; P384_PUBLIC_KEY_SIZE]| {
|
||||
panic!("[noise] [alice] Alice received an unexpected new session request from Bob");
|
||||
},
|
||||
0,
|
||||
true,
|
||||
);
|
||||
if let Ok(r) = r {
|
||||
match r {
|
||||
ReceiveResult::OkData(data, counter) => {
|
||||
println!("[noise] [alice] DATA len=={} counter=={}", data.len(), counter);
|
||||
}
|
||||
ReceiveResult::OkSendReply(p) => {
|
||||
println!("[noise] [alice] OK (reply: {} bytes)", p.len());
|
||||
from_alice.push(p.to_vec());
|
||||
}
|
||||
ReceiveResult::OkNewSession(_, _) => {
|
||||
panic!("[noise] [alice] attempt to create new session on Alice's side; Bob should not initiate");
|
||||
}
|
||||
ReceiveResult::Ok => {
|
||||
println!("[noise] [alice] OK");
|
||||
}
|
||||
ReceiveResult::Duplicate => {
|
||||
println!("[noise] [alice] duplicate packet");
|
||||
}
|
||||
ReceiveResult::Ignored => {
|
||||
println!("[noise] [alice] ignored packet");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("ERROR (alice): {}", r.err().unwrap().to_string());
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
impl TestHost {
|
||||
fn new(psk: Secret<64>, this_name: &'static str, other_name: &'static str) -> Self {
|
||||
let local_s = P384KeyPair::generate();
|
||||
let local_s_hash = SHA384::hash(local_s.public_key_bytes());
|
||||
Self {
|
||||
local_s,
|
||||
local_s_hash,
|
||||
psk,
|
||||
session: Mutex::new(None),
|
||||
session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1),
|
||||
queue: Mutex::new(LinkedList::new()),
|
||||
this_name,
|
||||
other_name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (random::next_u32_secure() & 1) == 0 {
|
||||
from_alice.push(alice.send(&mut a_buffer, &[0_u8; 16]).unwrap().to_vec());
|
||||
} else if bob.is_some() {
|
||||
from_bob.push(bob.as_ref().unwrap().send(&mut b_buffer, &[0_u8; 16]).unwrap().to_vec());
|
||||
impl Host for Box<TestHost> {
|
||||
type AssociatedObject = u32;
|
||||
type SessionRef = Arc<Session<Box<TestHost>>>;
|
||||
type IncomingPacketBuffer = Vec<u8>;
|
||||
|
||||
fn get_local_s_public(&self) -> &[u8] {
|
||||
self.local_s.public_key_bytes()
|
||||
}
|
||||
|
||||
fn get_local_s_public_hash(&self) -> &[u8; 48] {
|
||||
&self.local_s_hash
|
||||
}
|
||||
|
||||
fn get_local_s_keypair_p384(&self) -> &P384KeyPair {
|
||||
&self.local_s
|
||||
}
|
||||
|
||||
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey> {
|
||||
P384PublicKey::from_bytes(static_public)
|
||||
}
|
||||
|
||||
fn session_lookup(&self, local_session_id: SessionId) -> Option<Self::SessionRef> {
|
||||
self.session.lock().as_ref().and_then(|s| {
|
||||
if s.id == local_session_id {
|
||||
Some(s.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn accept_new_session(&self, _: &[u8], _: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)> {
|
||||
loop {
|
||||
let mut new_id = self.session_id_counter.lock();
|
||||
*new_id += 1;
|
||||
return Some((SessionId::new_from_u64(*new_id).unwrap(), self.psk.clone(), 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
#[test]
|
||||
fn establish_session() {
|
||||
let mut psk: Secret<64> = Secret::default();
|
||||
random::fill_bytes_secure(&mut psk.0);
|
||||
let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob"));
|
||||
let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice"));
|
||||
let rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new());
|
||||
let mut data_buf = [0_u8; 4096];
|
||||
|
||||
//println!("zssp: size of session (bytes): {}", std::mem::size_of::<Session<Box<TestHost>>>());
|
||||
|
||||
let _ = alice_host.session.lock().insert(Arc::new(
|
||||
Session::new(
|
||||
&alice_host,
|
||||
|data| bob_host.queue.lock().push_front(data.to_vec()),
|
||||
SessionId::new_random(),
|
||||
bob_host.local_s.public_key_bytes(),
|
||||
&[],
|
||||
&psk,
|
||||
1,
|
||||
1280,
|
||||
1,
|
||||
true,
|
||||
)
|
||||
.unwrap(),
|
||||
));
|
||||
|
||||
let mut ts = 0;
|
||||
for _ in 0..256 {
|
||||
for host in [&alice_host, &bob_host] {
|
||||
let send_to_other = |data: &mut [u8]| {
|
||||
if std::ptr::eq(host, &alice_host) {
|
||||
bob_host.queue.lock().push_front(data.to_vec());
|
||||
} else {
|
||||
alice_host.queue.lock().push_front(data.to_vec());
|
||||
}
|
||||
};
|
||||
|
||||
loop {
|
||||
if let Some(qi) = host.queue.lock().pop_back() {
|
||||
let qi_len = qi.len();
|
||||
ts += 1;
|
||||
let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, true, ts);
|
||||
if r.is_ok() {
|
||||
let r = r.unwrap();
|
||||
match r {
|
||||
ReceiveResult::Ok => {
|
||||
println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len);
|
||||
}
|
||||
ReceiveResult::OkData(data) => {
|
||||
println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len());
|
||||
}
|
||||
ReceiveResult::OkNewSession(new_session) => {
|
||||
println!("zssp: {} => {} ({}): OkNewSession ({})", host.other_name, host.this_name, qi_len, u64::from(new_session.id));
|
||||
let mut hs = host.session.lock();
|
||||
assert!(hs.is_none());
|
||||
let _ = hs.insert(Arc::new(new_session));
|
||||
}
|
||||
ReceiveResult::Ignored => {
|
||||
println!("zssp: {} => {} ({}): Ignored", host.other_name, host.this_name, qi_len);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("zssp: {} => {}: error: {}", host.other_name, host.this_name, r.err().unwrap().to_string());
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -11,16 +11,16 @@ use crate::arrayvec::ArrayVec;
|
|||
pub struct GatherArray<T, const C: usize> {
|
||||
a: [MaybeUninit<T>; C],
|
||||
have_bits: u64,
|
||||
have_count: u32,
|
||||
goal: u32,
|
||||
have_count: u8,
|
||||
goal: u8,
|
||||
}
|
||||
|
||||
impl<T, const C: usize> GatherArray<T, C> {
|
||||
/// Create a new gather array, which must be initialized prior to use.
|
||||
#[inline(always)]
|
||||
pub fn new(goal: u32) -> Self {
|
||||
pub fn new(goal: u8) -> Self {
|
||||
assert!(C <= 64);
|
||||
assert!(goal <= (C as u32));
|
||||
assert!(goal <= (C as u8));
|
||||
assert_eq!(size_of::<[T; C]>(), size_of::<[MaybeUninit<T>; C]>());
|
||||
Self {
|
||||
a: unsafe { MaybeUninit::uninit().assume_init() },
|
||||
|
@ -32,10 +32,10 @@ impl<T, const C: usize> GatherArray<T, C> {
|
|||
|
||||
/// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here.
|
||||
#[inline(always)]
|
||||
pub fn add(&mut self, index: u32, value: T) -> Option<ArrayVec<T, C>> {
|
||||
pub fn add(&mut self, index: u8, value: T) -> Option<ArrayVec<T, C>> {
|
||||
if index < self.goal {
|
||||
let mut have = self.have_bits;
|
||||
let got = 1u64.wrapping_shl(index);
|
||||
let got = 1u64.wrapping_shl(index as u32);
|
||||
if (have & got) == 0 {
|
||||
have |= got;
|
||||
self.have_bits = have;
|
||||
|
@ -64,7 +64,7 @@ impl<T, const C: usize> Drop for GatherArray<T, C> {
|
|||
fn drop(&mut self) {
|
||||
let have = self.have_bits;
|
||||
for i in 0..self.goal {
|
||||
if (have & 1u64.wrapping_shl(i)) != 0 {
|
||||
if (have & 1u64.wrapping_shl(i as u32)) != 0 {
|
||||
unsafe { self.a.get_unchecked_mut(i as usize).assume_init_drop() };
|
||||
}
|
||||
}
|
||||
|
@ -78,8 +78,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn gather_array() {
|
||||
for goal in 2..64 {
|
||||
let mut m = GatherArray::<u32, 64>::new(goal);
|
||||
for goal in 2u8..64u8 {
|
||||
let mut m = GatherArray::<u8, 64>::new(goal);
|
||||
for x in 0..(goal - 1) {
|
||||
assert!(m.add(x, x).is_none());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
pub mod arrayvec;
|
||||
pub mod gatherarray;
|
||||
pub mod hex;
|
||||
pub mod memory;
|
||||
pub mod ringbuffermap;
|
||||
pub mod varint;
|
||||
|
|
143
utils/src/memory.rs
Normal file
143
utils/src/memory.rs
Normal file
|
@ -0,0 +1,143 @@
|
|||
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
|
||||
#[allow(unused)]
|
||||
mod fast_int_memory_access {
|
||||
#[inline(always)]
|
||||
pub fn u64_from_le_bytes(b: &[u8]) -> u64 {
|
||||
assert!(b.len() >= 8);
|
||||
unsafe { *b.as_ptr().cast() }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u32_from_le_bytes(b: &[u8]) -> u32 {
|
||||
assert!(b.len() >= 4);
|
||||
unsafe { *b.as_ptr().cast() }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u16_from_le_bytes(b: &[u8]) -> u16 {
|
||||
assert!(b.len() >= 2);
|
||||
unsafe { *b.as_ptr().cast() }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i64_from_le_bytes(b: &[u8]) -> i64 {
|
||||
assert!(b.len() >= 8);
|
||||
unsafe { *b.as_ptr().cast() }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i32_from_le_bytes(b: &[u8]) -> i32 {
|
||||
assert!(b.len() >= 4);
|
||||
unsafe { *b.as_ptr().cast() }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i16_from_le_bytes(b: &[u8]) -> i16 {
|
||||
assert!(b.len() >= 2);
|
||||
unsafe { *b.as_ptr().cast() }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
|
||||
assert!(b.len() >= 8);
|
||||
unsafe { *b.as_ptr().cast::<u64>() }.swap_bytes()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u32_from_be_bytes(b: &[u8]) -> u32 {
|
||||
assert!(b.len() >= 4);
|
||||
unsafe { *b.as_ptr().cast::<u32>() }.swap_bytes()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u16_from_be_bytes(b: &[u8]) -> u16 {
|
||||
assert!(b.len() >= 2);
|
||||
unsafe { *b.as_ptr().cast::<u16>() }.swap_bytes()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i64_from_be_bytes(b: &[u8]) -> i64 {
|
||||
assert!(b.len() >= 8);
|
||||
unsafe { *b.as_ptr().cast::<i64>() }.swap_bytes()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i32_from_be_bytes(b: &[u8]) -> i32 {
|
||||
assert!(b.len() >= 4);
|
||||
unsafe { *b.as_ptr().cast::<i32>() }.swap_bytes()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i16_from_be_bytes(b: &[u8]) -> i16 {
|
||||
assert!(b.len() >= 2);
|
||||
unsafe { *b.as_ptr().cast::<i16>() }.swap_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
|
||||
#[allow(unused)]
|
||||
mod fast_int_memory_access {
|
||||
#[inline(always)]
|
||||
pub fn u64_from_le_bytes(b: &[u8]) -> u64 {
|
||||
u64::from_le_bytes(b[..8].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u32_from_le_bytes(b: &[u8]) -> u32 {
|
||||
u32::from_le_bytes(b[..4].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u16_from_le_bytes(b: &[u8]) -> u16 {
|
||||
u16::from_le_bytes(b[..2].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i64_from_le_bytes(b: &[u8]) -> i64 {
|
||||
i64::from_le_bytes(b[..8].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i32_from_le_bytes(b: &[u8]) -> i32 {
|
||||
i32::from_le_bytes(b[..4].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i16_from_le_bytes(b: &[u8]) -> i16 {
|
||||
i16::from_le_bytes(b[..2].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
|
||||
u64::from_be_bytes(b[..8].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u32_from_be_bytes(b: &[u8]) -> u32 {
|
||||
u32::from_be_bytes(b[..4].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u16_from_be_bytes(b: &[u8]) -> u16 {
|
||||
u16::from_be_bytes(b[..2].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i64_from_be_bytes(b: &[u8]) -> i64 {
|
||||
i64::from_be_bytes(b[..8].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i32_from_be_bytes(b: &[u8]) -> i32 {
|
||||
i32::from_be_bytes(b[..4].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i16_from_be_bytes(b: &[u8]) -> i16 {
|
||||
i16::from_be_bytes(b[..2].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
pub use fast_int_memory_access::*;
|
|
@ -106,13 +106,15 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
|
|||
#[inline]
|
||||
pub fn new(salt: u32) -> Self {
|
||||
Self {
|
||||
entries: std::array::from_fn(|_| Entry::<K, V> {
|
||||
key: MaybeUninit::uninit(),
|
||||
value: MaybeUninit::uninit(),
|
||||
bucket: -1,
|
||||
next: -1,
|
||||
prev: -1,
|
||||
}),
|
||||
entries: {
|
||||
let mut entries: [Entry<K, V>; C] = unsafe { MaybeUninit::uninit().assume_init() };
|
||||
for e in entries.iter_mut() {
|
||||
e.bucket = -1;
|
||||
e.next = -1;
|
||||
e.prev = -1;
|
||||
}
|
||||
entries
|
||||
},
|
||||
buckets: [-1; B],
|
||||
entry_ptr: 0,
|
||||
salt,
|
||||
|
|
Loading…
Add table
Reference in a new issue