From f66a2a7ef90f3909f5ab06cb62ae3464d9ecc7f1 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Fri, 10 Mar 2023 17:03:22 -0500 Subject: [PATCH] Tetanus adam (#1906) * Move some stuff around in prep for a VL2 rework and identity rework. * Mix ephemeral keys into "h" * More topology stuff for VL2. * Simplify key queue, fix macOS issues with bindings, and no need to cache PSK forever. * Some more merge fixes. * A bunch of ZSSP cleanup and optimization. Runs a bit faster now. --- controller/src/controller.rs | 5 +- controller/src/model/mod.rs | 2 +- controller/src/model/network.rs | 3 +- controller/src/postgresdatabase.rs | 3 +- crypto/src/aes_fruity.rs | 52 +- crypto/src/aes_openssl.rs | 1 + crypto/src/cipher_ctx.rs | 1 + crypto/src/lib.rs | 5 +- network-hypervisor/src/vl2/iproute.rs | 17 + network-hypervisor/src/vl2/mod.rs | 6 +- network-hypervisor/src/vl2/topology.rs | 61 ++ network-hypervisor/src/vl2/v1/mod.rs | 3 + .../src/vl2/{ => v1}/networkconfig.rs | 22 +- .../src/vl2/{ => v1}/revocation.rs | 0 utils/src/blob.rs | 24 +- utils/src/flatsortedmap.rs | 86 +++ utils/src/lib.rs | 1 + zssp/src/main.rs | 8 +- zssp/src/sessionid.rs | 27 +- zssp/src/zssp.rs | 690 +++++++++--------- 20 files changed, 610 insertions(+), 407 deletions(-) create mode 100644 network-hypervisor/src/vl2/iproute.rs create mode 100644 network-hypervisor/src/vl2/topology.rs rename network-hypervisor/src/vl2/{ => v1}/networkconfig.rs (96%) rename network-hypervisor/src/vl2/{ => v1}/revocation.rs (100%) create mode 100644 utils/src/flatsortedmap.rs diff --git a/controller/src/controller.rs b/controller/src/controller.rs index 39556d07e..c6f032df5 100644 --- a/controller/src/controller.rs +++ b/controller/src/controller.rs @@ -12,8 +12,9 @@ use zerotier_network_hypervisor::protocol::{PacketBuffer, DEFAULT_MULTICAST_LIMI use zerotier_network_hypervisor::vl1::*; use zerotier_network_hypervisor::vl2; use zerotier_network_hypervisor::vl2::multicastauthority::MulticastAuthority; -use zerotier_network_hypervisor::vl2::networkconfig::*; -use zerotier_network_hypervisor::vl2::{NetworkId, Revocation}; +use zerotier_network_hypervisor::vl2::v1::networkconfig::*; +use zerotier_network_hypervisor::vl2::v1::Revocation; +use zerotier_network_hypervisor::vl2::NetworkId; use zerotier_utils::blob::Blob; use zerotier_utils::buffer::OutOfBoundsError; use zerotier_utils::error::InvalidParameterError; diff --git a/controller/src/model/mod.rs b/controller/src/model/mod.rs index ff519c285..93cf11884 100644 --- a/controller/src/model/mod.rs +++ b/controller/src/model/mod.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use zerotier_network_hypervisor::vl1::{Address, Endpoint}; -use zerotier_network_hypervisor::vl2::networkconfig::NetworkConfig; +use zerotier_network_hypervisor::vl2::v1::networkconfig::NetworkConfig; use zerotier_network_hypervisor::vl2::NetworkId; use zerotier_utils::blob::Blob; diff --git a/controller/src/model/network.rs b/controller/src/model/network.rs index f44b7f3d6..6a4e21a11 100644 --- a/controller/src/model/network.rs +++ b/controller/src/model/network.rs @@ -6,9 +6,8 @@ use std::hash::Hash; use serde::{Deserialize, Serialize}; use zerotier_network_hypervisor::vl1::InetAddress; -use zerotier_network_hypervisor::vl2::networkconfig::IpRoute; use zerotier_network_hypervisor::vl2::rule::Rule; -use zerotier_network_hypervisor::vl2::NetworkId; +use zerotier_network_hypervisor::vl2::{IpRoute, NetworkId}; use crate::database::Database; use crate::model::Member; diff --git a/controller/src/postgresdatabase.rs b/controller/src/postgresdatabase.rs index 379b44a6e..8c30a9315 100644 --- a/controller/src/postgresdatabase.rs +++ b/controller/src/postgresdatabase.rs @@ -13,9 +13,8 @@ use zerotier_crypto::secure_eq; use zerotier_crypto::typestate::Valid; use zerotier_network_hypervisor::vl1::{Address, Identity, InetAddress}; -use zerotier_network_hypervisor::vl2::networkconfig::IpRoute; use zerotier_network_hypervisor::vl2::rule::Rule; -use zerotier_network_hypervisor::vl2::NetworkId; +use zerotier_network_hypervisor::vl2::{IpRoute, NetworkId}; use zerotier_utils::futures_util::{Stream, StreamExt}; use zerotier_utils::tokio; diff --git a/crypto/src/aes_fruity.rs b/crypto/src/aes_fruity.rs index 01b505b49..5a0d5355c 100644 --- a/crypto/src/aes_fruity.rs +++ b/crypto/src/aes_fruity.rs @@ -3,7 +3,7 @@ // MacOS implementation of AES primitives since CommonCrypto seems to be faster than OpenSSL, especially on ARM64. use std::os::raw::{c_int, c_void}; use std::ptr::{null, null_mut}; -use std::sync::Mutex; +use std::sync::atomic::AtomicPtr; use crate::secret::Secret; use crate::secure_eq; @@ -172,14 +172,26 @@ impl AesGcm { } } -pub struct Aes(Mutex<*mut c_void>, Mutex<*mut c_void>); +pub struct Aes(AtomicPtr, AtomicPtr); impl Drop for Aes { #[inline(always)] fn drop(&mut self) { unsafe { - CCCryptorRelease(*self.0.lock().unwrap()); - CCCryptorRelease(*self.1.lock().unwrap()); + loop { + let p = self.0.load(std::sync::atomic::Ordering::Acquire); + if !p.is_null() { + CCCryptorRelease(p); + break; + } + } + loop { + let p = self.1.load(std::sync::atomic::Ordering::Acquire); + if !p.is_null() { + CCCryptorRelease(p); + break; + } + } } } } @@ -191,7 +203,7 @@ impl Aes { KEY_SIZE == 32 || KEY_SIZE == 24 || KEY_SIZE == 16, "AES supports 128, 192, or 256 bits keys" ); - let aes: Self = std::mem::zeroed(); + let (mut p0, mut p1) = (null_mut(), null_mut()); assert_eq!( CCCryptorCreateWithMode( kCCEncrypt, @@ -205,7 +217,7 @@ impl Aes { 0, 0, kCCOptionECBMode, - &mut *aes.0.lock().unwrap() + &mut p0, ), 0 ); @@ -222,11 +234,11 @@ impl Aes { 0, 0, kCCOptionECBMode, - &mut *aes.1.lock().unwrap() + &mut p1, ), 0 ); - aes + Self(AtomicPtr::new(p0), AtomicPtr::new(p1)) } } @@ -235,8 +247,16 @@ impl Aes { assert_eq!(data.len(), 16); unsafe { let mut data_out_written = 0; - let e = self.0.lock().unwrap(); - CCCryptorUpdate(*e, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + loop { + let p = self.0.load(std::sync::atomic::Ordering::Acquire); + if !p.is_null() { + CCCryptorUpdate(p, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + self.0.store(p, std::sync::atomic::Ordering::Release); + break; + } else { + std::thread::yield_now(); + } + } } } @@ -245,8 +265,16 @@ impl Aes { assert_eq!(data.len(), 16); unsafe { let mut data_out_written = 0; - let d = self.1.lock().unwrap(); - CCCryptorUpdate(*d, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + loop { + let p = self.1.load(std::sync::atomic::Ordering::Acquire); + if !p.is_null() { + CCCryptorUpdate(p, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + self.1.store(p, std::sync::atomic::Ordering::Release); + break; + } else { + std::thread::yield_now(); + } + } } } } diff --git a/crypto/src/aes_openssl.rs b/crypto/src/aes_openssl.rs index 6fb4393a9..f0b4fe552 100644 --- a/crypto/src/aes_openssl.rs +++ b/crypto/src/aes_openssl.rs @@ -125,6 +125,7 @@ impl Aes { let ptr = data.as_mut_ptr(); unsafe { self.0.update::(data, ptr).unwrap() } } + /// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls. #[inline(always)] pub fn decrypt_block_in_place(&self, data: &mut [u8]) { diff --git a/crypto/src/cipher_ctx.rs b/crypto/src/cipher_ctx.rs index 94474cfaa..f5cba9441 100644 --- a/crypto/src/cipher_ctx.rs +++ b/crypto/src/cipher_ctx.rs @@ -109,6 +109,7 @@ impl CipherCtxRef { } /// Sets the authentication tag for verification during decryption. + #[allow(unused)] pub fn set_tag(&self, tag: &[u8]) -> Result<(), ErrorStack> { unsafe { cvt(ffi::EVP_CIPHER_CTX_ctrl( diff --git a/crypto/src/lib.rs b/crypto/src/lib.rs index 1101105b0..56d47b4be 100644 --- a/crypto/src/lib.rs +++ b/crypto/src/lib.rs @@ -14,10 +14,13 @@ pub mod salsa; pub mod typestate; pub mod x25519; +#[cfg(target_os = "macos")] pub mod aes_fruity; -pub mod aes_openssl; #[cfg(target_os = "macos")] pub use aes_fruity as aes; + +#[cfg(not(target_os = "macos"))] +pub mod aes_openssl; #[cfg(not(target_os = "macos"))] pub use aes_openssl as aes; diff --git a/network-hypervisor/src/vl2/iproute.rs b/network-hypervisor/src/vl2/iproute.rs new file mode 100644 index 000000000..529435485 --- /dev/null +++ b/network-hypervisor/src/vl2/iproute.rs @@ -0,0 +1,17 @@ +use crate::vl1::InetAddress; +use serde::{Deserialize, Serialize}; + +/// ZeroTier-managed L3 route on a virtual network. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct IpRoute { + pub target: InetAddress, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub via: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub flags: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub metric: Option, +} diff --git a/network-hypervisor/src/vl2/mod.rs b/network-hypervisor/src/vl2/mod.rs index e692a39fa..672ae41b6 100644 --- a/network-hypervisor/src/vl2/mod.rs +++ b/network-hypervisor/src/vl2/mod.rs @@ -1,16 +1,16 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently proprietary pending actual release and licensing. See LICENSE.md. +mod iproute; mod multicastgroup; mod networkid; -mod revocation; mod switch; +mod topology; pub mod multicastauthority; -pub mod networkconfig; pub mod rule; pub mod v1; +pub use iproute::IpRoute; pub use multicastgroup::MulticastGroup; pub use networkid::NetworkId; -pub use revocation::Revocation; pub use switch::{Switch, SwitchInterface}; diff --git a/network-hypervisor/src/vl2/topology.rs b/network-hypervisor/src/vl2/topology.rs new file mode 100644 index 000000000..3defd1233 --- /dev/null +++ b/network-hypervisor/src/vl2/topology.rs @@ -0,0 +1,61 @@ +use std::borrow::Cow; + +use zerotier_utils::blob::Blob; +use zerotier_utils::flatsortedmap::FlatSortedMap; + +use serde::{Deserialize, Serialize}; + +use crate::vl1::identity::IDENTITY_FINGERPRINT_SIZE; +use crate::vl1::inetaddress::InetAddress; +use crate::vl2::rule::Rule; + +#[derive(Serialize, Deserialize, Eq, PartialEq, Clone)] +pub struct Member<'a> { + #[serde(skip_serializing_if = "u64_zero")] + #[serde(default)] + pub flags: u64, + + #[serde(skip_serializing_if = "cow_str_is_empty")] + #[serde(default)] + pub name: Cow<'a, str>, +} + +#[derive(Serialize, Deserialize, Eq, PartialEq, Clone)] +pub struct Topology<'a> { + pub timestamp: i64, + + #[serde(skip_serializing_if = "cow_str_is_empty")] + #[serde(default)] + pub name: Cow<'a, str>, + + #[serde(skip_serializing_if = "slice_is_empty")] + #[serde(default)] + pub rules: Cow<'a, [Rule]>, + + #[serde(skip_serializing_if = "FlatSortedMap::is_empty")] + #[serde(default)] + pub dns_resolvers: FlatSortedMap<'a, Cow<'a, str>, InetAddress>, + + #[serde(skip_serializing_if = "FlatSortedMap::is_empty")] + #[serde(default)] + pub dns_names: FlatSortedMap<'a, Cow<'a, str>, InetAddress>, + + #[serde(skip_serializing_if = "FlatSortedMap::is_empty")] + #[serde(default)] + pub members: FlatSortedMap<'a, Blob, Member<'a>>, +} + +#[inline(always)] +fn u64_zero(i: &u64) -> bool { + *i == 0 +} + +#[inline(always)] +fn cow_str_is_empty<'a>(s: &Cow<'a, str>) -> bool { + s.is_empty() +} + +#[inline(always)] +fn slice_is_empty>(x: &S) -> bool { + x.as_ref().is_empty() +} diff --git a/network-hypervisor/src/vl2/v1/mod.rs b/network-hypervisor/src/vl2/v1/mod.rs index b18767fb0..f96ff95a8 100644 --- a/network-hypervisor/src/vl2/v1/mod.rs +++ b/network-hypervisor/src/vl2/v1/mod.rs @@ -1,5 +1,7 @@ mod certificateofmembership; mod certificateofownership; +pub mod networkconfig; +mod revocation; mod tag; #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -15,4 +17,5 @@ pub enum CredentialType { pub use certificateofmembership::CertificateOfMembership; pub use certificateofownership::{CertificateOfOwnership, Thing}; +pub use revocation::Revocation; pub use tag::Tag; diff --git a/network-hypervisor/src/vl2/networkconfig.rs b/network-hypervisor/src/vl2/v1/networkconfig.rs similarity index 96% rename from network-hypervisor/src/vl2/networkconfig.rs rename to network-hypervisor/src/vl2/v1/networkconfig.rs index 04c1d6435..0e114fc0a 100644 --- a/network-hypervisor/src/vl2/networkconfig.rs +++ b/network-hypervisor/src/vl2/v1/networkconfig.rs @@ -7,6 +7,7 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; use crate::vl1::{Address, Identity, InetAddress}; +use crate::vl2::iproute::IpRoute; use crate::vl2::rule::Rule; use crate::vl2::v1::{CertificateOfMembership, CertificateOfOwnership, Tag}; use crate::vl2::NetworkId; @@ -30,11 +31,6 @@ pub struct NetworkConfig { #[serde(default)] pub name: String, - /// A human-readable message for members of this network (V2 only) - #[serde(skip_serializing_if = "String::is_empty")] - #[serde(default)] - pub motd: String, - /// True if network has access control (the default) pub private: bool, @@ -94,7 +90,6 @@ impl NetworkConfig { network_id, issued_to, name: String::new(), - motd: String::new(), private: true, timestamp: 0, mtu: 0, @@ -436,21 +431,6 @@ pub struct V1Credentials { pub tags: HashMap, } -/// Statically pushed L3 IP routes included with a network configuration. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct IpRoute { - pub target: InetAddress, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(default)] - pub via: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(default)] - pub flags: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(default)] - pub metric: Option, -} - impl Marshalable for IpRoute { const MAX_MARSHAL_SIZE: usize = (InetAddress::MAX_MARSHAL_SIZE * 2) + 2 + 2; diff --git a/network-hypervisor/src/vl2/revocation.rs b/network-hypervisor/src/vl2/v1/revocation.rs similarity index 100% rename from network-hypervisor/src/vl2/revocation.rs rename to network-hypervisor/src/vl2/v1/revocation.rs diff --git a/utils/src/blob.rs b/utils/src/blob.rs index b02f7902b..660b35947 100644 --- a/utils/src/blob.rs +++ b/utils/src/blob.rs @@ -7,6 +7,7 @@ */ use std::fmt::Debug; +use std::hash::Hash; use serde::ser::SerializeTuple; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -72,6 +73,27 @@ impl ToString for Blob { } } +impl PartialOrd for Blob { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Ord for Blob { + #[inline(always)] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} + +impl Hash for Blob { + #[inline(always)] + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + impl Debug for Blob { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -118,7 +140,7 @@ impl<'de, const L: usize> serde::de::Visitor<'de> for BlobVisitor { impl<'de, const L: usize> Deserialize<'de> for Blob { #[inline] - fn deserialize(deserializer: D) -> Result, D::Error> + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { diff --git a/utils/src/flatsortedmap.rs b/utils/src/flatsortedmap.rs new file mode 100644 index 000000000..d05e62b3e --- /dev/null +++ b/utils/src/flatsortedmap.rs @@ -0,0 +1,86 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + * + * (c) ZeroTier, Inc. + * https://www.zerotier.com/ + */ + +use std::borrow::Cow; +use std::iter::{FromIterator, Iterator}; + +use serde::{Deserialize, Serialize}; + +/// A simple flat sorted map backed by a vector and binary search. +/// +/// This doesn't support gradual adding of keys or removal of keys, but only construction +/// from an iterator of keys and values. It also implements Serialize and Deserialize and +/// is mainly intended for memory and space efficient serializable lookup tables. +/// +/// If the iterator supplies more than one key with different values, which of these is +/// included is undefined. +#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)] +#[repr(transparent)] +pub struct FlatSortedMap<'a, K: Eq + Ord + Clone, V: Clone>(Cow<'a, [(K, V)]>); + +impl<'a, K: Eq + Ord + Clone, V: Clone> FromIterator<(K, V)> for FlatSortedMap<'a, K, V> { + #[inline] + fn from_iter>(iter: T) -> Self { + let mut tmp = Vec::from_iter(iter); + tmp.sort_unstable_by(|a, b| a.0.cmp(&b.0)); + tmp.dedup_by(|a, b| a.0.eq(&b.0)); + Self(Cow::Owned(tmp)) + } +} + +impl<'a, K: Eq + Ord + Clone, V: Clone> Default for FlatSortedMap<'a, K, V> { + #[inline(always)] + fn default() -> Self { + Self(Cow::Owned(Vec::new())) + } +} + +impl<'a, K: Eq + Ord + Clone, V: Clone> FlatSortedMap<'a, K, V> { + #[inline] + pub fn get(&self, k: &K) -> Option<&V> { + if let Ok(idx) = self.0.binary_search_by(|a| a.0.cmp(k)) { + Some(unsafe { &self.0.get_unchecked(idx).1 }) + } else { + None + } + } + + #[inline] + pub fn contains(&self, k: &K) -> bool { + self.0.binary_search_by(|a| a.0.cmp(k)).is_ok() + } + + /// Returns true if this map is valid, meaning that it contains only one of each key and is sorted. + #[inline] + pub fn is_valid(&self) -> bool { + let l = self.0.len(); + if l > 1 { + for i in 1..l { + if unsafe { !self.0.get_unchecked(i - 1).0.cmp(&self.0.get_unchecked(i).0).is_lt() } { + return false; + } + } + } + return true; + } + + #[inline(always)] + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.0.len() + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} diff --git a/utils/src/lib.rs b/utils/src/lib.rs index 8f211d91e..cb1e278e9 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -14,6 +14,7 @@ pub mod dictionary; pub mod error; #[allow(unused)] pub mod exitcode; +pub mod flatsortedmap; pub mod gate; pub mod gatherarray; pub mod hex; diff --git a/zssp/src/main.rs b/zssp/src/main.rs index 402505e06..8a15beee3 100644 --- a/zssp/src/main.rs +++ b/zssp/src/main.rs @@ -46,7 +46,7 @@ fn alice_main( alice_out: mpsc::SyncSender>, alice_in: mpsc::Receiver>, ) { - let context = zssp::Context::::new(16); + let context = zssp::Context::::new(16, TEST_MTU); let mut data_buf = [0u8; 65536]; let mut next_service = ms_monotonic() + 500; let mut last_ratchet_count = 0; @@ -88,7 +88,6 @@ fn alice_main( &0, &mut data_buf, pkt, - TEST_MTU, current_time, ) { Ok(zssp::ReceiveResult::Ok(_)) => { @@ -144,7 +143,6 @@ fn alice_main( |_, b| { let _ = alice_out.send(b.to_vec()); }, - TEST_MTU, current_time, ); } @@ -159,7 +157,7 @@ fn bob_main( bob_out: mpsc::SyncSender>, bob_in: mpsc::Receiver>, ) { - let context = zssp::Context::::new(16); + let context = zssp::Context::::new(16, TEST_MTU); let mut data_buf = [0u8; 65536]; let mut data_buf_2 = [0u8; TEST_MTU]; let mut last_ratchet_count = 0; @@ -186,7 +184,6 @@ fn bob_main( &0, &mut data_buf, pkt, - TEST_MTU, current_time, ) { Ok(zssp::ReceiveResult::Ok(_)) => { @@ -246,7 +243,6 @@ fn bob_main( |_, b| { let _ = bob_out.send(b.to_vec()); }, - TEST_MTU, current_time, ); } diff --git a/zssp/src/sessionid.rs b/zssp/src/sessionid.rs index be272034c..e1fb99b64 100644 --- a/zssp/src/sessionid.rs +++ b/zssp/src/sessionid.rs @@ -10,7 +10,6 @@ use std::fmt::Display; use std::num::NonZeroU64; use zerotier_crypto::random; -use zerotier_utils::memory::{array_range, as_byte_array}; /// 48-bit session ID (most significant 16 bits of u64 are unused) #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -25,6 +24,7 @@ impl SessionId { pub const MAX: u64 = 0xffffffffffff; /// Create a new session ID, panicing if 'i' is zero or exceeds MAX. + #[inline(always)] pub fn new(i: u64) -> SessionId { assert!(i <= Self::MAX); Self(NonZeroU64::new(i.to_le()).unwrap()) @@ -35,22 +35,23 @@ impl SessionId { Self(NonZeroU64::new(((random::xorshift64_random() % (Self::MAX - 1)) + 1).to_le()).unwrap()) } - pub(crate) fn new_from_bytes(b: &[u8; Self::SIZE]) -> Option { - let mut tmp = [0u8; 8]; + #[inline(always)] + pub fn to_bytes(&self) -> [u8; Self::SIZE] { + self.0.get().to_ne_bytes()[..Self::SIZE].try_into().unwrap() + } + + #[inline(always)] + pub fn new_from_bytes(b: &[u8]) -> Option { + let mut tmp = 0u64.to_ne_bytes(); tmp[..SESSION_ID_SIZE_BYTES].copy_from_slice(b); - Self::new_from_u64_le(u64::from_ne_bytes(tmp)) + NonZeroU64::new(u64::from_ne_bytes(tmp)).map(|i| Self(i)) } - /// Create from a u64 that is already in little-endian byte order. #[inline(always)] - pub(crate) fn new_from_u64_le(i: u64) -> Option { - NonZeroU64::new(i & Self::MAX.to_le()).map(|i| Self(i)) - } - - /// Get this session ID as a little-endian byte array. - #[inline(always)] - pub(crate) fn as_bytes(&self) -> &[u8; Self::SIZE] { - array_range::(as_byte_array(&self.0)) + pub fn new_from_array(b: &[u8; Self::SIZE]) -> Option { + let mut tmp = 0u64.to_ne_bytes(); + tmp[..SESSION_ID_SIZE_BYTES].copy_from_slice(b); + NonZeroU64::new(u64::from_ne_bytes(tmp)).map(|i| Self(i)) } } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 1f7ba7802..ee7842e35 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -11,7 +11,7 @@ use std::collections::{HashMap, HashSet}; use std::num::NonZeroU64; -use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicI64, AtomicU64, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; use zerotier_crypto::aes::{Aes, AesGcm}; @@ -28,12 +28,16 @@ use crate::fragged::Fragged; use crate::proto::*; use crate::sessionid::SessionId; +/// Number of GCM ciphers to pool for send/receive concurrency. +const GCM_CIPHER_POOL_SIZE: usize = 4; + /// Session context for local application. /// /// Each application using ZSSP must create an instance of this to own sessions and /// defragment incoming packets that are not yet associated with a session. pub struct Context { max_incomplete_session_queue_size: usize, + default_physical_mtu: AtomicUsize, defrag: Mutex< HashMap< (Application::PhysicalPath, u64), @@ -52,7 +56,7 @@ struct SessionsById { active: HashMap>>, // Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout. - incoming: HashMap>, + incoming: HashMap>>, } /// Result generated by the context packet receive function, with possible payloads. @@ -80,7 +84,6 @@ pub struct Session { /// An arbitrary application defined object associated with each session pub application_data: Application::Data, - psk: Secret, send_counter: AtomicU64, receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO], header_protection_cipher: Aes, @@ -90,13 +93,14 @@ pub struct Session { /// Most of the mutable parts of a session state. struct State { + physical_mtu: usize, remote_session_id: Option, keys: [Option; 2], current_key: usize, - current_offer: Offer, + outgoing_offer: Offer, } -struct BobIncomingIncompleteSessionState { +struct IncomingIncompleteSession { timestamp: i64, alice_session_id: SessionId, bob_session_id: SessionId, @@ -105,10 +109,12 @@ struct BobIncomingIncompleteSessionState { hk: Secret, header_protection_key: Secret, bob_noise_e_secret: P384KeyPair, + defrag: [Mutex>; MAX_NOISE_HANDSHAKE_FRAGMENTS], } -struct AliceOutgoingIncompleteSessionState { +struct OutgoingSessionOffer { last_retry_time: AtomicI64, + psk: Secret, noise_h: [u8; SHA384_HASH_SIZE], noise_es: Secret, alice_noise_e_secret: P384KeyPair, @@ -120,39 +126,37 @@ struct AliceOutgoingIncompleteSessionState { struct OutgoingSessionAck { last_retry_time: AtomicI64, ack: [u8; MAX_NOISE_HANDSHAKE_SIZE], - ack_size: usize, + ack_len: usize, } enum Offer { None, - NoiseXKInit(Box), + NoiseXKInit(Box), NoiseXKAck(Box), RekeyInit(P384KeyPair, i64), } -const AES_POOL_SIZE: usize = 4; struct SessionKey { - ratchet_key: Secret, // Key used in derivation of the next session key - //receive_key: Secret, // Receive side AES-GCM key - //send_key: Secret, // Send side AES-GCM key - receive_cipher_pool: [Mutex>; AES_POOL_SIZE], // Pool of reusable sending ciphers - send_cipher_pool: [Mutex>; AES_POOL_SIZE], // Pool of reusable receiving ciphers - rekey_at_time: i64, // Rekey at or after this time (ticks) - rekey_at_counter: u64, // Rekey at or after this counter - expire_at_counter: u64, // Hard error when this counter value is reached or exceeded - ratchet_count: u64, // Number of rekey events - bob: bool, // Was this side "Bob" in this exchange? - confirmed: bool, // Is this key confirmed by the other side? + ratchet_key: Secret, // Key used in derivation of the next session key + receive_cipher_pool: [Mutex>; GCM_CIPHER_POOL_SIZE], // Pool of reusable sending ciphers + send_cipher_pool: [Mutex>; GCM_CIPHER_POOL_SIZE], // Pool of reusable receiving ciphers + rekey_at_time: i64, // Rekey at or after this time (ticks) + rekey_at_counter: u64, // Rekey at or after this counter + expire_at_counter: u64, // Hard error when this counter value is reached or exceeded + ratchet_count: u64, // Number of rekey events + my_turn_to_rekey: bool, // Was this side "Bob" in this exchange? + confirmed: bool, // Is this key confirmed by the other side yet? } impl Context { /// Create a new session context. /// /// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase - pub fn new(max_incomplete_session_queue_size: usize) -> Self { + pub fn new(max_incomplete_session_queue_size: usize, default_physical_mtu: usize) -> Self { zerotier_crypto::init(); Self { max_incomplete_session_queue_size, + default_physical_mtu: AtomicUsize::new(default_physical_mtu), defrag: Mutex::new(HashMap::new()), sessions: RwLock::new(SessionsById { active: HashMap::with_capacity(64), @@ -163,12 +167,13 @@ impl Context { /// Perform periodic background service and cleanup tasks. /// - /// This returns the number of milliseconds until it should be called again. + /// This returns the number of milliseconds until it should be called again. The caller should + /// try to satisfy this but small variations in timing of up to +/- a second or two are not + /// a problem. /// /// * `send` - Function to send packets to remote sessions - /// * `mtu` - Physical MTU /// * `current_time` - Current monotonic time in milliseconds - pub fn service>, &mut [u8])>(&self, mut send: SendFunction, mtu: usize, current_time: i64) -> i64 { + pub fn service>, &mut [u8])>(&self, mut send: SendFunction, current_time: i64) -> i64 { let mut dead_active = Vec::new(); let mut dead_pending = Vec::new(); let retry_cutoff = current_time - Application::RETRY_INTERVAL; @@ -180,7 +185,7 @@ impl Context { for (id, s) in sessions.active.iter() { if let Some(session) = s.upgrade() { let state = session.state.read().unwrap(); - if match &state.current_offer { + if match &state.outgoing_offer { Offer::None => true, Offer::NoiseXKInit(offer) => { // If there's an outstanding attempt to open a session, retransmit this periodically @@ -193,7 +198,7 @@ impl Context { let _ = send_with_fragmentation( |b| send(&session, b), &mut (offer.init_packet.clone()), - mtu, + state.physical_mtu, PACKET_TYPE_ALICE_NOISE_XK_INIT, None, 0, @@ -210,8 +215,8 @@ impl Context { ack.last_retry_time.store(current_time, Ordering::Relaxed); let _ = send_with_fragmentation( |b| send(&session, b), - &mut (ack.ack.clone())[..ack.ack_size], - mtu, + &mut (ack.ack.clone())[..ack.ack_len], + state.physical_mtu, PACKET_TYPE_ALICE_NOISE_XK_ACK, state.remote_session_id, 0, @@ -226,7 +231,8 @@ impl Context { // Check whether we need to rekey if there is no pending offer or if the last rekey // offer was before retry_cutoff (checked in the 'match' above). if let Some(key) = state.keys[state.current_key].as_ref() { - if key.bob && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) + if key.my_turn_to_rekey + && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) { drop(state); session.initiate_rekey(|b| send(&session, b), current_time); @@ -268,7 +274,7 @@ impl Context { /// /// * `app` - Application layer instance /// * `send` - User-supplied packet sending function - /// * `mtu` - Physical MTU for calls to send() + /// * `mtu` - Physical MTU for calls to send() for this session (can be changed later) /// * `remote_s_public_blob` - Remote side's opaque static public blob (which must contain remote_s_public_p384) /// * `remote_s_public_p384` - Remote side's static public NIST P-384 key /// * `psk` - Pre-shared key (use all zero if none) @@ -311,16 +317,17 @@ impl Context { let session = Arc::new(Session { id: local_session_id, application_data, - psk, send_counter: AtomicU64::new(3), // 1 and 2 are reserved for init and final ack receive_window: std::array::from_fn(|_| AtomicU64::new(0)), header_protection_cipher: Aes::new(&header_protection_key), state: RwLock::new(State { + physical_mtu: mtu, remote_session_id: None, keys: [None, None], current_key: 0, - current_offer: Offer::NoiseXKInit(Box::new(AliceOutgoingIncompleteSessionState { + outgoing_offer: Offer::NoiseXKInit(Box::new(OutgoingSessionOffer { last_retry_time: AtomicI64::new(current_time), + psk, noise_h: mix_hash(&mix_hash(&INITIAL_H, remote_s_public_blob), &alice_noise_e), noise_es: noise_es.clone(), alice_noise_e_secret, @@ -339,7 +346,7 @@ impl Context { { let mut state = session.state.write().unwrap(); - let offer = if let Offer::NoiseXKInit(offer) = &mut state.current_offer { + let offer = if let Offer::NoiseXKInit(offer) = &mut state.outgoing_offer { offer } else { panic!(); // should be impossible as this is what we initialized with @@ -351,7 +358,7 @@ impl Context { let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(init_packet).unwrap(); init.session_protocol_version = SESSION_PROTOCOL_VERSION; init.alice_noise_e = alice_noise_e; - init.alice_session_id = *local_session_id.as_bytes(); + init.alice_session_id = local_session_id.to_bytes(); init.alice_hk_public = alice_hk_secret.public; init.header_protection_key = header_protection_key.0; } @@ -385,8 +392,7 @@ impl Context { /// /// The send function may be called one or more times to send packets. If the packet is associated /// wtth an active session this session is supplied, otherwise this parameter is None and the packet - /// should be a reply to the current incoming packet. The size of packets to be sent will not exceed - /// the supplied mtu. + /// should be a reply to the current incoming packet. /// /// The check_allow_incoming_session function is called when an initial Noise_XK init message is /// received. This is before anything is known about the caller. A return value of true proceeds @@ -411,8 +417,8 @@ impl Context { /// * `send` - Function to call to send packets /// * `data_buf` - Buffer to receive decrypted and authenticated object data (an error is returned if too small) /// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership) - /// * `mtu` - Physical wire MTU for sending packets /// * `current_time` - Current monotonic time in milliseconds + #[inline] pub fn receive< 'b, SendFunction: FnMut(Option<&Arc>>, &mut [u8]), @@ -426,115 +432,83 @@ impl Context { mut send: SendFunction, source: &Application::PhysicalPath, data_buf: &'b mut [u8], - mut incoming_packet_buf: Application::IncomingPacketBuffer, - mtu: usize, + mut incoming_physical_packet_buf: Application::IncomingPacketBuffer, current_time: i64, ) -> Result, Error> { - let incoming_packet: &mut [u8] = incoming_packet_buf.as_mut(); - if incoming_packet.len() < MIN_PACKET_SIZE { + let incoming_physical_packet: &mut [u8] = incoming_physical_packet_buf.as_mut(); + if incoming_physical_packet.len() < MIN_PACKET_SIZE { return Err(Error::InvalidPacket); } - let mut incoming = None; - if let Some(local_session_id) = SessionId::new_from_u64_le(u64::from_le_bytes(incoming_packet[0..8].try_into().unwrap())) { - if let Some(session) = self.sessions.read().unwrap().active.get(&local_session_id).and_then(|s| s.upgrade()) { + if let Some(local_session_id) = SessionId::new_from_bytes(&incoming_physical_packet[0..SessionId::SIZE]) { + let sessions = self.sessions.read().unwrap(); + if let Some(session) = sessions.active.get(&local_session_id).and_then(|s| s.upgrade()) { + drop(sessions); debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id)); session .header_protection_cipher - .decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); + .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); if session.check_receive_window(incoming_counter) { - if fragment_count > 1 { - let mut fragged = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO].lock().unwrap(); - if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) { - drop(fragged); - return self.process_complete_incoming_packet( - app, - &mut send, - &mut check_allow_incoming_session, - &mut check_accept_session, - data_buf, - incoming_counter, - assembled_packet.as_ref(), - packet_type, - Some(session), - None, - key_index, - mtu, - current_time, - ); + let (assembled_packet, incoming_packet_buf_arr); + let incoming_packet = if fragment_count > 1 { + assembled_packet = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO] + .lock() + .unwrap() + .assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if let Some(assembled_packet) = assembled_packet.as_ref() { + assembled_packet.as_ref() } else { - drop(fragged); return Ok(ReceiveResult::Ok(Some(session))); } } else { - return self.process_complete_incoming_packet( - app, - &mut send, - &mut check_allow_incoming_session, - &mut check_accept_session, - data_buf, - incoming_counter, - &[incoming_packet_buf], - packet_type, - Some(session), - None, - key_index, - mtu, - current_time, - ); - } + incoming_packet_buf_arr = [incoming_physical_packet_buf]; + &incoming_packet_buf_arr + }; + + return self.process_complete_incoming_packet( + app, + &mut send, + &mut check_allow_incoming_session, + &mut check_accept_session, + data_buf, + incoming_counter, + incoming_packet, + packet_type, + Some(session), + None, + key_index, + current_time, + ); } else { return Err(Error::OutOfSequence); } - } else { - if let Some(i) = self.sessions.read().unwrap().incoming.get(&local_session_id).cloned() { - Aes::new(&i.header_protection_key) - .decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - incoming = Some(i); - } else { - return Err(Error::UnknownLocalSessionId); - } - } - } + } else if let Some(incoming) = sessions.incoming.get(&local_session_id).cloned() { + drop(sessions); + debug_assert!(!self.sessions.read().unwrap().active.contains_key(&local_session_id)); - // If we make it here the packet is not associated with a session or is associated with an - // incoming session (Noise_XK mid-negotiation). + Aes::new(&incoming.header_protection_key) + .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); - let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); - if fragment_count > 1 { - let f = { - let mut defrag = self.defrag.lock().unwrap(); - let f = defrag - .entry((source.clone(), incoming_counter)) - .or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time))) - .clone(); - - // Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions. - if defrag.len() >= self.max_incomplete_session_queue_size { - // First, drop all entries that are timed out or whose physical source duplicates another entry. - let mut sources = HashSet::with_capacity(defrag.len()); - let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; - defrag.retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f)); - - // Then, if we are still at or over the limit, drop 10% of remaining entries at random. - if defrag.len() >= self.max_incomplete_session_queue_size { - let mut rn = random::next_u32_secure(); - defrag.retain(|_, fragged| { - rn = prng32(rn); - rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f) - }); + let (assembled_packet, incoming_packet_buf_arr); + let incoming_packet = if fragment_count > 1 { + assembled_packet = incoming.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO] + .lock() + .unwrap() + .assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if let Some(assembled_packet) = assembled_packet.as_ref() { + assembled_packet.as_ref() + } else { + return Ok(ReceiveResult::Ok(None)); } - } + } else { + incoming_packet_buf_arr = [incoming_physical_packet_buf]; + &incoming_packet_buf_arr + }; - f - }; - let mut fragged = f.0.lock().unwrap(); - - if let Some(assembled_packet) = fragged.assemble(incoming_counter, incoming_packet_buf, fragment_no, fragment_count) { - self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter)); return self.process_complete_incoming_packet( app, &mut send, @@ -542,16 +516,63 @@ impl Context { &mut check_accept_session, data_buf, incoming_counter, - assembled_packet.as_ref(), + incoming_packet, packet_type, None, - incoming, + Some(incoming), key_index, - mtu, current_time, ); + } else { + return Err(Error::UnknownLocalSessionId); } } else { + let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); + + let (assembled_packet, incoming_packet_buf_arr); + let incoming_packet = if fragment_count > 1 { + assembled_packet = { + let mut defrag = self.defrag.lock().unwrap(); + let f = defrag + .entry((source.clone(), incoming_counter)) + .or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time))) + .clone(); + + // Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions. + if defrag.len() >= self.max_incomplete_session_queue_size { + // First, drop all entries that are timed out or whose physical source duplicates another entry. + let mut sources = HashSet::with_capacity(defrag.len()); + let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; + defrag + .retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f)); + + // Then, if we are still at or over the limit, drop 10% of remaining entries at random. + if defrag.len() >= self.max_incomplete_session_queue_size { + let mut rn = random::next_u32_secure(); + defrag.retain(|_, fragged| { + rn = prng32(rn); + rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f) + }); + } + } + + f + } + .0 + .lock() + .unwrap() + .assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if let Some(assembled_packet) = assembled_packet.as_ref() { + self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter)); + assembled_packet.as_ref() + } else { + return Ok(ReceiveResult::Ok(None)); + } + } else { + incoming_packet_buf_arr = [incoming_physical_packet_buf]; + &incoming_packet_buf_arr + }; + return self.process_complete_incoming_packet( app, &mut send, @@ -559,17 +580,14 @@ impl Context { &mut check_accept_session, data_buf, incoming_counter, - &[incoming_packet_buf], + incoming_packet, packet_type, None, - incoming, + None, key_index, - mtu, current_time, ); } - - return Ok(ReceiveResult::Ok(None)); } fn process_complete_incoming_packet< @@ -588,9 +606,8 @@ impl Context { fragments: &[Application::IncomingPacketBuffer], packet_type: u8, session: Option>>, - incoming: Option>, + incoming: Option>>, key_index: usize, - mtu: usize, current_time: i64, ) -> Result, Error> { debug_assert!(fragments.len() >= 1); @@ -653,9 +670,9 @@ impl Context { // If we got a valid data packet from Bob, this means we can cancel any offers // that are still oustanding for initialization. - match &state.current_offer { + match &state.outgoing_offer { Offer::NoiseXKInit(_) | Offer::NoiseXKAck(_) => { - state.current_offer = Offer::None; + state.outgoing_offer = Offer::None; } _ => {} } @@ -732,7 +749,7 @@ impl Context { } let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?; - let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?; + let alice_session_id = SessionId::new_from_array(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?; let header_protection_key = Secret(pkt.header_protection_key); // Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create @@ -749,6 +766,7 @@ impl Context { let mut sessions = self.sessions.write().unwrap(); + // Pick an unused session ID on this side. let mut bob_session_id; loop { bob_session_id = SessionId::random(); @@ -762,7 +780,7 @@ impl Context { let ack: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut ack_packet)?; ack.session_protocol_version = SESSION_PROTOCOL_VERSION; ack.bob_noise_e = bob_noise_e; - ack.bob_session_id = *bob_session_id.as_bytes(); + ack.bob_session_id = bob_session_id.to_bytes(); ack.bob_hk_ciphertext = bob_hk_ciphertext; // Encrypt main section of reply and attach tag. @@ -794,7 +812,7 @@ impl Context { // Reserve session ID on this side and record incomplete session state. sessions.incoming.insert( bob_session_id, - Arc::new(BobIncomingIncompleteSessionState { + Arc::new(IncomingIncompleteSession { timestamp: current_time, alice_session_id, bob_session_id, @@ -803,6 +821,7 @@ impl Context { hk, bob_noise_e_secret, header_protection_key: Secret(pkt.header_protection_key), + defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), }), ); debug_assert!(!sessions.active.contains_key(&bob_session_id)); @@ -813,7 +832,7 @@ impl Context { send_with_fragmentation( |b| send(None, b), &mut ack_packet, - mtu, + self.default_physical_mtu.load(Ordering::Relaxed), PACKET_TYPE_BOB_NOISE_XK_ACK, Some(alice_session_id), 0, @@ -848,7 +867,7 @@ impl Context { return Err(Error::OutOfSequence); } - if let Offer::NoiseXKInit(outgoing_offer) = &state.current_offer { + if let Offer::NoiseXKInit(outgoing_offer) = &state.outgoing_offer { let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?; // Derive noise_es_ee from Bob's ephemeral public key. @@ -876,7 +895,7 @@ impl Context { let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?; - if let Some(bob_session_id) = SessionId::new_from_bytes(&pkt.bob_session_id) { + if let Some(bob_session_id) = SessionId::new_from_array(&pkt.bob_session_id) { // Complete Noise_XKpsk3 by mixing in noise_se followed by the PSK. The PSK as far as // the Noise pattern is concerned is the result of mixing the externally supplied PSK // with the Kyber1024 shared secret (hk). Kyber is treated as part of the PSK because @@ -890,67 +909,75 @@ impl Context { if session.update_receive_window(incoming_counter) { let noise_es_ee_se_hk_psk = hmac_sha512_secret::( hmac_sha512_secret::(noise_es_ee.as_bytes(), noise_se.as_bytes()).as_bytes(), - hmac_sha512_secret::(session.psk.as_bytes(), hk.as_bytes()).as_bytes(), + hmac_sha512_secret::(outgoing_offer.psk.as_bytes(), hk.as_bytes()).as_bytes(), ); let reply_message_nonce = create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_ACK, 2); // Create reply informing Bob of our static identity now that we've verified Bob and set // up forward secrecy. Also return Bob's opaque note. - let mut reply_buffer = [0u8; MAX_NOISE_HANDSHAKE_SIZE]; - reply_buffer[HEADER_SIZE] = SESSION_PROTOCOL_VERSION; - let mut reply_len = HEADER_SIZE + 1; + let mut ack = [0u8; MAX_NOISE_HANDSHAKE_SIZE]; + ack[HEADER_SIZE] = SESSION_PROTOCOL_VERSION; + let mut ack_len = HEADER_SIZE + 1; let alice_s_public_blob = app.get_local_s_public_blob(); assert!(alice_s_public_blob.len() <= (u16::MAX as usize)); - reply_len = append_to_slice(&mut reply_buffer, reply_len, &(alice_s_public_blob.len() as u16).to_le_bytes())?; - let mut enc_start = reply_len; - reply_len = append_to_slice(&mut reply_buffer, reply_len, alice_s_public_blob)?; + ack_len = append_to_slice(&mut ack, ack_len, &(alice_s_public_blob.len() as u16).to_le_bytes())?; + let mut enc_start = ack_len; + ack_len = append_to_slice(&mut ack, ack_len, alice_s_public_blob)?; let mut gcm = AesGcm::new(&kbkdf::( hmac_sha512_secret::(noise_es_ee.as_bytes(), hk.as_bytes()).as_bytes(), )); gcm.reset_init_gcm(&reply_message_nonce); gcm.aad(&noise_h_next); - gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]); - reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?; + gcm.crypt_in_place(&mut ack[enc_start..ack_len]); + ack_len = append_to_slice(&mut ack, ack_len, &gcm.finish_encrypt())?; let metadata = outgoing_offer.metadata.as_ref().map_or(&[][..0], |md| md.as_slice()); assert!(metadata.len() <= (u16::MAX as usize)); - reply_len = append_to_slice(&mut reply_buffer, reply_len, &(metadata.len() as u16).to_le_bytes())?; + ack_len = append_to_slice(&mut ack, ack_len, &(metadata.len() as u16).to_le_bytes())?; - let noise_h_next = mix_hash(&mix_hash(&noise_h_next, &reply_buffer[HEADER_SIZE..reply_len]), session.psk.as_bytes()); + let noise_h_next = mix_hash(&mix_hash(&noise_h_next, &ack[HEADER_SIZE..ack_len]), outgoing_offer.psk.as_bytes()); - enc_start = reply_len; - reply_len = append_to_slice(&mut reply_buffer, reply_len, metadata)?; + enc_start = ack_len; + ack_len = append_to_slice(&mut ack, ack_len, metadata)?; let mut gcm = AesGcm::new(&kbkdf::( noise_es_ee_se_hk_psk.as_bytes(), )); gcm.reset_init_gcm(&reply_message_nonce); gcm.aad(&noise_h_next); - gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]); - reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?; + gcm.crypt_in_place(&mut ack[enc_start..ack_len]); + ack_len = append_to_slice(&mut ack, ack_len, &gcm.finish_encrypt())?; + + let mtu = state.physical_mtu; drop(state); { let mut state = session.state.write().unwrap(); let _ = state.remote_session_id.insert(bob_session_id); - let _ = - state.keys[0].insert(SessionKey::new::(noise_es_ee_se_hk_psk, 1, current_time, 2, false, false)); + let _ = state.keys[0].insert(SessionKey::new::( + noise_es_ee_se_hk_psk, + 1, + current_time, + 2, + false, + false, + )); debug_assert!(state.keys[1].is_none()); state.current_key = 0; - state.current_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck { + state.outgoing_offer = Offer::NoiseXKAck(Box::new(OutgoingSessionAck { last_retry_time: AtomicI64::new(current_time), - ack: reply_buffer, - ack_size: reply_len, + ack, + ack_len, })); } send_with_fragmentation( |b| send(Some(&session), b), - &mut reply_buffer[..reply_len], + &mut ack[..ack_len], mtu, PACKET_TYPE_ALICE_NOISE_XK_ACK, Some(bob_session_id), @@ -1053,18 +1080,18 @@ impl Context { let session = Arc::new(Session { id: incoming.bob_session_id, application_data, - psk, send_counter: AtomicU64::new(2), // 1 was already used during negotiation receive_window: std::array::from_fn(|_| AtomicU64::new(incoming_counter)), header_protection_cipher: Aes::new(&incoming.header_protection_key), state: RwLock::new(State { + physical_mtu: self.default_physical_mtu.load(Ordering::Relaxed), remote_session_id: Some(incoming.alice_session_id), keys: [ Some(SessionKey::new::(noise_es_ee_se_hk_psk, 1, current_time, 2, true, true)), None, ], current_key: 0, - current_offer: Offer::None, + outgoing_offer: Offer::None, }), defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), }); @@ -1101,81 +1128,74 @@ impl Context { if let Some(session) = session { let state = session.state.read().unwrap(); - if let Some(remote_session_id) = state.remote_session_id { - if let Some(key) = state.keys[key_index].as_ref() { - // Only the current "Alice" accepts rekeys initiated by the current "Bob." These roles - // flip with each rekey event. - if !key.bob { - let mut c = key.get_receive_cipher(incoming_counter); - c.reset_init_gcm(&incoming_message_nonce); - c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]); - let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]); - drop(c); + if let (Some(remote_session_id), Some(key)) = (state.remote_session_id, state.keys[key_index].as_ref()) { + if !key.my_turn_to_rekey && { + let mut c = key.get_receive_cipher(incoming_counter); + c.reset_init_gcm(&incoming_message_nonce); + c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]); + c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]) + } { + let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); + if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) { + let bob_e_secret = P384KeyPair::generate(); + let next_session_key = hmac_sha512_secret( + key.ratchet_key.as_bytes(), + bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?.as_bytes(), + ); - if aead_authentication_ok { - let pkt: &RekeyInit = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); - if let Some(alice_e) = P384PublicKey::from_bytes(&pkt.alice_e) { - let bob_e_secret = P384KeyPair::generate(); - let next_session_key = hmac_sha512_secret( - key.ratchet_key.as_bytes(), - bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?.as_bytes(), - ); + // Packet fully authenticated + if session.update_receive_window(incoming_counter) { + let mut reply_buf = [0u8; RekeyAck::SIZE]; + let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap(); + reply.session_protocol_version = SESSION_PROTOCOL_VERSION; + reply.bob_e = *bob_e_secret.public_key_bytes(); + reply.next_key_fingerprint = SHA384::hash(next_session_key.as_bytes()); - // Packet fully authenticated - if session.update_receive_window(incoming_counter) { - let mut reply_buf = [0u8; RekeyAck::SIZE]; - let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap(); - reply.session_protocol_version = SESSION_PROTOCOL_VERSION; - reply.bob_e = *bob_e_secret.public_key_bytes(); - reply.next_key_fingerprint = SHA384::hash(next_session_key.as_bytes()); + let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + set_packet_header( + &mut reply_buf, + 1, + 0, + PACKET_TYPE_REKEY_ACK, + u64::from(remote_session_id), + state.current_key, + counter, + ); - let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); - set_packet_header( - &mut reply_buf, - 1, - 0, - PACKET_TYPE_REKEY_ACK, - u64::from(remote_session_id), - state.current_key, - counter, - ); + let mut c = key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter)); + c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]); + reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt()); + drop(c); - let mut c = key.get_send_cipher(counter)?; - c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_REKEY_ACK, counter)); - c.crypt_in_place(&mut reply_buf[RekeyAck::ENC_START..RekeyAck::AUTH_START]); - reply_buf[RekeyAck::AUTH_START..].copy_from_slice(&c.finish_encrypt()); - drop(c); + session + .header_protection_cipher + .encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(Some(&session), &mut reply_buf); - session - .header_protection_cipher - .encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - send(Some(&session), &mut reply_buf); + // The new "Bob" doesn't know yet if Alice has received the new key, so the + // new key is recorded as the "alt" (key_index ^ 1) but the current key is + // not advanced yet. This happens automatically the first time we receive a + // valid packet with the new key. + let next_ratchet_count = key.ratchet_count + 1; + drop(state); + let mut state = session.state.write().unwrap(); + let _ = state.keys[key_index ^ 1].replace(SessionKey::new::( + next_session_key, + next_ratchet_count, + current_time, + counter, + false, + false, + )); - // The new "Bob" doesn't know yet if Alice has received the new key, so the - // new key is recorded as the "alt" (key_index ^ 1) but the current key is - // not advanced yet. This happens automatically the first time we receive a - // valid packet with the new key. - let next_ratchet_count = key.ratchet_count + 1; - drop(state); - let mut state = session.state.write().unwrap(); - let _ = state.keys[key_index ^ 1].replace(SessionKey::new::( - next_session_key, - next_ratchet_count, - current_time, - counter, - false, - false, - )); - - drop(state); - return Ok(ReceiveResult::Ok(Some(session))); - } else { - return Err(Error::OutOfSequence); - } - } + drop(state); + return Ok(ReceiveResult::Ok(Some(session))); + } else { + return Err(Error::OutOfSequence); } - return Err(Error::FailedAuthentication); } + return Err(Error::FailedAuthentication); } } return Err(Error::OutOfSequence); @@ -1194,59 +1214,52 @@ impl Context { if let Some(session) = session { let state = session.state.read().unwrap(); - if let Offer::RekeyInit(alice_e_secret, _) = &state.current_offer { - if let Some(key) = state.keys[key_index].as_ref() { - // Only the current "Bob" initiates rekeys and expects this ACK. - if key.bob { - let mut c = key.get_receive_cipher(incoming_counter); - c.reset_init_gcm(&incoming_message_nonce); - c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]); - let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]); - drop(c); + if let (Offer::RekeyInit(alice_e_secret, _), Some(key)) = (&state.outgoing_offer, state.keys[key_index].as_ref()) { + if key.my_turn_to_rekey && { + let mut c = key.get_receive_cipher(incoming_counter); + c.reset_init_gcm(&incoming_message_nonce); + c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]); + c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]) + } { + let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); + if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) { + let next_session_key = hmac_sha512_secret( + key.ratchet_key.as_bytes(), + alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?.as_bytes(), + ); - if aead_authentication_ok { - // Packet fully authenticated + if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(next_session_key.as_bytes())) { + if session.update_receive_window(incoming_counter) { + // The new "Alice" knows Bob has the key since this is an ACK, so she can go + // ahead and set current_key to the new key. Then when she sends something + // to Bob the other side will automatically advance to the new key as well. + let next_ratchet_count = key.ratchet_count + 1; + drop(state); + let next_key_index = key_index ^ 1; + let mut state = session.state.write().unwrap(); + let _ = state.keys[next_key_index].replace(SessionKey::new::( + next_session_key, + next_ratchet_count, + current_time, + session.send_counter.load(Ordering::Relaxed), + true, + true, + )); + state.current_key = next_key_index; // this is an ACK so it's confirmed + state.outgoing_offer = Offer::None; - let pkt: &RekeyAck = byte_array_as_proto_buffer(&pkt_assembled).unwrap(); - if let Some(bob_e) = P384PublicKey::from_bytes(&pkt.bob_e) { - let next_session_key = hmac_sha512_secret( - key.ratchet_key.as_bytes(), - alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?.as_bytes(), - ); - - if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(next_session_key.as_bytes())) { - if session.update_receive_window(incoming_counter) { - // The new "Alice" knows Bob has the key since this is an ACK, so she can go - // ahead and set current_key to the new key. Then when she sends something - // to Bob the other side will automatically advance to the new key as well. - let next_ratchet_count = key.ratchet_count + 1; - drop(state); - let next_key_index = key_index ^ 1; - let mut state = session.state.write().unwrap(); - let _ = state.keys[next_key_index].replace(SessionKey::new::( - next_session_key, - next_ratchet_count, - current_time, - session.send_counter.load(Ordering::Relaxed), - true, - true, - )); - state.current_key = next_key_index; // this is an ACK so it's confirmed - state.current_offer = Offer::None; - - drop(state); - return Ok(ReceiveResult::Ok(Some(session))); - } else { - return Err(Error::OutOfSequence); - } - } + drop(state); + return Ok(ReceiveResult::Ok(Some(session))); + } else { + return Err(Error::OutOfSequence); } } - return Err(Error::FailedAuthentication); } } + return Err(Error::FailedAuthentication); + } else { + return Err(Error::OutOfSequence); } - return Err(Error::OutOfSequence); } else { return Err(Error::UnknownLocalSessionId); } @@ -1270,51 +1283,49 @@ impl Session { pub fn send(&self, mut send: SendFunction, mtu_sized_buffer: &mut [u8], mut data: &[u8]) -> Result<(), Error> { debug_assert!(mtu_sized_buffer.len() >= MIN_TRANSPORT_MTU); let state = self.state.read().unwrap(); - if let Some(remote_session_id) = state.remote_session_id { - if let Some(session_key) = state.keys[state.current_key].as_ref() { - let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) { + let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); - let mut c = session_key.get_send_cipher(counter)?; - c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); + let mut c = session_key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter)); - let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize; - let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE; - let last_fragment_no = fragment_count - 1; + let fragment_count = (((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize; + let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE; + let last_fragment_no = fragment_count - 1; - for fragment_no in 0..fragment_count { - let chunk_size = fragment_max_chunk_size.min(data.len()); - let mut fragment_size = chunk_size + HEADER_SIZE; + for fragment_no in 0..fragment_count { + let chunk_size = fragment_max_chunk_size.min(data.len()); + let mut fragment_size = chunk_size + HEADER_SIZE; - set_packet_header( - mtu_sized_buffer, - fragment_count, - fragment_no, - PACKET_TYPE_DATA, - u64::from(remote_session_id), - state.current_key, - counter, - ); + set_packet_header( + mtu_sized_buffer, + fragment_count, + fragment_no, + PACKET_TYPE_DATA, + u64::from(remote_session_id), + state.current_key, + counter, + ); - c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); - data = &data[chunk_size..]; + c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]); + data = &data[chunk_size..]; - if fragment_no == last_fragment_no { - debug_assert!(data.is_empty()); - let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE; - mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt()); - fragment_size = tagged_fragment_size; - } - - self.header_protection_cipher - .encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - send(&mut mtu_sized_buffer[..fragment_size]); + if fragment_no == last_fragment_no { + debug_assert!(data.is_empty()); + let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE; + mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt()); + fragment_size = tagged_fragment_size; } - debug_assert!(data.is_empty()); - drop(c); - - return Ok(()); + self.header_protection_cipher + .encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(&mut mtu_sized_buffer[..fragment_size]); } + debug_assert!(data.is_empty()); + + drop(c); + + return Ok(()); } return Err(Error::SessionNotEstablished); } @@ -1322,23 +1333,26 @@ impl Session { /// Send a NOP to the other side (e.g. for keep alive). pub fn send_nop(&self, mut send: SendFunction) -> Result<(), Error> { let state = self.state.read().unwrap(); - if let Some(remote_session_id) = state.remote_session_id { - if let Some(session_key) = state.keys[state.current_key].as_ref() { - let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); - let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; - let mut c = session_key.get_send_cipher(counter)?; - c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); - nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); - drop(c); - set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter); - self.header_protection_cipher - .encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); - send(&mut nop); - } + if let (Some(remote_session_id), Some(session_key)) = (state.remote_session_id, state.keys[state.current_key].as_ref()) { + let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); + let mut nop = [0u8; HEADER_SIZE + AES_GCM_TAG_SIZE]; + let mut c = session_key.get_send_cipher(counter)?; + c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_NOP, counter)); + nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); + drop(c); + set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter); + self.header_protection_cipher + .encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); + send(&mut nop); } return Err(Error::SessionNotEstablished); } + /// Set the current physical MTU that this session should use to send packets. + pub fn set_physical_mtu(&self, mtu: usize) { + self.state.write().unwrap().physical_mtu = mtu; + } + /// Check whether this session is established. pub fn established(&self) -> bool { let state = self.state.read().unwrap(); @@ -1396,7 +1410,7 @@ impl Session { send(&mut rekey_buf); drop(state); - self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, current_time); + self.state.write().unwrap().outgoing_offer = Offer::RekeyInit(rekey_e, current_time); } } } @@ -1556,8 +1570,6 @@ impl SessionKey { let send_cipher_pool = std::array::from_fn(|_| Mutex::new(AesGcm::new(&send_key))); Self { ratchet_key: kbkdf::(key.as_bytes()), - //receive_key, - //send_key, receive_cipher_pool, send_cipher_pool, rekey_at_time: current_time @@ -1568,31 +1580,23 @@ impl SessionKey { rekey_at_counter: current_counter.checked_add(Application::REKEY_AFTER_USES).unwrap(), expire_at_counter: current_counter.checked_add(Application::EXPIRE_AFTER_USES).unwrap(), ratchet_count, - bob, + my_turn_to_rekey: bob, confirmed, } } + #[inline(always)] fn get_send_cipher<'a>(&'a self, counter: u64) -> Result>, Error> { if counter < self.expire_at_counter { - for i in 0..AES_POOL_SIZE { - if let Ok(p) = self.send_cipher_pool[(counter as usize).wrapping_add(i) % AES_POOL_SIZE].try_lock() { - return Ok(p); - } - } - Ok(self.send_cipher_pool[(counter as usize) % AES_POOL_SIZE].lock().unwrap()) + Ok(self.send_cipher_pool[(counter as usize) % GCM_CIPHER_POOL_SIZE].lock().unwrap()) } else { Err(Error::MaxKeyLifetimeExceeded) } } + #[inline(always)] fn get_receive_cipher<'a>(&'a self, counter: u64) -> MutexGuard<'a, AesGcm> { - for i in 0..AES_POOL_SIZE { - if let Ok(p) = self.receive_cipher_pool[(counter as usize).wrapping_add(i) % AES_POOL_SIZE].try_lock() { - return p; - } - } - self.receive_cipher_pool[(counter as usize) % AES_POOL_SIZE].lock().unwrap() + self.receive_cipher_pool[(counter as usize) % GCM_CIPHER_POOL_SIZE].lock().unwrap() } }