mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2025-06-08 05:23:44 +02:00
Session works again, and some optimization.
This commit is contained in:
parent
3770fcdc83
commit
06573c1ea8
5 changed files with 374 additions and 261 deletions
|
@ -12,6 +12,7 @@ use crate::random;
|
||||||
use crate::secret::Secret;
|
use crate::secret::Secret;
|
||||||
|
|
||||||
use zerotier_utils::gatherarray::GatherArray;
|
use zerotier_utils::gatherarray::GatherArray;
|
||||||
|
use zerotier_utils::memory;
|
||||||
use zerotier_utils::ringbuffermap::RingBufferMap;
|
use zerotier_utils::ringbuffermap::RingBufferMap;
|
||||||
use zerotier_utils::varint;
|
use zerotier_utils::varint;
|
||||||
|
|
||||||
|
@ -41,6 +42,7 @@ const OFFER_RATE_LIMIT_MS: i64 = 2000;
|
||||||
/// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication)
|
/// Version 0: NIST P-384 forward secrecy and authentication with optional Kyber1024 forward secrecy (but not authentication)
|
||||||
const SESSION_PROTOCOL_VERSION: u8 = 0x00;
|
const SESSION_PROTOCOL_VERSION: u8 = 0x00;
|
||||||
|
|
||||||
|
// Packet types can range from 0 to 15 (4 bits) -- 0-3 are defined and 4-15 are reserved for future use
|
||||||
const PACKET_TYPE_DATA: u8 = 0;
|
const PACKET_TYPE_DATA: u8 = 0;
|
||||||
const PACKET_TYPE_NOP: u8 = 1;
|
const PACKET_TYPE_NOP: u8 = 1;
|
||||||
const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice"
|
const PACKET_TYPE_KEY_OFFER: u8 = 2; // "alice"
|
||||||
|
@ -56,6 +58,10 @@ const AES_GCM_TAG_SIZE: usize = 16;
|
||||||
const HMAC_SIZE: usize = 48;
|
const HMAC_SIZE: usize = 48;
|
||||||
const SESSION_ID_SIZE: usize = 6;
|
const SESSION_ID_SIZE: usize = 6;
|
||||||
|
|
||||||
|
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
|
||||||
|
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
|
||||||
|
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
|
||||||
|
|
||||||
/// Aribitrary starting value for key derivation chain.
|
/// Aribitrary starting value for key derivation chain.
|
||||||
///
|
///
|
||||||
/// It doesn't matter very much what this is, but it's good for it to be unique.
|
/// It doesn't matter very much what this is, but it's good for it to be unique.
|
||||||
|
@ -66,10 +72,6 @@ const KEY_DERIVATION_CHAIN_STARTING_SALT: [u8; 64] = [
|
||||||
0xb4, 0x32, 0x85, 0xaf, 0x7f, 0x0d, 0xa9, 0x6c, 0x01, 0xfb, 0x72, 0x46, 0xc0, 0x09, 0x58, 0xb8, 0xe0, 0xa8, 0xcf, 0xb1, 0x58, 0x04, 0x6e, 0x32, 0xba, 0xa8, 0xb8, 0xf9, 0x0a, 0xa4, 0xbf, 0x36,
|
0xb4, 0x32, 0x85, 0xaf, 0x7f, 0x0d, 0xa9, 0x6c, 0x01, 0xfb, 0x72, 0x46, 0xc0, 0x09, 0x58, 0xb8, 0xe0, 0xa8, 0xcf, 0xb1, 0x58, 0x04, 0x6e, 0x32, 0xba, 0xa8, 0xb8, 0xf9, 0x0a, 0xa4, 0xbf, 0x36,
|
||||||
];
|
];
|
||||||
|
|
||||||
const KBKDF_KEY_USAGE_LABEL_HMAC: u8 = b'M';
|
|
||||||
const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A';
|
|
||||||
const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B';
|
|
||||||
|
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
/// The packet was addressed to an unrecognized local session
|
/// The packet was addressed to an unrecognized local session
|
||||||
UnknownLocalSessionId(SessionId),
|
UnknownLocalSessionId(SessionId),
|
||||||
|
@ -147,16 +149,17 @@ pub enum ReceiveResult<'a, H: Host> {
|
||||||
/// Packet is valid, no action needs to be taken.
|
/// Packet is valid, no action needs to be taken.
|
||||||
Ok,
|
Ok,
|
||||||
|
|
||||||
/// Packet is valid and contained a data payload.
|
/// Packet is valid and a data payload was decoded and authenticated.
|
||||||
OkData(&'a [u8]),
|
///
|
||||||
|
/// The returned reference is to the filled parts of the data buffer supplied to receive.
|
||||||
|
OkData(&'a mut [u8]),
|
||||||
|
|
||||||
/// Packet is valid and a new session was created, also includes a reply to be sent back.
|
/// Packet is valid and a new session was created.
|
||||||
|
///
|
||||||
|
/// The session will have already been gated by the accept_new_session() method in the Host trait.
|
||||||
OkNewSession(Session<H>),
|
OkNewSession(Session<H>),
|
||||||
|
|
||||||
/// Packet appears valid but was ignored as a duplicate.
|
/// Packet apperas valid but was ignored e.g. as a duplicate.
|
||||||
Duplicate,
|
|
||||||
|
|
||||||
/// Packet apperas valid but was ignored for another reason.
|
|
||||||
Ignored,
|
Ignored,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,20 +170,15 @@ pub struct SessionId(NonZeroU64);
|
||||||
impl SessionId {
|
impl SessionId {
|
||||||
pub const MAX_BIT_MASK: u64 = 0xffffffffffff;
|
pub const MAX_BIT_MASK: u64 = 0xffffffffffff;
|
||||||
|
|
||||||
pub fn new_from_bytes(b: &[u8]) -> Option<SessionId> {
|
#[inline(always)]
|
||||||
if b.len() >= 6 {
|
pub fn new_from_u64(i: u64) -> Option<SessionId> {
|
||||||
let value = (u32::from_le_bytes(b[..4].try_into().unwrap()) as u64) | (u16::from_le_bytes(b[4..6].try_into().unwrap()) as u64).wrapping_shl(32);
|
debug_assert!(i <= Self::MAX_BIT_MASK);
|
||||||
if value > 0 && value <= Self::MAX_BIT_MASK {
|
NonZeroU64::new(i).map(|i| Self(i))
|
||||||
return Some(Self(NonZeroU64::new(value).unwrap()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return None;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_from_reader<R: Read>(r: &mut R) -> std::io::Result<Option<SessionId>> {
|
pub fn new_from_reader<R: Read>(r: &mut R) -> std::io::Result<Option<SessionId>> {
|
||||||
let mut tmp = [0_u8; SESSION_ID_SIZE];
|
let mut tmp = [0_u8; 8];
|
||||||
r.read_exact(&mut tmp)?;
|
r.read_exact(&mut tmp[..SESSION_ID_SIZE]).map(|_| NonZeroU64::new(u64::from_le_bytes(tmp)).map(|i| Self(i)))
|
||||||
Ok(Self::new_from_bytes(&tmp))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_random() -> Self {
|
pub fn new_random() -> Self {
|
||||||
|
@ -193,19 +191,6 @@ impl SessionId {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<u64> for SessionId {
|
|
||||||
type Error = self::Error;
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn try_from(value: u64) -> Result<Self, Self::Error> {
|
|
||||||
if value > 0 && value <= Self::MAX_BIT_MASK {
|
|
||||||
Ok(Self(NonZeroU64::new(value).unwrap()))
|
|
||||||
} else {
|
|
||||||
Err(Error::InvalidParameter)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<SessionId> for u64 {
|
impl From<SessionId> for u64 {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn from(sid: SessionId) -> Self {
|
fn from(sid: SessionId) -> Self {
|
||||||
|
@ -213,21 +198,31 @@ impl From<SessionId> for u64 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Trait to implement to integrate the session into an application.
|
||||||
pub trait Host: Sized {
|
pub trait Host: Sized {
|
||||||
type AssociatedObject: Sized;
|
/// Arbitrary object that can be associated with sessions.
|
||||||
|
type AssociatedObject;
|
||||||
|
|
||||||
|
/// Arbitrary object that dereferences to the session, such as Arc<Session<Self>>.
|
||||||
type SessionRef: Deref<Target = Session<Self>>;
|
type SessionRef: Deref<Target = Session<Self>>;
|
||||||
|
|
||||||
|
/// A buffer containing data read from the network that can be cached.
|
||||||
|
///
|
||||||
|
/// This can be e.g. a pooled buffer that automatically returns itself to the pool when dropped.
|
||||||
type IncomingPacketBuffer: AsRef<[u8]>;
|
type IncomingPacketBuffer: AsRef<[u8]>;
|
||||||
|
|
||||||
/// Get a reference to this host's static public key blob.
|
/// Get a reference to this host's static public key blob.
|
||||||
|
///
|
||||||
|
/// This must contain a NIST P-384 public key but can contain other information.
|
||||||
fn get_local_s_public(&self) -> &[u8];
|
fn get_local_s_public(&self) -> &[u8];
|
||||||
|
|
||||||
/// Get SHA384(this host's static public key blob)
|
/// Get SHA384(this host's static public key blob), included here so we don't have to calculate it each time.
|
||||||
fn get_local_s_public_hash(&self) -> &[u8; 48];
|
fn get_local_s_public_hash(&self) -> &[u8; 48];
|
||||||
|
|
||||||
/// Get a reference to this hosts' static public key's NIST P-384 secret key pair
|
/// Get a reference to this hosts' static public key's NIST P-384 secret key pair
|
||||||
fn get_local_s_keypair_p384(&self) -> &P384KeyPair;
|
fn get_local_s_keypair_p384(&self) -> &P384KeyPair;
|
||||||
|
|
||||||
/// Extract a NIST P-384 ECC public key from a static public key blob.
|
/// Extract the NIST P-384 ECC public key component from a static public key blob or return None on failure.
|
||||||
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey>;
|
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey>;
|
||||||
|
|
||||||
/// Look up a local session by local ID.
|
/// Look up a local session by local ID.
|
||||||
|
@ -246,23 +241,23 @@ pub struct Session<H: Host> {
|
||||||
pub associated_object: H::AssociatedObject,
|
pub associated_object: H::AssociatedObject,
|
||||||
|
|
||||||
send_counter: Counter,
|
send_counter: Counter,
|
||||||
psk: Secret<64>,
|
psk: Secret<64>, // Arbitrary PSK provided by external code
|
||||||
ss: Secret<48>,
|
ss: Secret<48>, // NIST P-384 raw ECDH key agreement with peer
|
||||||
state: RwLock<MutableState>,
|
state: RwLock<MutableState>, // Mutable parts of state (other than defrag buffers)
|
||||||
remote_s_public_hash: [u8; 48],
|
remote_s_public_hash: [u8; 48], // SHA384(remote static public key blob)
|
||||||
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE],
|
remote_s_public_p384: [u8; P384_PUBLIC_KEY_SIZE], // Remote NIST P-384 static public key
|
||||||
defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 64, 32>>,
|
defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, MAX_FRAGMENTS>, 16, 4>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MutableState {
|
struct MutableState {
|
||||||
remote_session_id: Option<SessionId>,
|
remote_session_id: Option<SessionId>,
|
||||||
keys: [Option<SessionKey>; 2], // current, next
|
keys: [Option<SessionKey>; 2], // current, next (promoted to current on successful decrypt)
|
||||||
offer: Option<EphemeralOffer>,
|
offer: Option<EphemeralOffer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// State information to associate with receiving contexts such as sockets or remote paths/endpoints.
|
/// State information to associate with receiving contexts such as sockets or remote paths/endpoints.
|
||||||
pub struct ReceiveContext<H: Host> {
|
pub struct ReceiveContext<H: Host> {
|
||||||
initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 256>>,
|
initial_offer_defrag: Mutex<RingBufferMap<u32, GatherArray<H::IncomingPacketBuffer, KEY_EXCHANGE_MAX_FRAGMENTS>, 1024, 128>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<H: Host> Session<H> {
|
impl<H: Host> Session<H> {
|
||||||
|
@ -362,26 +357,25 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
current_time: i64,
|
current_time: i64,
|
||||||
) -> Result<ReceiveResult<'a, H>, Error> {
|
) -> Result<ReceiveResult<'a, H>, Error> {
|
||||||
let incoming_packet = incoming_packet_buf.as_ref();
|
let incoming_packet = incoming_packet_buf.as_ref();
|
||||||
|
|
||||||
if incoming_packet.len() < MIN_PACKET_SIZE {
|
if incoming_packet.len() < MIN_PACKET_SIZE {
|
||||||
unlikely_branch();
|
unlikely_branch();
|
||||||
return Err(Error::InvalidPacket);
|
return Err(Error::InvalidPacket);
|
||||||
}
|
}
|
||||||
|
|
||||||
let type_and_frag_info = u16::from_le_bytes(incoming_packet[0..2].try_into().unwrap());
|
let header_0_8 = memory::u64_from_le_bytes(incoming_packet); // type, frag info, session ID
|
||||||
let local_session_id = SessionId::new_from_bytes(&incoming_packet[2..8]);
|
let counter = memory::u32_from_le_bytes(&incoming_packet[8..]);
|
||||||
let counter = u32::from_le_bytes(incoming_packet[8..12].try_into().unwrap());
|
let local_session_id = SessionId::new_from_u64(header_0_8.wrapping_shr(16));
|
||||||
let packet_type = (type_and_frag_info as u8) & 15;
|
let packet_type = (header_0_8 as u8) & 15;
|
||||||
let fragment_count = type_and_frag_info.wrapping_shr(4) & 63;
|
let fragment_count = ((header_0_8.wrapping_shr(4) as u8) & 63).wrapping_add(1);
|
||||||
let fragment_no = type_and_frag_info.wrapping_shr(10); // & 63 not needed
|
let fragment_no = (header_0_8.wrapping_shr(10) as u8) & 63;
|
||||||
|
|
||||||
if fragment_count > 1 {
|
if fragment_count > 1 {
|
||||||
if let Some(local_session_id) = local_session_id {
|
if let Some(local_session_id) = local_session_id {
|
||||||
if fragment_count < (MAX_FRAGMENTS as u16) && fragment_no < fragment_count {
|
if fragment_count <= (MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
|
||||||
if let Some(session) = host.session_lookup(local_session_id) {
|
if let Some(session) = host.session_lookup(local_session_id) {
|
||||||
let mut defrag = session.defrag.lock();
|
let mut defrag = session.defrag.lock();
|
||||||
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32));
|
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
|
||||||
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) {
|
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
|
||||||
drop(defrag); // release lock
|
drop(defrag); // release lock
|
||||||
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time);
|
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, Some(session), mtu, jedi, current_time);
|
||||||
}
|
}
|
||||||
|
@ -394,10 +388,10 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
return Err(Error::InvalidPacket);
|
return Err(Error::InvalidPacket);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if fragment_count < (KEY_EXCHANGE_MAX_FRAGMENTS as u16) && fragment_no < fragment_count {
|
if fragment_count <= (KEY_EXCHANGE_MAX_FRAGMENTS as u8) && fragment_no < fragment_count {
|
||||||
let mut defrag = self.initial_offer_defrag.lock();
|
let mut defrag = self.initial_offer_defrag.lock();
|
||||||
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count as u32));
|
let fragment_gather_array = defrag.get_or_create_mut(&counter, || GatherArray::new(fragment_count));
|
||||||
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no as u32, incoming_packet_buf) {
|
if let Some(assembled_packet) = fragment_gather_array.add(fragment_no, incoming_packet_buf) {
|
||||||
drop(defrag); // release lock
|
drop(defrag); // release lock
|
||||||
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
|
return self.receive_complete(host, send, data_buf, assembled_packet.as_ref(), packet_type, None, mtu, jedi, current_time);
|
||||||
}
|
}
|
||||||
|
@ -444,22 +438,19 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
let state = session.state.read();
|
let state = session.state.read();
|
||||||
for ki in 0..2 {
|
for ki in 0..2 {
|
||||||
if let Some(key) = state.keys[ki].as_ref() {
|
if let Some(key) = state.keys[ki].as_ref() {
|
||||||
let head = fragments.first().unwrap().as_ref();
|
let tail = fragments.last().unwrap().as_ref();
|
||||||
debug_assert!(head.len() >= MIN_PACKET_SIZE);
|
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
|
||||||
|
unlikely_branch();
|
||||||
|
return Err(Error::InvalidPacket);
|
||||||
|
}
|
||||||
|
|
||||||
let mut c = key.get_receive_cipher();
|
let mut c = key.get_receive_cipher();
|
||||||
c.init(&get_aes_gcm_nonce(head));
|
c.init(&get_aes_gcm_nonce(fragments.first().unwrap().as_ref()));
|
||||||
|
|
||||||
let mut data_len = head.len() - HEADER_SIZE;
|
let mut data_len = 0;
|
||||||
if data_len > data_buf.len() {
|
|
||||||
unlikely_branch();
|
|
||||||
key.return_receive_cipher(c);
|
|
||||||
return Err(Error::DataBufferTooSmall);
|
|
||||||
}
|
|
||||||
c.crypt(&head[HEADER_SIZE..], &mut data_buf[..data_len]);
|
|
||||||
|
|
||||||
for fi in 1..(fragments.len() - 1) {
|
for f in fragments[..(fragments.len() - 1)].iter() {
|
||||||
let f = fragments[fi].as_ref();
|
let f = f.as_ref();
|
||||||
debug_assert!(f.len() >= HEADER_SIZE);
|
debug_assert!(f.len() >= HEADER_SIZE);
|
||||||
let current_frag_data_start = data_len;
|
let current_frag_data_start = data_len;
|
||||||
data_len += f.len() - HEADER_SIZE;
|
data_len += f.len() - HEADER_SIZE;
|
||||||
|
@ -471,21 +462,14 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
|
c.crypt(&f[HEADER_SIZE..], &mut data_buf[current_frag_data_start..data_len]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let tail = fragments.last().unwrap().as_ref();
|
|
||||||
if tail.len() < (HEADER_SIZE + AES_GCM_TAG_SIZE) {
|
|
||||||
unlikely_branch();
|
|
||||||
key.return_receive_cipher(c);
|
|
||||||
return Err(Error::InvalidPacket);
|
|
||||||
}
|
|
||||||
let tail_data_len = tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
|
|
||||||
let current_frag_data_start = data_len;
|
let current_frag_data_start = data_len;
|
||||||
data_len += tail_data_len;
|
data_len += tail.len() - (HEADER_SIZE + AES_GCM_TAG_SIZE);
|
||||||
if data_len > data_buf.len() {
|
if data_len > data_buf.len() {
|
||||||
unlikely_branch();
|
unlikely_branch();
|
||||||
key.return_receive_cipher(c);
|
key.return_receive_cipher(c);
|
||||||
return Err(Error::DataBufferTooSmall);
|
return Err(Error::DataBufferTooSmall);
|
||||||
}
|
}
|
||||||
c.crypt(&tail[HEADER_SIZE..tail_data_len], &mut data_buf[current_frag_data_start..data_len]);
|
c.crypt(&tail[HEADER_SIZE..(tail.len() - AES_GCM_TAG_SIZE)], &mut data_buf[current_frag_data_start..data_len]);
|
||||||
|
|
||||||
let tag = c.finish();
|
let tag = c.finish();
|
||||||
key.return_receive_cipher(c);
|
key.return_receive_cipher(c);
|
||||||
|
@ -500,7 +484,7 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if packet_type == PACKET_TYPE_DATA {
|
if packet_type == PACKET_TYPE_DATA {
|
||||||
return Ok(ReceiveResult::OkData(&data_buf[..data_len]));
|
return Ok(ReceiveResult::OkData(&mut data_buf[..data_len]));
|
||||||
} else {
|
} else {
|
||||||
return Ok(ReceiveResult::Ok);
|
return Ok(ReceiveResult::Ok);
|
||||||
}
|
}
|
||||||
|
@ -540,8 +524,6 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
return Err(Error::UnknownProtocolVersion);
|
return Err(Error::UnknownProtocolVersion);
|
||||||
}
|
}
|
||||||
|
|
||||||
let local_s_keypair_p384 = host.get_local_s_keypair_p384();
|
|
||||||
|
|
||||||
match packet_type {
|
match packet_type {
|
||||||
PACKET_TYPE_KEY_OFFER => {
|
PACKET_TYPE_KEY_OFFER => {
|
||||||
// alice (remote) -> bob (local)
|
// alice (remote) -> bob (local)
|
||||||
|
@ -565,7 +547,7 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
}
|
}
|
||||||
|
|
||||||
let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
|
let (alice_e0_public, e0s) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
|
||||||
.and_then(|pk| local_s_keypair_p384.agree(&pk).map(move |s| (pk, s)))
|
.and_then(|pk| host.get_local_s_keypair_p384().agree(&pk).map(move |s| (pk, s)))
|
||||||
.ok_or(Error::FailedAuthentication)?;
|
.ok_or(Error::FailedAuthentication)?;
|
||||||
|
|
||||||
let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes()));
|
let key = Secret(hmac_sha512(&hmac_sha512(&KEY_DERIVATION_CHAIN_STARTING_SALT, alice_e0_public.as_bytes()), e0s.as_bytes()));
|
||||||
|
@ -588,7 +570,7 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
}
|
}
|
||||||
|
|
||||||
let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?;
|
let alice_s_public_p384 = H::extract_p384_static(alice_s_public).ok_or(Error::InvalidPacket)?;
|
||||||
let ss = local_s_keypair_p384.agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?;
|
let ss = host.get_local_s_keypair_p384().agree(&alice_s_public_p384).ok_or(Error::FailedAuthentication)?;
|
||||||
|
|
||||||
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
|
let key = Secret(hmac_sha512(key.as_bytes(), ss.as_bytes()));
|
||||||
|
|
||||||
|
@ -668,7 +650,7 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
}
|
}
|
||||||
(MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len()
|
(MIN_MTU * KEY_EXCHANGE_MAX_FRAGMENTS) - rp.len()
|
||||||
};
|
};
|
||||||
let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_OFFER, session.id.into(), reply_counter);
|
let mut header = send_with_fragmentation_init_header(reply_len, mtu, PACKET_TYPE_KEY_COUNTER_OFFER, alice_session_id.into(), reply_counter);
|
||||||
reply_buf[..HEADER_SIZE].copy_from_slice(&header);
|
reply_buf[..HEADER_SIZE].copy_from_slice(&header);
|
||||||
|
|
||||||
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true);
|
let mut c = AesGcm::new(kbkdf512(key.as_bytes(), KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE).first_n::<32>(), true);
|
||||||
|
@ -715,7 +697,7 @@ impl<H: Host> ReceiveContext<H> {
|
||||||
let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
|
let (bob_e0_public, e0e0) = P384PublicKey::from_bytes(&incoming_packet[(HEADER_SIZE + 1)..(HEADER_SIZE + 1 + P384_PUBLIC_KEY_SIZE)])
|
||||||
.and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s)))
|
.and_then(|pk| offer.alice_e0_keypair.agree(&pk).map(move |s| (pk, s)))
|
||||||
.ok_or(Error::FailedAuthentication)?;
|
.ok_or(Error::FailedAuthentication)?;
|
||||||
let se0 = local_s_keypair_p384.agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?;
|
let se0 = host.get_local_s_keypair_p384().agree(&bob_e0_public).ok_or(Error::FailedAuthentication)?;
|
||||||
|
|
||||||
let key = Secret(hmac_sha512(
|
let key = Secret(hmac_sha512(
|
||||||
session.psk.as_bytes(),
|
session.psk.as_bytes(),
|
||||||
|
@ -956,15 +938,16 @@ impl EphemeralOffer {
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] {
|
fn send_with_fragmentation_init_header(packet_len: usize, mtu: usize, packet_type: u8, recipient_session_id: u64, counter: CounterValue) -> [u8; 12] {
|
||||||
let fragment_count = (packet_len / mtu) + (((packet_len % mtu) != 0) as usize);
|
let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
|
||||||
debug_assert!(mtu >= MIN_MTU);
|
debug_assert!(mtu >= MIN_MTU);
|
||||||
debug_assert!(packet_len >= HEADER_SIZE);
|
debug_assert!(packet_len >= HEADER_SIZE);
|
||||||
debug_assert!(fragment_count <= MAX_FRAGMENTS);
|
debug_assert!(fragment_count <= MAX_FRAGMENTS);
|
||||||
|
debug_assert!(fragment_count > 0);
|
||||||
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
|
debug_assert!(packet_type <= 0x0f); // packet type is 4 bits
|
||||||
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
|
debug_assert!(recipient_session_id <= 0xffffffffffff); // session ID is 48 bits
|
||||||
|
|
||||||
// Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter
|
// Header bytes: TTRRRRRRCCCC where T == type/fragment, R == recipient session ID, C == counter
|
||||||
let mut header = (fragment_count.wrapping_shl(4) | (packet_type as usize)) as u128;
|
let mut header = ((fragment_count - 1).wrapping_shl(4) | (packet_type as usize)) as u128;
|
||||||
header |= recipient_session_id.wrapping_shl(16) as u128;
|
header |= recipient_session_id.wrapping_shl(16) as u128;
|
||||||
header |= (counter.to_u32() as u128).wrapping_shl(64);
|
header |= (counter.to_u32() as u128).wrapping_shl(64);
|
||||||
header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap()
|
header.to_le_bytes()[..HEADER_SIZE].try_into().unwrap()
|
||||||
|
@ -980,7 +963,8 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(mut send: SendFunctio
|
||||||
if fragment_end < packet_len {
|
if fragment_end < packet_len {
|
||||||
fragment_start = fragment_end - HEADER_SIZE;
|
fragment_start = fragment_end - HEADER_SIZE;
|
||||||
fragment_end = (fragment_start + mtu).min(packet_len);
|
fragment_end = (fragment_start + mtu).min(packet_len);
|
||||||
header[1] += 0x04; // increment fragment number, at bit 2 in byte 1 since type/fragment u16 is little-endian
|
debug_assert!(header[1].wrapping_shr(2) < 63);
|
||||||
|
header[1] += 0x04; // increment fragment number in least significant 6 bits of byte 1
|
||||||
packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header);
|
packet[fragment_start..(fragment_start + HEADER_SIZE)].copy_from_slice(header);
|
||||||
} else {
|
} else {
|
||||||
debug_assert_eq!(fragment_end, packet_len);
|
debug_assert_eq!(fragment_end, packet_len);
|
||||||
|
@ -1104,171 +1088,154 @@ fn get_aes_gcm_nonce(packet: &[u8]) -> [u8; 16] {
|
||||||
tmp
|
tmp
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {}
|
|
||||||
|
|
||||||
/*
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::rc::Rc;
|
use parking_lot::Mutex;
|
||||||
|
use std::collections::LinkedList;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
struct TestHost {
|
||||||
|
local_s: P384KeyPair,
|
||||||
|
local_s_hash: [u8; 48],
|
||||||
|
psk: Secret<64>,
|
||||||
|
session: Mutex<Option<Arc<Session<Box<TestHost>>>>>,
|
||||||
|
session_id_counter: Mutex<u64>,
|
||||||
|
pub queue: Mutex<LinkedList<Vec<u8>>>,
|
||||||
|
pub this_name: &'static str,
|
||||||
|
pub other_name: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestHost {
|
||||||
|
fn new(psk: Secret<64>, this_name: &'static str, other_name: &'static str) -> Self {
|
||||||
|
let local_s = P384KeyPair::generate();
|
||||||
|
let local_s_hash = SHA384::hash(local_s.public_key_bytes());
|
||||||
|
Self {
|
||||||
|
local_s,
|
||||||
|
local_s_hash,
|
||||||
|
psk,
|
||||||
|
session: Mutex::new(None),
|
||||||
|
session_id_counter: Mutex::new(random::next_u64_secure().wrapping_shr(16) | 1),
|
||||||
|
queue: Mutex::new(LinkedList::new()),
|
||||||
|
this_name,
|
||||||
|
other_name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Host for Box<TestHost> {
|
||||||
|
type AssociatedObject = u32;
|
||||||
|
type SessionRef = Arc<Session<Box<TestHost>>>;
|
||||||
|
type IncomingPacketBuffer = Vec<u8>;
|
||||||
|
|
||||||
|
fn get_local_s_public(&self) -> &[u8] {
|
||||||
|
self.local_s.public_key_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_local_s_public_hash(&self) -> &[u8; 48] {
|
||||||
|
&self.local_s_hash
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_local_s_keypair_p384(&self) -> &P384KeyPair {
|
||||||
|
&self.local_s
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_p384_static(static_public: &[u8]) -> Option<P384PublicKey> {
|
||||||
|
P384PublicKey::from_bytes(static_public)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn session_lookup(&self, local_session_id: SessionId) -> Option<Self::SessionRef> {
|
||||||
|
self.session.lock().as_ref().and_then(|s| {
|
||||||
|
if s.id == local_session_id {
|
||||||
|
Some(s.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn accept_new_session(&self, _: &[u8], _: &[u8]) -> Option<(SessionId, Secret<64>, Self::AssociatedObject)> {
|
||||||
|
loop {
|
||||||
|
let mut new_id = self.session_id_counter.lock();
|
||||||
|
*new_id += 1;
|
||||||
|
return Some((SessionId::new_from_u64(*new_id).unwrap(), self.psk.clone(), 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused_variables)]
|
||||||
#[test]
|
#[test]
|
||||||
fn alice_bob() {
|
fn establish_session() {
|
||||||
let psk: Secret<64> = Secret::default();
|
let mut psk: Secret<64> = Secret::default();
|
||||||
let mut a_buffer = [0_u8; 1500];
|
random::fill_bytes_secure(&mut psk.0);
|
||||||
let mut b_buffer = [0_u8; 1500];
|
let alice_host = Box::new(TestHost::new(psk.clone(), "alice", "bob"));
|
||||||
let alice_static_keypair = P384KeyPair::generate();
|
let bob_host = Box::new(TestHost::new(psk.clone(), "bob", "alice"));
|
||||||
let bob_static_keypair = P384KeyPair::generate();
|
let rc: Box<ReceiveContext<Box<TestHost>>> = Box::new(ReceiveContext::new());
|
||||||
let outgoing_obfuscator_to_alice = Obfuscator::new(alice_static_keypair.public_key_bytes());
|
let mut data_buf = [0_u8; 4096];
|
||||||
let outgoing_obfuscator_to_bob = Obfuscator::new(bob_static_keypair.public_key_bytes());
|
|
||||||
|
|
||||||
let mut from_alice: Vec<Vec<u8>> = Vec::new();
|
//println!("zssp: size of session (bytes): {}", std::mem::size_of::<Session<Box<TestHost>>>());
|
||||||
let mut from_bob: Vec<Vec<u8>> = Vec::new();
|
|
||||||
|
|
||||||
// Session TO Bob, on Alice's side.
|
let _ = alice_host.session.lock().insert(Arc::new(
|
||||||
let (alice, packet) = Session::new(
|
Session::new(
|
||||||
&mut a_buffer,
|
&alice_host,
|
||||||
|
|data| bob_host.queue.lock().push_front(data.to_vec()),
|
||||||
SessionId::new_random(),
|
SessionId::new_random(),
|
||||||
alice_static_keypair.public_key_bytes(),
|
bob_host.local_s.public_key_bytes(),
|
||||||
&alice_static_keypair,
|
&[],
|
||||||
bob_static_keypair.public_key_bytes(),
|
|
||||||
bob_static_keypair.public_key(),
|
|
||||||
&psk,
|
&psk,
|
||||||
0,
|
1,
|
||||||
0,
|
1280,
|
||||||
|
1,
|
||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap(),
|
||||||
let alice = Rc::new(alice);
|
));
|
||||||
from_alice.push(packet.to_vec());
|
|
||||||
|
|
||||||
// Session FROM Alice, on Bob's side.
|
|
||||||
let mut bob: Option<Rc<Session<u32>>> = None;
|
|
||||||
|
|
||||||
|
let mut ts = 0;
|
||||||
for _ in 0..256 {
|
for _ in 0..256 {
|
||||||
while !from_alice.is_empty() || !from_bob.is_empty() {
|
for host in [&alice_host, &bob_host] {
|
||||||
if let Some(packet) = from_alice.pop() {
|
let send_to_other = |data: &mut [u8]| {
|
||||||
let r = Session::receive(
|
if std::ptr::eq(host, &alice_host) {
|
||||||
packet.as_slice(),
|
bob_host.queue.lock().push_front(data.to_vec());
|
||||||
&mut b_buffer,
|
|
||||||
&bob_static_keypair,
|
|
||||||
&outgoing_obfuscator_to_bob,
|
|
||||||
|p: &[u8; P384_PUBLIC_KEY_SIZE]| P384PublicKey::from_bytes(p),
|
|
||||||
|sid| {
|
|
||||||
println!("[noise] [bob] session ID: {}", u64::from(sid));
|
|
||||||
if let Some(bob) = bob.as_ref() {
|
|
||||||
if sid == bob.id {
|
|
||||||
Some(bob.clone())
|
|
||||||
} else {
|
} else {
|
||||||
None
|
alice_host.queue.lock().push_front(data.to_vec());
|
||||||
}
|
}
|
||||||
} else {
|
};
|
||||||
None
|
|
||||||
}
|
loop {
|
||||||
},
|
if let Some(qi) = host.queue.lock().pop_back() {
|
||||||
|_: &[u8; P384_PUBLIC_KEY_SIZE]| {
|
let qi_len = qi.len();
|
||||||
if bob.is_none() {
|
ts += 1;
|
||||||
Some((SessionId::new_random(), psk.clone(), 0))
|
let r = rc.receive(host, send_to_other, &mut data_buf, qi, 1280, true, ts);
|
||||||
} else {
|
if r.is_ok() {
|
||||||
panic!("[noise] [bob] Bob received a second new session request from Alice");
|
let r = r.unwrap();
|
||||||
}
|
|
||||||
},
|
|
||||||
0,
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
if let Ok(r) = r {
|
|
||||||
match r {
|
match r {
|
||||||
ReceiveResult::OkData(data, counter) => {
|
|
||||||
println!("[noise] [bob] DATA len=={} counter=={}", data.len(), counter);
|
|
||||||
}
|
|
||||||
ReceiveResult::OkSendReply(p) => {
|
|
||||||
println!("[noise] [bob] OK (reply: {} bytes)", p.len());
|
|
||||||
from_bob.push(p.to_vec());
|
|
||||||
}
|
|
||||||
ReceiveResult::OkNewSession(ns, p) => {
|
|
||||||
if bob.is_some() {
|
|
||||||
panic!("[noise] [bob] attempt to create new session on Bob's side when he already has one");
|
|
||||||
}
|
|
||||||
let id: u64 = ns.id.into();
|
|
||||||
let _ = bob.replace(Rc::new(ns));
|
|
||||||
from_bob.push(p.to_vec());
|
|
||||||
println!("[noise] [bob] NEW SESSION {}", id);
|
|
||||||
}
|
|
||||||
ReceiveResult::Ok => {
|
ReceiveResult::Ok => {
|
||||||
println!("[noise] [bob] OK");
|
println!("zssp: {} => {} ({}): Ok", host.other_name, host.this_name, qi_len);
|
||||||
}
|
}
|
||||||
ReceiveResult::Duplicate => {
|
ReceiveResult::OkData(data) => {
|
||||||
println!("[noise] [bob] duplicate packet");
|
println!("zssp: {} => {} ({}): OkData length=={}", host.other_name, host.this_name, qi_len, data.len());
|
||||||
|
}
|
||||||
|
ReceiveResult::OkNewSession(new_session) => {
|
||||||
|
println!("zssp: {} => {} ({}): OkNewSession ({})", host.other_name, host.this_name, qi_len, u64::from(new_session.id));
|
||||||
|
let mut hs = host.session.lock();
|
||||||
|
assert!(hs.is_none());
|
||||||
|
let _ = hs.insert(Arc::new(new_session));
|
||||||
}
|
}
|
||||||
ReceiveResult::Ignored => {
|
ReceiveResult::Ignored => {
|
||||||
println!("[noise] [bob] ignored packet");
|
println!("zssp: {} => {} ({}): Ignored", host.other_name, host.this_name, qi_len);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
println!("ERROR (bob): {}", r.err().unwrap().to_string());
|
println!("zssp: {} => {}: error: {}", host.other_name, host.this_name, r.err().unwrap().to_string());
|
||||||
panic!();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(packet) = from_bob.pop() {
|
|
||||||
let r = Session::receive(
|
|
||||||
packet.as_slice(),
|
|
||||||
&mut b_buffer,
|
|
||||||
&alice_static_keypair,
|
|
||||||
&outgoing_obfuscator_to_alice,
|
|
||||||
|p: &[u8; P384_PUBLIC_KEY_SIZE]| P384PublicKey::from_bytes(p),
|
|
||||||
|sid| {
|
|
||||||
println!("[noise] [alice] session ID: {}", u64::from(sid));
|
|
||||||
if sid == alice.id {
|
|
||||||
Some(alice.clone())
|
|
||||||
} else {
|
|
||||||
panic!("[noise] [alice] received from Bob addressed to unknown session ID, not Alice");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|_: &[u8; P384_PUBLIC_KEY_SIZE]| {
|
|
||||||
panic!("[noise] [alice] Alice received an unexpected new session request from Bob");
|
|
||||||
},
|
|
||||||
0,
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
if let Ok(r) = r {
|
|
||||||
match r {
|
|
||||||
ReceiveResult::OkData(data, counter) => {
|
|
||||||
println!("[noise] [alice] DATA len=={} counter=={}", data.len(), counter);
|
|
||||||
}
|
|
||||||
ReceiveResult::OkSendReply(p) => {
|
|
||||||
println!("[noise] [alice] OK (reply: {} bytes)", p.len());
|
|
||||||
from_alice.push(p.to_vec());
|
|
||||||
}
|
|
||||||
ReceiveResult::OkNewSession(_, _) => {
|
|
||||||
panic!("[noise] [alice] attempt to create new session on Alice's side; Bob should not initiate");
|
|
||||||
}
|
|
||||||
ReceiveResult::Ok => {
|
|
||||||
println!("[noise] [alice] OK");
|
|
||||||
}
|
|
||||||
ReceiveResult::Duplicate => {
|
|
||||||
println!("[noise] [alice] duplicate packet");
|
|
||||||
}
|
|
||||||
ReceiveResult::Ignored => {
|
|
||||||
println!("[noise] [alice] ignored packet");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
println!("ERROR (alice): {}", r.err().unwrap().to_string());
|
break;
|
||||||
panic!();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (random::next_u32_secure() & 1) == 0 {
|
|
||||||
from_alice.push(alice.send(&mut a_buffer, &[0_u8; 16]).unwrap().to_vec());
|
|
||||||
} else if bob.is_some() {
|
|
||||||
from_bob.push(bob.as_ref().unwrap().send(&mut b_buffer, &[0_u8; 16]).unwrap().to_vec());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
|
@ -11,16 +11,16 @@ use crate::arrayvec::ArrayVec;
|
||||||
pub struct GatherArray<T, const C: usize> {
|
pub struct GatherArray<T, const C: usize> {
|
||||||
a: [MaybeUninit<T>; C],
|
a: [MaybeUninit<T>; C],
|
||||||
have_bits: u64,
|
have_bits: u64,
|
||||||
have_count: u32,
|
have_count: u8,
|
||||||
goal: u32,
|
goal: u8,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, const C: usize> GatherArray<T, C> {
|
impl<T, const C: usize> GatherArray<T, C> {
|
||||||
/// Create a new gather array, which must be initialized prior to use.
|
/// Create a new gather array, which must be initialized prior to use.
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn new(goal: u32) -> Self {
|
pub fn new(goal: u8) -> Self {
|
||||||
assert!(C <= 64);
|
assert!(C <= 64);
|
||||||
assert!(goal <= (C as u32));
|
assert!(goal <= (C as u8));
|
||||||
assert_eq!(size_of::<[T; C]>(), size_of::<[MaybeUninit<T>; C]>());
|
assert_eq!(size_of::<[T; C]>(), size_of::<[MaybeUninit<T>; C]>());
|
||||||
Self {
|
Self {
|
||||||
a: unsafe { MaybeUninit::uninit().assume_init() },
|
a: unsafe { MaybeUninit::uninit().assume_init() },
|
||||||
|
@ -32,10 +32,10 @@ impl<T, const C: usize> GatherArray<T, C> {
|
||||||
|
|
||||||
/// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here.
|
/// Add an item to the array if we don't have this index anymore, returning complete array if all parts are here.
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn add(&mut self, index: u32, value: T) -> Option<ArrayVec<T, C>> {
|
pub fn add(&mut self, index: u8, value: T) -> Option<ArrayVec<T, C>> {
|
||||||
if index < self.goal {
|
if index < self.goal {
|
||||||
let mut have = self.have_bits;
|
let mut have = self.have_bits;
|
||||||
let got = 1u64.wrapping_shl(index);
|
let got = 1u64.wrapping_shl(index as u32);
|
||||||
if (have & got) == 0 {
|
if (have & got) == 0 {
|
||||||
have |= got;
|
have |= got;
|
||||||
self.have_bits = have;
|
self.have_bits = have;
|
||||||
|
@ -64,7 +64,7 @@ impl<T, const C: usize> Drop for GatherArray<T, C> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let have = self.have_bits;
|
let have = self.have_bits;
|
||||||
for i in 0..self.goal {
|
for i in 0..self.goal {
|
||||||
if (have & 1u64.wrapping_shl(i)) != 0 {
|
if (have & 1u64.wrapping_shl(i as u32)) != 0 {
|
||||||
unsafe { self.a.get_unchecked_mut(i as usize).assume_init_drop() };
|
unsafe { self.a.get_unchecked_mut(i as usize).assume_init_drop() };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -78,8 +78,8 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gather_array() {
|
fn gather_array() {
|
||||||
for goal in 2..64 {
|
for goal in 2u8..64u8 {
|
||||||
let mut m = GatherArray::<u32, 64>::new(goal);
|
let mut m = GatherArray::<u8, 64>::new(goal);
|
||||||
for x in 0..(goal - 1) {
|
for x in 0..(goal - 1) {
|
||||||
assert!(m.add(x, x).is_none());
|
assert!(m.add(x, x).is_none());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod arrayvec;
|
pub mod arrayvec;
|
||||||
pub mod gatherarray;
|
pub mod gatherarray;
|
||||||
pub mod hex;
|
pub mod hex;
|
||||||
|
pub mod memory;
|
||||||
pub mod ringbuffermap;
|
pub mod ringbuffermap;
|
||||||
pub mod varint;
|
pub mod varint;
|
||||||
|
|
143
utils/src/memory.rs
Normal file
143
utils/src/memory.rs
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
// (c) 2020-2022 ZeroTier, Inc. -- currently propritery pending actual release and licensing. See LICENSE.md.
|
||||||
|
|
||||||
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
mod fast_int_memory_access {
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u64_from_le_bytes(b: &[u8]) -> u64 {
|
||||||
|
assert!(b.len() >= 8);
|
||||||
|
unsafe { *b.as_ptr().cast() }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u32_from_le_bytes(b: &[u8]) -> u32 {
|
||||||
|
assert!(b.len() >= 4);
|
||||||
|
unsafe { *b.as_ptr().cast() }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u16_from_le_bytes(b: &[u8]) -> u16 {
|
||||||
|
assert!(b.len() >= 2);
|
||||||
|
unsafe { *b.as_ptr().cast() }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i64_from_le_bytes(b: &[u8]) -> i64 {
|
||||||
|
assert!(b.len() >= 8);
|
||||||
|
unsafe { *b.as_ptr().cast() }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i32_from_le_bytes(b: &[u8]) -> i32 {
|
||||||
|
assert!(b.len() >= 4);
|
||||||
|
unsafe { *b.as_ptr().cast() }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i16_from_le_bytes(b: &[u8]) -> i16 {
|
||||||
|
assert!(b.len() >= 2);
|
||||||
|
unsafe { *b.as_ptr().cast() }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
|
||||||
|
assert!(b.len() >= 8);
|
||||||
|
unsafe { *b.as_ptr().cast::<u64>() }.swap_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u32_from_be_bytes(b: &[u8]) -> u32 {
|
||||||
|
assert!(b.len() >= 4);
|
||||||
|
unsafe { *b.as_ptr().cast::<u32>() }.swap_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u16_from_be_bytes(b: &[u8]) -> u16 {
|
||||||
|
assert!(b.len() >= 2);
|
||||||
|
unsafe { *b.as_ptr().cast::<u16>() }.swap_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i64_from_be_bytes(b: &[u8]) -> i64 {
|
||||||
|
assert!(b.len() >= 8);
|
||||||
|
unsafe { *b.as_ptr().cast::<i64>() }.swap_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i32_from_be_bytes(b: &[u8]) -> i32 {
|
||||||
|
assert!(b.len() >= 4);
|
||||||
|
unsafe { *b.as_ptr().cast::<i32>() }.swap_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i16_from_be_bytes(b: &[u8]) -> i16 {
|
||||||
|
assert!(b.len() >= 2);
|
||||||
|
unsafe { *b.as_ptr().cast::<i16>() }.swap_bytes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
|
||||||
|
#[allow(unused)]
|
||||||
|
mod fast_int_memory_access {
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u64_from_le_bytes(b: &[u8]) -> u64 {
|
||||||
|
u64::from_le_bytes(b[..8].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u32_from_le_bytes(b: &[u8]) -> u32 {
|
||||||
|
u32::from_le_bytes(b[..4].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u16_from_le_bytes(b: &[u8]) -> u16 {
|
||||||
|
u16::from_le_bytes(b[..2].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i64_from_le_bytes(b: &[u8]) -> i64 {
|
||||||
|
i64::from_le_bytes(b[..8].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i32_from_le_bytes(b: &[u8]) -> i32 {
|
||||||
|
i32::from_le_bytes(b[..4].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i16_from_le_bytes(b: &[u8]) -> i16 {
|
||||||
|
i16::from_le_bytes(b[..2].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u64_from_be_bytes(b: &[u8]) -> u64 {
|
||||||
|
u64::from_be_bytes(b[..8].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u32_from_be_bytes(b: &[u8]) -> u32 {
|
||||||
|
u32::from_be_bytes(b[..4].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn u16_from_be_bytes(b: &[u8]) -> u16 {
|
||||||
|
u16::from_be_bytes(b[..2].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i64_from_be_bytes(b: &[u8]) -> i64 {
|
||||||
|
i64::from_be_bytes(b[..8].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i32_from_be_bytes(b: &[u8]) -> i32 {
|
||||||
|
i32::from_be_bytes(b[..4].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn i16_from_be_bytes(b: &[u8]) -> i16 {
|
||||||
|
i16::from_be_bytes(b[..2].try_into().unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub use fast_int_memory_access::*;
|
|
@ -106,13 +106,15 @@ impl<K: Eq + PartialEq + Hash + Clone, V, const C: usize, const B: usize> RingBu
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new(salt: u32) -> Self {
|
pub fn new(salt: u32) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: std::array::from_fn(|_| Entry::<K, V> {
|
entries: {
|
||||||
key: MaybeUninit::uninit(),
|
let mut entries: [Entry<K, V>; C] = unsafe { MaybeUninit::uninit().assume_init() };
|
||||||
value: MaybeUninit::uninit(),
|
for e in entries.iter_mut() {
|
||||||
bucket: -1,
|
e.bucket = -1;
|
||||||
next: -1,
|
e.next = -1;
|
||||||
prev: -1,
|
e.prev = -1;
|
||||||
}),
|
}
|
||||||
|
entries
|
||||||
|
},
|
||||||
buckets: [-1; B],
|
buckets: [-1; B],
|
||||||
entry_ptr: 0,
|
entry_ptr: 0,
|
||||||
salt,
|
salt,
|
||||||
|
|
Loading…
Add table
Reference in a new issue