Remove potential lock bottleneck from Peer, refactor Pool.

This commit is contained in:
Adam Ierymenko 2021-08-05 15:49:31 -04:00
parent 8376cfe15d
commit 8681f61de3
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3
4 changed files with 199 additions and 228 deletions

View file

@ -4,17 +4,18 @@ use std::sync::{Arc, Weak};
use parking_lot::Mutex; use parking_lot::Mutex;
/// Trait for objects that can be used with Pool. /// Trait for objects that create and reset poolable objects.
pub trait Reusable: Default + Sized { pub trait PoolFactory<O> {
fn reset(&mut self); fn create(&self) -> O;
fn reset(&self, obj: &mut O);
} }
struct PoolEntry<O: Reusable> { struct PoolEntry<O, F: PoolFactory<O>> {
obj: O, obj: O,
return_pool: Weak<PoolInner<O>>, return_pool: Weak<PoolInner<O, F>>,
} }
type PoolInner<O> = Mutex<Vec<*mut PoolEntry<O>>>; struct PoolInner<O, F: PoolFactory<O>>(F, Mutex<Vec<*mut PoolEntry<O, F>>>);
/// Container for pooled objects that have been checked out of the pool. /// Container for pooled objects that have been checked out of the pool.
/// ///
@ -26,9 +27,9 @@ type PoolInner<O> = Mutex<Vec<*mut PoolEntry<O>>>;
/// Note that pooled objects are not clonable. If you want to share them use Rc<> /// Note that pooled objects are not clonable. If you want to share them use Rc<>
/// or Arc<>. /// or Arc<>.
#[repr(transparent)] #[repr(transparent)]
pub struct Pooled<O: Reusable>(*mut PoolEntry<O>); pub struct Pooled<O, F: PoolFactory<O>>(*mut PoolEntry<O, F>);
impl<O: Reusable> Pooled<O> { impl<O, F: PoolFactory<O>> Pooled<O, F> {
/// Get a raw pointer to the object wrapped by this pooled object container. /// Get a raw pointer to the object wrapped by this pooled object container.
/// The returned raw pointer MUST be restored into a Pooled instance with /// The returned raw pointer MUST be restored into a Pooled instance with
/// from_raw() or memory will leak. /// from_raw() or memory will leak.
@ -55,7 +56,7 @@ impl<O: Reusable> Pooled<O> {
} }
} }
impl<O: Reusable> Deref for Pooled<O> { impl<O, F: PoolFactory<O>> Deref for Pooled<O, F> {
type Target = O; type Target = O;
#[inline(always)] #[inline(always)]
@ -65,7 +66,7 @@ impl<O: Reusable> Deref for Pooled<O> {
} }
} }
impl<O: Reusable> AsRef<O> for Pooled<O> { impl<O, F: PoolFactory<O>> AsRef<O> for Pooled<O, F> {
#[inline(always)] #[inline(always)]
fn as_ref(&self) -> &O { fn as_ref(&self) -> &O {
debug_assert!(!self.0.is_null()); debug_assert!(!self.0.is_null());
@ -73,7 +74,7 @@ impl<O: Reusable> AsRef<O> for Pooled<O> {
} }
} }
impl<O: Reusable> DerefMut for Pooled<O> { impl<O, F: PoolFactory<O>> DerefMut for Pooled<O, F> {
#[inline(always)] #[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
debug_assert!(!self.0.is_null()); debug_assert!(!self.0.is_null());
@ -81,7 +82,7 @@ impl<O: Reusable> DerefMut for Pooled<O> {
} }
} }
impl<O: Reusable> AsMut<O> for Pooled<O> { impl<O, F: PoolFactory<O>> AsMut<O> for Pooled<O, F> {
#[inline(always)] #[inline(always)]
fn as_mut(&mut self) -> &mut O { fn as_mut(&mut self) -> &mut O {
debug_assert!(!self.0.is_null()); debug_assert!(!self.0.is_null());
@ -89,34 +90,34 @@ impl<O: Reusable> AsMut<O> for Pooled<O> {
} }
} }
impl<O: Reusable> Drop for Pooled<O> { impl<O, F: PoolFactory<O>> Drop for Pooled<O, F> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
Weak::upgrade(&(*self.0).return_pool).map_or_else(|| { Weak::upgrade(&(*self.0).return_pool).map_or_else(|| {
drop(Box::from_raw(self.0)) drop(Box::from_raw(self.0))
}, |p| { }, |p| {
(*self.0).obj.reset(); p.0.reset(&mut (*self.0).obj);
p.lock().push(self.0) p.1.lock().push(self.0)
}) })
} }
} }
} }
/// An object pool for Reusable objects. /// An object pool for Reusable objects.
/// The pool is safe in that checked out objects return automatically when their Pooled /// Checked out objects are held by a guard object that returns them when dropped if
/// transparent container is dropped, or deallocate if the pool has been dropped. /// the pool still exists or drops them if the pool has itself been dropped.
pub struct Pool<O: Reusable>(Arc<PoolInner<O>>); pub struct Pool<O, F: PoolFactory<O>>(Arc<PoolInner<O, F>>);
impl<O: Reusable> Pool<O> { impl<O, F: PoolFactory<O>> Pool<O, F> {
pub fn new(initial_stack_capacity: usize) -> Self { pub fn new(initial_stack_capacity: usize, factory: F) -> Self {
Self(Arc::new(Mutex::new(Vec::with_capacity(initial_stack_capacity)))) Self(Arc::new(PoolInner::<O, F>(factory, Mutex::new(Vec::with_capacity(initial_stack_capacity)))))
} }
/// Get a pooled object, or allocate one if the pool is empty. /// Get a pooled object, or allocate one if the pool is empty.
pub fn get(&self) -> Pooled<O> { pub fn get(&self) -> Pooled<O, F> {
Pooled::<O>(self.0.lock().pop().map_or_else(|| { Pooled::<O, F>(self.0.1.lock().pop().map_or_else(|| {
Box::into_raw(Box::new(PoolEntry::<O> { Box::into_raw(Box::new(PoolEntry::<O, F> {
obj: O::default(), obj: self.0.0.create(),
return_pool: Arc::downgrade(&self.0), return_pool: Arc::downgrade(&self.0),
})) }))
}, |obj| { }, |obj| {
@ -128,7 +129,7 @@ impl<O: Reusable> Pool<O> {
/// Get approximate memory use in bytes (does not include checked out objects). /// Get approximate memory use in bytes (does not include checked out objects).
#[inline(always)] #[inline(always)]
pub fn pool_memory_bytes(&self) -> usize { pub fn pool_memory_bytes(&self) -> usize {
self.0.lock().len() * (size_of::<PoolEntry<O>>() + size_of::<usize>()) self.0.1.lock().len() * (size_of::<PoolEntry<O, F>>() + size_of::<usize>())
} }
/// Dispose of all pooled objects, freeing any memory they use. /// Dispose of all pooled objects, freeing any memory they use.
@ -136,7 +137,7 @@ impl<O: Reusable> Pool<O> {
/// objects will still be returned on drop unless the pool itself is dropped. This can /// objects will still be returned on drop unless the pool itself is dropped. This can
/// be done to free some memory if there has been a spike in memory use. /// be done to free some memory if there has been a spike in memory use.
pub fn purge(&self) { pub fn purge(&self) {
let mut p = self.0.lock(); let mut p = self.0.1.lock();
for obj in p.iter() { for obj in p.iter() {
drop(unsafe { Box::from_raw(*obj) }); drop(unsafe { Box::from_raw(*obj) });
} }
@ -144,15 +145,15 @@ impl<O: Reusable> Pool<O> {
} }
} }
impl<O: Reusable> Drop for Pool<O> { impl<O, F: PoolFactory<O>> Drop for Pool<O, F> {
fn drop(&mut self) { fn drop(&mut self) {
self.purge(); self.purge();
} }
} }
unsafe impl<O: Reusable> Sync for Pool<O> {} unsafe impl<O, F: PoolFactory<O>> Sync for Pool<O, F> {}
unsafe impl<O: Reusable> Send for Pool<O> {} unsafe impl<O, F: PoolFactory<O>> Send for Pool<O, F> {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -160,24 +161,24 @@ mod tests {
use std::ops::DerefMut; use std::ops::DerefMut;
use std::time::Duration; use std::time::Duration;
use crate::util::pool::{Reusable, Pool}; use crate::util::pool::*;
use std::sync::Arc; use std::sync::Arc;
struct ReusableTestObject(usize); struct TestPoolFactory;
impl Default for ReusableTestObject { impl PoolFactory<String> for TestPoolFactory {
fn default() -> Self { fn create(&self) -> String {
Self(0) String::new()
}
} }
impl Reusable for ReusableTestObject { fn reset(&self, obj: &mut String) {
fn reset(&mut self) {} obj.clear();
}
} }
#[test] #[test]
fn threaded_pool_use() { fn threaded_pool_use() {
let p: Arc<Pool<ReusableTestObject>> = Arc::new(Pool::new(2)); let p: Arc<Pool<String, TestPoolFactory>> = Arc::new(Pool::new(2, TestPoolFactory{}));
let ctr = Arc::new(AtomicUsize::new(0)); let ctr = Arc::new(AtomicUsize::new(0));
for _ in 0..64 { for _ in 0..64 {
let p2 = p.clone(); let p2 = p.clone();
@ -185,10 +186,10 @@ mod tests {
let _ = std::thread::spawn(move || { let _ = std::thread::spawn(move || {
for _ in 0..16384 { for _ in 0..16384 {
let mut o1 = p2.get(); let mut o1 = p2.get();
o1.deref_mut().0 += 1; o1.push('a');
let mut o2 = p2.get(); let mut o2 = p2.get();
drop(o1); drop(o1);
o2.deref_mut().0 += 1; o2.push('b');
ctr2.fetch_add(1, Ordering::Relaxed); ctr2.fetch_add(1, Ordering::Relaxed);
} }
}); });

View file

@ -1,7 +1,6 @@
use std::mem::size_of; use std::mem::size_of;
use std::io::Write; use std::io::Write;
use crate::util::pool::PoolFactory;
use crate::util::pool::Reusable;
const OVERFLOW_ERR_MSG: &'static str = "overflow"; const OVERFLOW_ERR_MSG: &'static str = "overflow";
@ -23,13 +22,6 @@ impl<const L: usize> Default for Buffer<L> {
} }
} }
impl<const L: usize> Reusable for Buffer<L> {
#[inline(always)]
fn reset(&mut self) {
self.clear();
}
}
impl<const L: usize> Buffer<L> { impl<const L: usize> Buffer<L> {
#[inline(always)] #[inline(always)]
pub fn new() -> Self { pub fn new() -> Self {
@ -62,7 +54,7 @@ impl<const L: usize> Buffer<L> {
/// Get all bytes after a given position. /// Get all bytes after a given position.
#[inline(always)] #[inline(always)]
pub fn as_bytes_after(&self, start: usize) -> std::io::Result<&[u8]> { pub fn as_bytes_starting_at(&self, start: usize) -> std::io::Result<&[u8]> {
if start <= self.0 { if start <= self.0 {
Ok(&self.1[start..]) Ok(&self.1[start..])
} else { } else {
@ -388,3 +380,17 @@ impl<const L: usize> AsMut<[u8]> for Buffer<L> {
self.as_bytes_mut() self.as_bytes_mut()
} }
} }
pub struct PooledBufferFactory<const L: usize>;
impl<const L: usize> PoolFactory<Buffer<L>> for PooledBufferFactory<L> {
#[inline(always)]
fn create(&self) -> Buffer<L> {
Buffer::new()
}
#[inline(always)]
fn reset(&self, obj: &mut Buffer<L>) {
obj.clear();
}
}

View file

@ -10,7 +10,7 @@ use crate::error::InvalidParameterError;
use crate::util::gate::IntervalGate; use crate::util::gate::IntervalGate;
use crate::util::pool::{Pool, Pooled}; use crate::util::pool::{Pool, Pooled};
use crate::vl1::{Address, Endpoint, Identity, Locator}; use crate::vl1::{Address, Endpoint, Identity, Locator};
use crate::vl1::buffer::Buffer; use crate::vl1::buffer::{Buffer, PooledBufferFactory};
use crate::vl1::constants::PACKET_SIZE_MAX; use crate::vl1::constants::PACKET_SIZE_MAX;
use crate::vl1::path::Path; use crate::vl1::path::Path;
use crate::vl1::peer::Peer; use crate::vl1::peer::Peer;
@ -18,9 +18,10 @@ use crate::vl1::protocol::*;
use crate::vl1::whois::WhoisQueue; use crate::vl1::whois::WhoisQueue;
/// Standard packet buffer type including pool container. /// Standard packet buffer type including pool container.
pub type PacketBuffer = Pooled<Buffer<{ PACKET_SIZE_MAX }>>; pub type PacketBuffer = Pooled<Buffer<{ PACKET_SIZE_MAX }>, PooledBufferFactory<{ PACKET_SIZE_MAX }>>;
/// Callback interface and call context for calls to the node (for VL1). /// Callback interface and call context for calls to the node (for VL1).
///
/// Every non-trivial call takes a reference to this, which it passes all the way through /// Every non-trivial call takes a reference to this, which it passes all the way through
/// the call stack. This can be used to call back into the caller to send packets, get or /// the call stack. This can be used to call back into the caller to send packets, get or
/// store data, report events, etc. /// store data, report events, etc.
@ -124,7 +125,7 @@ pub struct Node {
paths: DashMap<Endpoint, Arc<Path>>, paths: DashMap<Endpoint, Arc<Path>>,
peers: DashMap<Address, Arc<Peer>>, peers: DashMap<Address, Arc<Peer>>,
whois: WhoisQueue, whois: WhoisQueue,
buffer_pool: Pool<Buffer<{ PACKET_SIZE_MAX }>>, buffer_pool: Pool<Buffer<{ PACKET_SIZE_MAX }>, PooledBufferFactory<{ PACKET_SIZE_MAX }>>,
secure_prng: SecureRandom, secure_prng: SecureRandom,
} }
@ -161,7 +162,7 @@ impl Node {
paths: DashMap::new(), paths: DashMap::new(),
peers: DashMap::new(), peers: DashMap::new(),
whois: WhoisQueue::new(), whois: WhoisQueue::new(),
buffer_pool: Pool::new(64), buffer_pool: Pool::new(64, PooledBufferFactory),
secure_prng: SecureRandom::get(), secure_prng: SecureRandom::get(),
}) })
} }

View file

@ -1,4 +1,5 @@
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicI64, AtomicU64, AtomicU8, Ordering};
use parking_lot::Mutex; use parking_lot::Mutex;
use aes_gmac_siv::AesGmacSiv; use aes_gmac_siv::AesGmacSiv;
@ -10,24 +11,40 @@ use crate::crypto::poly1305::Poly1305;
use crate::crypto::random::next_u64_secure; use crate::crypto::random::next_u64_secure;
use crate::crypto::salsa::Salsa; use crate::crypto::salsa::Salsa;
use crate::crypto::secret::Secret; use crate::crypto::secret::Secret;
use crate::util::pool::{Pool, PoolFactory};
use crate::vl1::{Identity, Path}; use crate::vl1::{Identity, Path};
use crate::vl1::buffer::Buffer; use crate::vl1::buffer::Buffer;
use crate::vl1::constants::*; use crate::vl1::constants::*;
use crate::vl1::node::*; use crate::vl1::node::*;
use crate::vl1::protocol::*; use crate::vl1::protocol::*;
struct PeerSecrets { struct AesGmacSivPoolFactory(Secret<48>, Secret<48>);
// Time secret was created in ticks or -1 for static secrets.
impl PoolFactory<AesGmacSiv> for AesGmacSivPoolFactory {
#[inline(always)]
fn create(&self) -> AesGmacSiv {
AesGmacSiv::new(self.0.as_ref(), self.1.as_ref())
}
#[inline(always)]
fn reset(&self, obj: &mut AesGmacSiv) {
obj.reset();
}
}
struct PeerSecret {
// Time secret was created in ticks for ephemeral secrets, or -1 for static secrets.
create_time_ticks: i64, create_time_ticks: i64,
// Number of time secret has been used to encrypt something during this session. // Number of times secret has been used to encrypt something during this session.
encrypt_count: u64, encrypt_count: AtomicU64,
// Raw secret itself. // Raw secret itself.
secret: Secret<48>, secret: Secret<48>,
// Reusable AES-GMAC-SIV initialized with secret. // Reusable AES-GMAC-SIV ciphers initialized with secret.
aes: AesGmacSiv, // These can't be used concurrently so they're pooled to allow multithreaded use.
aes: Pool<AesGmacSiv, AesGmacSivPoolFactory>,
} }
struct EphemeralKeyPair { struct EphemeralKeyPair {
@ -44,50 +61,6 @@ struct EphemeralKeyPair {
p521: P521KeyPair, p521: P521KeyPair,
} }
struct TxState {
// Time we last sent something to this peer.
last_send_time_ticks: i64,
// Outgoing packet IV counter, starts at a random position.
packet_iv_counter: u64,
// Total bytes sent to this peer during this session.
total_bytes: u64,
// "Eternal" static secret created via identity agreement.
static_secret: PeerSecrets,
// The most recently negotiated ephemeral secret.
ephemeral_secret: Option<PeerSecrets>,
// The current ephemeral key pair we will share with HELLO.
ephemeral_pair: Option<EphemeralKeyPair>,
// Paths to this peer sorted in ascending order of path quality.
paths: Vec<Arc<Path>>,
}
struct RxState {
// Time we last received something (authenticated) from this peer.
last_receive_time_ticks: i64,
// Total bytes received from this peer during this session.
total_bytes: u64,
// "Eternal" static secret created via identity agreement.
static_secret: PeerSecrets,
// The most recently negotiated ephemeral secret.
ephemeral_secret: Option<PeerSecrets>,
// Remote version as major, minor, revision, build in most-to-least-significant 16-bit chunks.
// This is the user-facing software version and is zero if not yet known.
remote_version: u64,
// Remote protocol version or zero if not yet known.
remote_protocol_version: u8,
}
/// A remote peer known to this node. /// A remote peer known to this node.
/// Sending-related and receiving-related fields are locked separately since concurrent /// Sending-related and receiving-related fields are locked separately since concurrent
/// send/receive is not uncommon. /// send/receive is not uncommon.
@ -96,7 +69,7 @@ pub struct Peer {
identity: Identity, identity: Identity,
// Static shared secret computed from agreement with identity. // Static shared secret computed from agreement with identity.
static_secret: Secret<48>, static_secret: PeerSecret,
// Derived static secret used to encrypt the dictionary part of HELLO. // Derived static secret used to encrypt the dictionary part of HELLO.
static_secret_hello_dictionary_encrypt: Secret<48>, static_secret_hello_dictionary_encrypt: Secret<48>,
@ -104,11 +77,27 @@ pub struct Peer {
// Derived static secret used to add full HMAC-SHA384 to packets, currently just HELLO. // Derived static secret used to add full HMAC-SHA384 to packets, currently just HELLO.
static_secret_packet_hmac: Secret<48>, static_secret_packet_hmac: Secret<48>,
// State used primarily when sending to this peer. // Latest ephemeral secret acknowledged with OK(HELLO).
tx: Mutex<TxState>, ephemeral_secret: Mutex<Option<Arc<PeerSecret>>>,
// State used primarily when receiving from this peer. // Either None or the current ephemeral key pair whose public keys are on offer.
rx: Mutex<RxState>, ephemeral_pair: Mutex<Option<EphemeralKeyPair>>,
// Statistics
last_send_time_ticks: AtomicI64,
last_receive_time_ticks: AtomicI64,
total_bytes_sent: AtomicU64,
total_bytes_received: AtomicU64,
// Counter for assigning packet IV's a.k.a. PacketIDs.
packet_iv_counter: AtomicU64,
// Remote peer version information.
remote_version: AtomicU64,
remote_protocol_version: AtomicU8,
// Paths sorted in ascending order of quality / preference.
paths: Mutex<Vec<Arc<Path>>>,
} }
/// Derive per-packet key for Sals20/12 encryption (and Poly1305 authentication). /// Derive per-packet key for Sals20/12 encryption (and Poly1305 authentication).
@ -142,42 +131,31 @@ impl Peer {
/// fatal error occurs performing key agreement between the two identities. /// fatal error occurs performing key agreement between the two identities.
pub(crate) fn new(this_node_identity: &Identity, id: Identity) -> Option<Peer> { pub(crate) fn new(this_node_identity: &Identity, id: Identity) -> Option<Peer> {
this_node_identity.agree(&id).map(|static_secret| { this_node_identity.agree(&id).map(|static_secret| {
let aes_k0 = zt_kbkdf_hmac_sha384(&static_secret.0, KBKDF_KEY_USAGE_LABEL_AES_GMAC_SIV_K0, 0, 0); let aes_factory = AesGmacSivPoolFactory(
let aes_k1 = zt_kbkdf_hmac_sha384(&static_secret.0, KBKDF_KEY_USAGE_LABEL_AES_GMAC_SIV_K1, 0, 0); zt_kbkdf_hmac_sha384(&static_secret.0, KBKDF_KEY_USAGE_LABEL_AES_GMAC_SIV_K0, 0, 0),
zt_kbkdf_hmac_sha384(&static_secret.0, KBKDF_KEY_USAGE_LABEL_AES_GMAC_SIV_K1, 0, 0));
let static_secret_hello_dictionary_encrypt = zt_kbkdf_hmac_sha384(&static_secret.0, KBKDF_KEY_USAGE_LABEL_HELLO_DICTIONARY_ENCRYPT, 0, 0); let static_secret_hello_dictionary_encrypt = zt_kbkdf_hmac_sha384(&static_secret.0, KBKDF_KEY_USAGE_LABEL_HELLO_DICTIONARY_ENCRYPT, 0, 0);
let static_secret_packet_hmac = zt_kbkdf_hmac_sha384(&static_secret.0, KBKDF_KEY_USAGE_LABEL_PACKET_HMAC, 0, 0); let static_secret_packet_hmac = zt_kbkdf_hmac_sha384(&static_secret.0, KBKDF_KEY_USAGE_LABEL_PACKET_HMAC, 0, 0);
Peer { Peer {
identity: id, identity: id,
static_secret: static_secret.clone(), static_secret: PeerSecret {
create_time_ticks: -1,
encrypt_count: AtomicU64::new(0),
secret: static_secret,
aes: Pool::new(4, aes_factory),
},
static_secret_hello_dictionary_encrypt, static_secret_hello_dictionary_encrypt,
static_secret_packet_hmac, static_secret_packet_hmac,
tx: Mutex::new(TxState { ephemeral_secret: Mutex::new(None),
last_send_time_ticks: 0, ephemeral_pair: Mutex::new(None),
packet_iv_counter: next_u64_secure(), last_send_time_ticks: AtomicI64::new(0),
total_bytes: 0, last_receive_time_ticks: AtomicI64::new(0),
static_secret: PeerSecrets { total_bytes_sent: AtomicU64::new(0),
create_time_ticks: -1, total_bytes_received: AtomicU64::new(0),
encrypt_count: 0, packet_iv_counter: AtomicU64::new(next_u64_secure()),
secret: static_secret.clone(), remote_version: AtomicU64::new(0),
aes: AesGmacSiv::new(&aes_k0.0, &aes_k1.0), remote_protocol_version: AtomicU8::new(0),
}, paths: Mutex::new(Vec::new()),
ephemeral_secret: None,
paths: Vec::with_capacity(4),
ephemeral_pair: None,
}),
rx: Mutex::new(RxState {
last_receive_time_ticks: 0,
total_bytes: 0,
static_secret: PeerSecrets {
create_time_ticks: -1,
encrypt_count: 0,
secret: static_secret,
aes: AesGmacSiv::new(&aes_k0.0, &aes_k1.0),
},
ephemeral_secret: None,
remote_version: 0,
remote_protocol_version: 0,
}),
} }
}) })
} }
@ -185,30 +163,21 @@ impl Peer {
/// Receive, decrypt, authenticate, and process an incoming packet from this peer. /// Receive, decrypt, authenticate, and process an incoming packet from this peer.
/// If the packet comes in multiple fragments, the fragments slice should contain all /// If the packet comes in multiple fragments, the fragments slice should contain all
/// those fragments after the main packet header and first chunk. /// those fragments after the main packet header and first chunk.
#[inline(always)]
pub(crate) fn receive<CI: VL1CallerInterface, PH: VL1PacketHandler>(&self, node: &Node, ci: &CI, ph: &PH, time_ticks: i64, source_path: &Arc<Path>, header: &PacketHeader, packet: &Buffer<{ PACKET_SIZE_MAX }>, fragments: &[Option<PacketBuffer>]) { pub(crate) fn receive<CI: VL1CallerInterface, PH: VL1PacketHandler>(&self, node: &Node, ci: &CI, ph: &PH, time_ticks: i64, source_path: &Arc<Path>, header: &PacketHeader, packet: &Buffer<{ PACKET_SIZE_MAX }>, fragments: &[Option<PacketBuffer>]) {
let packet_frag0_payload_bytes = packet.as_bytes_after(PACKET_VERB_INDEX).unwrap_or(&[]); let _ = packet.as_bytes_starting_at(PACKET_VERB_INDEX).map(|packet_frag0_payload_bytes| {
if !packet_frag0_payload_bytes.is_empty() {
let mut payload: Buffer<{ PACKET_SIZE_MAX }> = Buffer::new(); let mut payload: Buffer<{ PACKET_SIZE_MAX }> = Buffer::new();
let mut rx = self.rx.lock(); let mut forward_secrecy = true;
let cipher = header.cipher();
// When handling incoming packets we try any current ephemeral secret first, and if that let ephemeral_secret = self.ephemeral_secret.lock().clone();
// fails we fall back to the static secret. If decryption with an ephemeral secret succeeds for secret in [ephemeral_secret.as_ref().map_or(&self.static_secret, |s| s.as_ref()), &self.static_secret] {
// the forward secrecy flag in the receive path is set. match cipher {
let forward_secrecy = {
let mut secret = if rx.ephemeral_secret.is_some() { rx.ephemeral_secret.as_mut().unwrap() } else { &mut rx.static_secret };
loop {
match header.cipher() {
CIPHER_NOCRYPT_POLY1305 => { CIPHER_NOCRYPT_POLY1305 => {
// Only HELLO is allowed in the clear (but still authenticated).
if (packet_frag0_payload_bytes[0] & VERB_MASK) == VERB_VL1_HELLO { if (packet_frag0_payload_bytes[0] & VERB_MASK) == VERB_VL1_HELLO {
let _ = payload.append_bytes(packet_frag0_payload_bytes); let _ = payload.append_bytes(packet_frag0_payload_bytes);
for f in fragments.iter() { for f in fragments.iter() {
let _ = f.as_ref().map(|f| { let _ = f.as_ref().map(|f| f.as_bytes_starting_at(FRAGMENT_HEADER_SIZE).map(|f| payload.append_bytes(f)));
let _ = f.as_bytes_after(FRAGMENT_HEADER_SIZE).map(|f| {
let _ = payload.append_bytes(f);
});
});
} }
// FIPS note: for FIPS purposes the HMAC-SHA384 tag at the end of V2 HELLOs // FIPS note: for FIPS purposes the HMAC-SHA384 tag at the end of V2 HELLOs
@ -218,17 +187,18 @@ impl Peer {
let mut poly1305_key = [0_u8; 32]; let mut poly1305_key = [0_u8; 32];
salsa.crypt_in_place(&mut poly1305_key); salsa.crypt_in_place(&mut poly1305_key);
let mut poly = Poly1305::new(&poly1305_key).unwrap(); let mut poly = Poly1305::new(&poly1305_key).unwrap();
poly.update(packet_frag0_payload_bytes); poly.update(payload.as_bytes());
if poly.finish()[0..8].eq(&header.message_auth) { if poly.finish()[0..8].eq(&header.message_auth) {
break; break;
} }
} else {
// Only HELLO is permitted without payload encryption. Drop other packet types if sent this way.
return;
} }
} }
CIPHER_SALSA2012_POLY1305 => { CIPHER_SALSA2012_POLY1305 => {
// FIPS note: support for this mode would have to be disabled in FIPS compliant
// modes of operation.
let key = salsa_derive_per_packet_key(&secret.secret, header, payload.len()); let key = salsa_derive_per_packet_key(&secret.secret, header, payload.len());
let mut salsa = Salsa::new(&key.0[0..32], header.id_bytes(), true).unwrap(); let mut salsa = Salsa::new(&key.0[0..32], header.id_bytes(), true).unwrap();
let mut poly1305_key = [0_u8; 32]; let mut poly1305_key = [0_u8; 32];
@ -238,12 +208,10 @@ impl Peer {
poly.update(packet_frag0_payload_bytes); poly.update(packet_frag0_payload_bytes);
let _ = payload.append_and_init_bytes(packet_frag0_payload_bytes.len(), |b| salsa.crypt(packet_frag0_payload_bytes, b)); let _ = payload.append_and_init_bytes(packet_frag0_payload_bytes.len(), |b| salsa.crypt(packet_frag0_payload_bytes, b));
for f in fragments.iter() { for f in fragments.iter() {
let _ = f.as_ref().map(|f| { let _ = f.as_ref().map(|f| f.as_bytes_starting_at(FRAGMENT_HEADER_SIZE).map(|f| {
let _ = f.as_bytes_after(FRAGMENT_HEADER_SIZE).map(|f| {
poly.update(f); poly.update(f);
let _ = payload.append_and_init_bytes(f.len(), |b| salsa.crypt(f, b)); let _ = payload.append_and_init_bytes(f.len(), |b| salsa.crypt(f, b));
}); }));
});
} }
if poly.finish()[0..8].eq(&header.message_auth) { if poly.finish()[0..8].eq(&header.message_auth) {
@ -252,20 +220,16 @@ impl Peer {
} }
CIPHER_AES_GMAC_SIV => { CIPHER_AES_GMAC_SIV => {
secret.aes.reset(); let mut aes = secret.aes.get();
secret.aes.decrypt_init(&header.aes_gmac_siv_tag()); aes.decrypt_init(&header.aes_gmac_siv_tag());
secret.aes.decrypt_set_aad(&header.aad_bytes()); aes.decrypt_set_aad(&header.aad_bytes());
let _ = payload.append_and_init_bytes(packet_frag0_payload_bytes.len(), |b| secret.aes.decrypt(packet_frag0_payload_bytes, b)); let _ = payload.append_and_init_bytes(packet_frag0_payload_bytes.len(), |b| aes.decrypt(packet_frag0_payload_bytes, b));
for f in fragments.iter() { for f in fragments.iter() {
let _ = f.as_ref().map(|f| { let _ = f.as_ref().map(|f| f.as_bytes_starting_at(FRAGMENT_HEADER_SIZE).map(|f| payload.append_and_init_bytes(f.len(), |b| aes.decrypt(f, b))));
let _ = f.as_bytes_after(FRAGMENT_HEADER_SIZE).map(|f| {
let _ = payload.append_and_init_bytes(f.len(), |b| secret.aes.decrypt(f, b));
});
});
} }
if secret.aes.decrypt_finish() { if aes.decrypt_finish() {
break; break;
} }
} }
@ -273,28 +237,27 @@ impl Peer {
_ => {} _ => {}
} }
if (secret as *const PeerSecrets) != (&rx.static_secret as *const PeerSecrets) { if (secret as *const PeerSecret) == (&self.static_secret as *const PeerSecret) {
payload.clear(); // If the static secret failed to authenticate it means we either didn't have an
secret = &mut rx.static_secret; // ephemeral key or the ephemeral also failed (as it's tried first).
} else {
// Both ephemeral (if any) and static secret have failed, drop packet.
return; return;
} else {
// If ephemeral failed, static secret will be tried. Set forward secrecy to false.
forward_secrecy = false;
payload.clear();
} }
} }
(secret as *const PeerSecrets) != (&(rx.static_secret) as *const PeerSecrets) drop(ephemeral_secret);
};
// If we make it here we've successfully decrypted and authenticated the packet. // If decryption and authentication succeeded, the code above will break out of the
// for loop and end up here. Otherwise it returns from the whole function.
rx.last_receive_time_ticks = time_ticks; self.last_receive_time_ticks.store(time_ticks, Ordering::Relaxed);
rx.total_bytes += payload.len() as u64; let _ = self.total_bytes_received.fetch_add((payload.len() + PACKET_HEADER_SIZE) as u64, Ordering::Relaxed);
// Unlock rx state mutex.
drop(rx);
let _ = payload.u8_at(0).map(|verb| { let _ = payload.u8_at(0).map(|verb| {
// For performance reasons we let VL2 handle packets first. It returns false // For performance reasons we let VL2 handle packets first. It returns false
// if it didn't pick up anything. // if it didn't handle the packet, in which case it's handled at VL1.
if !ph.handle_packet(self, source_path, forward_secrecy, verb, &payload) { if !ph.handle_packet(self, source_path, forward_secrecy, verb, &payload) {
match verb { match verb {
VERB_VL1_NOP => {} VERB_VL1_NOP => {}
@ -311,13 +274,13 @@ impl Peer {
} }
} }
}); });
} });
} }
/// Get the remote version of this peer: major, minor, revision, and build. /// Get the remote version of this peer: major, minor, revision, and build.
/// Returns None if it's not yet known. /// Returns None if it's not yet known.
pub fn version(&self) -> Option<[u16; 4]> { pub fn version(&self) -> Option<[u16; 4]> {
let rv = self.rx.lock().remote_version; let rv = self.remote_version.load(Ordering::Relaxed);
if rv != 0 { if rv != 0 {
Some([(rv >> 48) as u16, (rv >> 32) as u16, (rv >> 16) as u16, rv as u16]) Some([(rv >> 48) as u16, (rv >> 32) as u16, (rv >> 16) as u16, rv as u16])
} else { } else {
@ -327,7 +290,7 @@ impl Peer {
/// Get the remote protocol version of this peer or None if not yet known. /// Get the remote protocol version of this peer or None if not yet known.
pub fn protocol_version(&self) -> Option<u8> { pub fn protocol_version(&self) -> Option<u8> {
let pv = self.rx.lock().remote_protocol_version; let pv = self.remote_protocol_version.load(Ordering::Relaxed);
if pv != 0 { if pv != 0 {
Some(pv) Some(pv)
} else { } else {