diff --git a/crypto/src/aes_gmac_siv_openssl.rs b/crypto/src/aes_gmac_siv_openssl.rs index f2c81fabf..a5399aa5e 100644 --- a/crypto/src/aes_gmac_siv_openssl.rs +++ b/crypto/src/aes_gmac_siv_openssl.rs @@ -9,69 +9,95 @@ use crate::{cipher_ctx::CipherCtx, ZEROES}; pub struct AesGmacSiv { tag: [u8; 16], tmp: [u8; 16], - k0: Vec, - k1: Vec, - ctr: Option, - gmac: Option, + ecb_enc: CipherCtx, + ecb_dec: CipherCtx, + ctr: CipherCtx, + gmac: CipherCtx, } impl AesGmacSiv { /// Create a new keyed instance of AES-GMAC-SIV /// The key may be of size 16, 24, or 32 bytes (128, 192, or 256 bits). Any other size will panic. pub fn new(k0: &[u8], k1: &[u8]) -> Self { - if k0.len() != 32 && k0.len() != 24 && k0.len() != 16 { - panic!("AES supports 128, 192, or 256 bits keys"); + let gmac = CipherCtx::new().unwrap(); + unsafe { + let t = match k0.len() { + 16 => ffi::EVP_aes_128_gcm(), + 24 => ffi::EVP_aes_192_gcm(), + 32 => ffi::EVP_aes_256_gcm(), + _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), + }; + gmac.cipher_init::(t, k0.as_ptr(), ptr::null_mut()).unwrap(); } - if k1.len() != k0.len() { - panic!("k0 and k1 must be of the same size"); + let ctr = CipherCtx::new().unwrap(); + unsafe { + let t = match k1.len() { + 16 => ffi::EVP_aes_128_ctr(), + 24 => ffi::EVP_aes_192_ctr(), + 32 => ffi::EVP_aes_256_ctr(), + _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), + }; + ctr.cipher_init::(t, k1.as_ptr(), ptr::null_mut()).unwrap(); } + let ecb_enc = CipherCtx::new().unwrap(); + unsafe { + let t = match k1.len() { + 16 => ffi::EVP_aes_128_ecb(), + 24 => ffi::EVP_aes_192_ecb(), + 32 => ffi::EVP_aes_256_ecb(), + _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), + }; + ecb_enc.cipher_init::(t, k1.as_ptr(), ptr::null_mut()).unwrap(); + ffi::EVP_CIPHER_CTX_set_padding(ecb_enc.as_ptr(), 0); + } + let ecb_dec = CipherCtx::new().unwrap(); + unsafe { + let t = match k1.len() { + 16 => ffi::EVP_aes_128_ecb(), + 24 => ffi::EVP_aes_192_ecb(), + 32 => ffi::EVP_aes_256_ecb(), + _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), + }; + ecb_dec.cipher_init::(t, k1.as_ptr(), ptr::null_mut()).unwrap(); + ffi::EVP_CIPHER_CTX_set_padding(ecb_dec.as_ptr(), 0); + } + AesGmacSiv { tag: [0_u8; 16], tmp: [0_u8; 16], - k0: k0.to_vec(), - k1: k1.to_vec(), - ctr: None, - gmac: None, + ecb_dec, + ecb_enc, + ctr, + gmac, } } /// Reset to prepare for another encrypt or decrypt operation. #[inline(always)] - pub fn reset(&mut self) { - let _ = self.ctr.take(); - let _ = self.gmac.take(); - } + pub fn reset(&mut self) {} /// Initialize for encryption. #[inline(always)] pub fn encrypt_init(&mut self, iv: &[u8]) { self.tag[0..8].copy_from_slice(iv); self.tag[8..12].fill(0); - - let ctx = CipherCtx::new().unwrap(); unsafe { - let t = match self.k0.len() { - 16 => ffi::EVP_aes_128_gcm(), - 24 => ffi::EVP_aes_192_gcm(), - 32 => ffi::EVP_aes_256_gcm(), - _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), - }; - ctx.cipher_init::(t, self.k0.as_mut_ptr(), self.tag[0..12].as_ptr()).unwrap(); + self.gmac + .cipher_init::(ptr::null_mut(), ptr::null_mut(), self.tag[0..12].as_ptr()) + .unwrap(); } - let _ = self.gmac.replace(ctx); } /// Set additional authenticated data (data to be authenticated but not encrypted). /// This can currently only be called once. Multiple calls will result in corrupt data. #[inline(always)] pub fn encrypt_set_aad(&mut self, data: &[u8]) { - let gmac = self.gmac.as_mut().unwrap(); unsafe { - gmac.update::(data, ptr::null_mut()).unwrap(); + self.gmac.update::(data, ptr::null_mut()).unwrap(); let mut pad = data.len() & 0xf; if pad != 0 { pad = 16 - pad; - gmac.update::(&ZEROES[0..pad], ptr::null_mut()).unwrap(); + self.gmac.update::(&ZEROES[0..pad], ptr::null_mut()).unwrap(); } } } @@ -81,17 +107,16 @@ impl AesGmacSiv { #[inline(always)] pub fn encrypt_first_pass(&mut self, plaintext: &[u8]) { unsafe { - self.gmac.as_mut().unwrap().update::(plaintext, ptr::null_mut()).unwrap(); + self.gmac.update::(plaintext, ptr::null_mut()).unwrap(); } } /// Finish first pass and begin second pass. #[inline(always)] pub fn encrypt_first_pass_finish(&mut self) { - let gmac = self.gmac.as_mut().unwrap(); unsafe { - gmac.finalize::(self.tmp.as_mut_ptr()).unwrap(); - gmac.tag(&mut self.tmp).unwrap(); + self.gmac.finalize::(ptr::null_mut()).unwrap(); + self.gmac.tag(&mut self.tmp).unwrap(); } self.tag[8] = self.tmp[0] ^ self.tmp[8]; @@ -103,36 +128,19 @@ impl AesGmacSiv { self.tag[14] = self.tmp[6] ^ self.tmp[14]; self.tag[15] = self.tmp[7] ^ self.tmp[15]; - let mut tag_tmp = [0_u8; 32]; + let mut tag_tmp = [0_u8; 16]; - let ctx = CipherCtx::new().unwrap(); unsafe { - let t = match self.k1.len() { - 16 => ffi::EVP_aes_128_ecb(), - 24 => ffi::EVP_aes_192_ecb(), - 32 => ffi::EVP_aes_256_ecb(), - _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), - }; - ctx.cipher_init::(t, self.k1.as_mut_ptr(), ptr::null_mut()).unwrap(); - ffi::EVP_CIPHER_CTX_set_padding(ctx.as_ptr(), 0); - ctx.update::(&self.tag, tag_tmp.as_mut_ptr()).unwrap(); + self.ecb_enc.update::(&self.tag, tag_tmp.as_mut_ptr()).unwrap(); } - self.tag.copy_from_slice(&tag_tmp[0..16]); - self.tmp.copy_from_slice(&tag_tmp[0..16]); + self.tag.copy_from_slice(&tag_tmp); + self.tmp.copy_from_slice(&tag_tmp); self.tmp[12] &= 0x7f; - let ctx = CipherCtx::new().unwrap(); unsafe { - let t = match self.k1.len() { - 16 => ffi::EVP_aes_128_ctr(), - 24 => ffi::EVP_aes_192_ctr(), - 32 => ffi::EVP_aes_256_ctr(), - _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), - }; - ctx.cipher_init::(t, self.k1.as_mut_ptr(), self.tmp.as_ptr()).unwrap(); + self.ctr.cipher_init::(ptr::null_mut(), ptr::null_mut(), self.tmp.as_ptr()).unwrap(); } - let _ = self.ctr.replace(ctx); } /// Feed plaintext for second pass and write ciphertext to supplied buffer. @@ -140,7 +148,7 @@ impl AesGmacSiv { #[inline(always)] pub fn encrypt_second_pass(&mut self, plaintext: &[u8], ciphertext: &mut [u8]) { unsafe { - self.ctr.as_mut().unwrap().update::(plaintext, ciphertext.as_mut_ptr()).unwrap(); + self.ctr.update::(plaintext, ciphertext.as_mut_ptr()).unwrap(); } } @@ -150,7 +158,7 @@ impl AesGmacSiv { pub fn encrypt_second_pass_in_place(&mut self, plaintext_to_ciphertext: &mut [u8]) { unsafe { let out = plaintext_to_ciphertext.as_mut_ptr(); - self.ctr.as_mut().unwrap().update::(plaintext_to_ciphertext, out).unwrap(); + self.ctr.update::(plaintext_to_ciphertext, out).unwrap(); } } @@ -168,46 +176,23 @@ impl AesGmacSiv { self.tmp.copy_from_slice(tag); self.tmp[12] &= 0x7f; - let ctx = CipherCtx::new().unwrap(); unsafe { - let t = match self.k1.len() { - 16 => ffi::EVP_aes_128_ctr(), - 24 => ffi::EVP_aes_192_ctr(), - 32 => ffi::EVP_aes_256_ctr(), - _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), - }; - ctx.cipher_init::(t, self.k1.as_mut_ptr(), self.tmp.as_ptr()).unwrap(); + self.ctr + .cipher_init::(ptr::null_mut(), ptr::null_mut(), self.tmp.as_ptr()) + .unwrap(); } - let _ = self.ctr.replace(ctx); - let mut tag_tmp = [0_u8; 32]; + let mut tag_tmp = [0_u8; 16]; - let ctx = CipherCtx::new().unwrap(); unsafe { - let t = match self.k1.len() { - 16 => ffi::EVP_aes_128_ecb(), - 24 => ffi::EVP_aes_192_ecb(), - 32 => ffi::EVP_aes_256_ecb(), - _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), - }; - ctx.cipher_init::(t, self.k1.as_mut_ptr(), ptr::null_mut()).unwrap(); - ffi::EVP_CIPHER_CTX_set_padding(ctx.as_ptr(), 0); - ctx.update::(&self.tag, tag_tmp.as_mut_ptr()).unwrap(); + self.ecb_dec.update::(&tag, tag_tmp.as_mut_ptr()).unwrap(); } - self.tag.copy_from_slice(&tag_tmp[0..16]); + self.tag.copy_from_slice(&tag_tmp); tag_tmp[8..12].fill(0); - let ctx = CipherCtx::new().unwrap(); unsafe { - let t = match self.k0.len() { - 16 => ffi::EVP_aes_128_gcm(), - 24 => ffi::EVP_aes_192_gcm(), - 32 => ffi::EVP_aes_256_gcm(), - _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), - }; - ctx.cipher_init::(t, self.k0.as_mut_ptr(), self.tag[0..12].as_ptr()).unwrap(); + self.gmac.cipher_init::(ptr::null_mut(), ptr::null_mut(), tag_tmp.as_ptr()).unwrap(); } - let _ = self.gmac.replace(ctx); } /// Set additional authenticated data to be checked. @@ -221,8 +206,8 @@ impl AesGmacSiv { #[inline(always)] pub fn decrypt(&mut self, ciphertext: &[u8], plaintext: &mut [u8]) { unsafe { - self.ctr.as_mut().unwrap().update::(ciphertext, plaintext.as_mut_ptr()).unwrap(); - self.gmac.as_mut().unwrap().update::(plaintext, ptr::null_mut()).unwrap(); + self.ctr.update::(ciphertext, plaintext.as_mut_ptr()).unwrap(); + self.gmac.update::(plaintext, ptr::null_mut()).unwrap(); } } @@ -240,10 +225,9 @@ impl AesGmacSiv { /// If this returns false the message should be dropped. #[inline(always)] pub fn decrypt_finish(&mut self) -> Option<&[u8; 16]> { - let gmac = self.gmac.as_mut().unwrap(); unsafe { - gmac.finalize::(self.tmp.as_mut_ptr()).unwrap(); - gmac.tag(&mut self.tmp).unwrap(); + self.gmac.finalize::(self.tmp.as_mut_ptr()).unwrap(); + self.gmac.tag(&mut self.tmp).unwrap(); } if (self.tag[8] == self.tmp[0] ^ self.tmp[8]) && (self.tag[9] == self.tmp[1] ^ self.tmp[9]) diff --git a/crypto/src/hash.rs b/crypto/src/hash.rs index e71472b9c..c9a7a2327 100644 --- a/crypto/src/hash.rs +++ b/crypto/src/hash.rs @@ -269,19 +269,19 @@ pub fn hmac_sha512(key: &[u8], msg: &[u8]) -> [u8; HMAC_SHA512_SIZE] { hm.finish() } -#[inline(always)] -pub fn hmac_sha512_into(key: &[u8], msg: &[u8], md: &mut [u8]) { +pub fn hmac_sha512_secret256(key: &[u8], msg: &[u8]) -> Secret<32> { let mut hm = HMACSHA512::new(key); hm.update(msg); - hm.finish_into(md); + let mut md = [0u8; HMAC_SHA512_SIZE]; + hm.finish_into(&mut md); + // With such a simple procedure hopefully the compiler implements the following line as a move from md and not a copy + // If not we ought to change this code so we don't leak a secret value on the stack + unsafe { Secret::from_bytes(&md[0..32]) } } - -pub fn hmac_sha512_secret(key: &[u8], msg: &[u8]) -> Secret { - debug_assert!(C <= HMAC_SHA512_SIZE); +pub fn hmac_sha512_secret(key: &[u8], msg: &[u8]) -> Secret { let mut hm = HMACSHA512::new(key); hm.update(msg); - let buff = hm.finish(); - unsafe { Secret::from_bytes(&buff[..C]) } + Secret::move_bytes(hm.finish()) } #[inline(always)] @@ -290,10 +290,3 @@ pub fn hmac_sha384(key: &[u8], msg: &[u8]) -> [u8; HMAC_SHA384_SIZE] { hm.update(msg); hm.finish() } - -#[inline(always)] -pub fn hmac_sha384_into(key: &[u8], msg: &[u8], md: &mut [u8]) { - let mut hm = HMACSHA384::new(key); - hm.update(msg); - hm.finish_into(md); -} diff --git a/utils/src/arc_pool.rs b/utils/src/arc_pool.rs new file mode 100644 index 000000000..4375bf0db --- /dev/null +++ b/utils/src/arc_pool.rs @@ -0,0 +1,769 @@ +use std::fmt::{Debug, Display}; +use std::marker::PhantomData; +use std::mem::{self, ManuallyDrop, MaybeUninit}; +use std::num::NonZeroU64; +use std::ops::Deref; +use std::ptr::{self, NonNull}; +use std::sync::{ + atomic::{AtomicPtr, AtomicU32, Ordering}, + Mutex, RwLock, RwLockReadGuard, +}; + +const DEFAULT_L: usize = 64; + +union SlotState { + empty_next: *mut Slot, + full_obj: ManuallyDrop, +} +struct Slot { + obj: SlotState, + free_lock: RwLock<()>, + ref_count: AtomicU32, + uid: u64, +} + +struct PoolMem { + mem: [MaybeUninit>; L], + pre: *mut PoolMem, +} + +/// A generic, *thread-safe*, fixed-sized memory allocator for instances of `T`. +/// New instances of `T` are packed together into arrays of size `L`, and allocated in bulk as one memory arena from the global allocator. +/// Arenas from the global allocator are not deallocated until the pool is dropped, and are re-used as instances of `T` are allocated and freed. +/// +/// This specific datastructure also supports generational indexing, which means that an arbitrary number of non-owning references to allocated instances of `T` can be generated safely. These references can outlive the underlying `T` they reference, and will safely report upon dereference that the original underlying `T` is gone. +/// +/// Atomic reference counting is also implemented allowing for exceedingly complex models of shared ownership. Multiple copies of both strong and weak references to the underlying `T` can be generated that are all memory safe and borrow-checked. +/// +/// Allocating from a pool results in very little internal and external fragmentation in the global heap, thus saving significant amounts of memory from being used by one's program. Pools also allocate memory significantly faster on average than the global allocator. This specific pool implementation supports guaranteed constant time `alloc` and `free`. +pub struct Pool(Mutex<(*mut Slot, u64, *mut PoolMem, usize)>); +unsafe impl Send for Pool {} +unsafe impl Sync for Pool {} + +impl Pool { + pub const DEFAULT_L: usize = DEFAULT_L; + + /// Creates a new `Pool` with packing length `L`. Packing length determines the number of instances of `T` that will fit in a page before it becomes full. Once all pages in a `Pool` are full a new page is allocated from the LocalNode allocator. Larger values of `L` are generally faster, but the returns are diminishing and vary by platform. + /// + /// A `Pool` cannot be interacted with directly, it requires a `impl StaticPool for Pool` implementation. See the `static_pool!` macro for automatically generated trait implementation. + #[inline] + pub const fn new() -> Self { + Pool(Mutex::new((ptr::null_mut(), 1, ptr::null_mut(), usize::MAX))) + } + + #[inline(always)] + fn create_arr() -> [MaybeUninit>; L] { + unsafe { MaybeUninit::<[MaybeUninit>; L]>::uninit().assume_init() } + } + + /// Allocates uninitialized memory for an instance `T`. The returned pointer points to this memory. It is undefined what will be contained in this memory, it must be initiallized before being used. This pointer must be manually freed from the pool using `Pool::free_ptr` before being dropped, otherwise its memory will be leaked. If the pool is dropped before this pointer is freed, the destructor of `T` will not be run and this pointer will point to invalid memory. + unsafe fn alloc_ptr(&self, obj: T) -> NonNull> { + let mut mutex = self.0.lock().unwrap(); + let (mut first_free, uid, mut head_arena, mut head_size) = *mutex; + + let slot_ptr = if let Some(mut slot_ptr) = NonNull::new(first_free) { + let slot = slot_ptr.as_mut(); + let _announce_free = slot.free_lock.write().unwrap(); + debug_assert_eq!(slot.uid, 0); + first_free = slot.obj.empty_next; + slot.ref_count = AtomicU32::new(1); + slot.uid = uid; + slot.obj.full_obj = ManuallyDrop::new(obj); + slot_ptr + } else { + if head_size >= L { + let new = Box::leak(Box::new(PoolMem { pre: head_arena, mem: Self::create_arr() })); + head_arena = new; + head_size = 0; + } + let slot = Slot { + obj: SlotState { full_obj: ManuallyDrop::new(obj) }, + free_lock: RwLock::new(()), + ref_count: AtomicU32::new(1), + uid, + }; + let slot_ptr = &mut (*head_arena).mem[head_size]; + let slot_ptr = NonNull::new_unchecked(slot_ptr.write(slot)); + head_size += 1; + // We do not have to hold the free lock since we know this slot has never been touched before and nothing external references it + slot_ptr + }; + + *mutex = (first_free, uid.wrapping_add(1), head_arena, head_size); + slot_ptr + } + /// Frees memory allocated from the pool by `Pool::alloc_ptr`. This must be called only once on only pointers returned by `Pool::alloc_ptr` from the same pool. Once memory is freed the content of the memory is undefined, it should not be read or written. + /// + /// `drop` will be called on the `T` pointed to, be sure it has not been called already. + /// + /// The free lock must be held by the caller. + unsafe fn free_ptr(&self, mut slot_ptr: NonNull>) { + let slot = slot_ptr.as_mut(); + slot.uid = 0; + ManuallyDrop::::drop(&mut slot.obj.full_obj); + //linked-list insert + let mut mutex = self.0.lock().unwrap(); + + slot.obj.empty_next = mutex.0; + mutex.0 = slot_ptr.as_ptr(); + } +} +impl Drop for Pool { + fn drop(&mut self) { + let mutex = self.0.lock().unwrap(); + let (_, _, mut head_arena, _) = *mutex; + unsafe { + while !head_arena.is_null() { + let mem = Box::from_raw(head_arena); + head_arena = mem.pre; + drop(mem); + } + } + drop(mutex); + } +} + +pub trait StaticPool { + /// Must return a pointer to an instance of a `Pool` with a static lifetime. That pointer must be cast to a `*const ()` to make the borrow-checker happy. + /// + /// **Safety**: The returned pointer must have originally been a `&'static Pool` reference. So it must have had a matching `T` and `L` and it must have the static lifetime. + /// + /// In order to borrow-split allocations from a `Pool`, we need to force the borrow-checker to not associate the lifetime of an instance of `T` with the lifetime of the pool. Otherwise the borrow-checker would require every allocated `T` to have the `'static` lifetime, to match the pool's lifetime. + /// The simplest way I have found to do this is to return the pointer to the static pool as an anonymous, lifetimeless `*const ()`. This introduces unnecessary safety concerns surrounding pointer casting unfortunately. If there is a better way to borrow-split from a pool I will gladly implement it. + unsafe fn get_static_pool() -> *const (); + + /// Allocates memory for an instance `T` and puts its pointer behind a memory-safe Arc. This `PoolArc` automatically frees itself on drop, and will cause the borrow checker to complain if you attempt to drop the pool before you drop this box. + /// + /// This `PoolArc` supports the ability to generate weak, non-owning references to the allocated `T`. + #[inline(always)] + fn alloc(obj: T) -> PoolArc + where + Self: Sized, + { + unsafe { + PoolArc { + ptr: (*Self::get_static_pool().cast::>()).alloc_ptr(obj), + _p: PhantomData, + } + } + } +} + +/// A rust-style RAII wrapper that drops and frees memory allocated from a pool automatically, the same as an `Arc`. This will run the destructor of `T` in place within the pool before freeing it, correctly maintaining the invariants that the borrow checker and rust compiler expect of generic types. +pub struct PoolArc, const L: usize = DEFAULT_L> { + ptr: NonNull>, + _p: PhantomData<*const OriginPool>, +} + +impl, const L: usize> PoolArc { + /// Obtain a non-owning reference to the `T` contained in this `PoolArc`. This reference has the special property that the underlying `T` can be dropped from the pool while neither making this reference invalid or unsafe nor leaking the memory of `T`. Instead attempts to `grab` the reference will safely return `None`. + /// + /// `T` is guaranteed to be dropped when all `PoolArc` are dropped, regardless of how many `PoolWeakRef` still exist. + #[inline] + pub fn downgrade(&self) -> PoolWeakRef { + unsafe { + // Since this is a Arc we know for certain the object has not been freed, so we don't have to hold the free lock + PoolWeakRef { + ptr: self.ptr, + uid: NonZeroU64::new_unchecked(self.ptr.as_ref().uid), + _p: PhantomData, + } + } + } + /// Returns a number that uniquely identifies this allocated `T` within this pool. No other instance of `T` may have this uid. + pub fn uid(&self) -> NonZeroU64 { + unsafe { NonZeroU64::new_unchecked(self.ptr.as_ref().uid) } + } +} + +impl, const L: usize> Deref for PoolArc { + type Target = T; + #[inline] + fn deref(&self) -> &Self::Target { + unsafe { &self.ptr.as_ref().obj.full_obj } + } +} +impl, const L: usize> Clone for PoolArc { + fn clone(&self) -> Self { + unsafe { + self.ptr.as_ref().ref_count.fetch_add(1, Ordering::Relaxed); + } + Self { ptr: self.ptr, _p: PhantomData } + } +} +impl, const L: usize> Drop for PoolArc { + #[inline] + fn drop(&mut self) { + unsafe { + let slot = self.ptr.as_ref(); + if slot.ref_count.fetch_sub(1, Ordering::AcqRel) == 1 { + let _announce_free = slot.free_lock.write().unwrap(); + // We have to check twice in case a weakref was upgraded before the lock was acquired + if slot.ref_count.load(Ordering::Relaxed) == 0 { + (*OriginPool::get_static_pool().cast::>()).free_ptr(self.ptr); + } + } + } + } +} +unsafe impl, const L: usize> Send for PoolArc where T: Send {} +unsafe impl, const L: usize> Sync for PoolArc where T: Sync {} +impl, const L: usize> Debug for PoolArc +where + T: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("PoolArc").field(self.deref()).finish() + } +} +impl, const L: usize> Display for PoolArc +where + T: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.deref().fmt(f) + } +} + +/// A non-owning reference to a `T` allocated by a pool. This reference has the special property that the underlying `T` can be dropped from the pool while neither making this reference invalid nor leaking the memory of `T`. Instead attempts to `grab` this reference will safely return `None` if the underlying `T` has been freed by any thread. +/// +/// Due to their thread safety and low overhead a `PoolWeakRef` implements clone and copy. +/// +/// The lifetime of this reference is tied to the lifetime of the pool it came from, because if it were allowed to live longer than its origin pool, it would no longer be safe to dereference and would most likely segfault. Instead the borrow-checker will enforce that this reference has a shorter lifetime that its origin pool. +/// +/// For technical reasons a `RwLock>` will always be the fastest implementation of a `PoolWeakRefSwap`, which is why this library does not provide a `PoolWeakRefSwap` type. +pub struct PoolWeakRef, const L: usize = DEFAULT_L> { + /// A number that uniquely identifies this allocated `T` within this pool. No other instance of `T` may have this uid. This value is read-only. + pub uid: NonZeroU64, + ptr: NonNull>, + _p: PhantomData<*const OriginPool>, +} + +impl, const L: usize> PoolWeakRef { + /// Obtains a lock that allows the `T` contained in this `PoolWeakRef` to be dereferenced in a thread-safe manner. This lock does not prevent other threads from accessing `T` at the same time, so `T` ought to use interior mutability if it needs to be mutated in a thread-safe way. What this lock does guarantee is that `T` cannot be destructed and freed while it is being held. + /// + /// Do not attempt from within the same thread to drop the `PoolArc` that owns this `T` before dropping this lock, or else the thread will deadlock. Rust makes this quite hard to do accidentally but it's not strictly impossible. + #[inline] + pub fn grab<'b>(&self) -> Option> { + unsafe { + let slot = self.ptr.as_ref(); + let prevent_free_lock = slot.free_lock.read().unwrap(); + if slot.uid == self.uid.get() { + Some(PoolGuard(prevent_free_lock, &slot.obj.full_obj)) + } else { + None + } + } + } + /// Attempts to create an owning `PoolArc` from this `PoolWeakRef` of the underlying `T`. Will return `None` if the underlying `T` has already been dropped. + pub fn upgrade(&self) -> Option> { + unsafe { + let slot = self.ptr.as_ref(); + let _prevent_free_lock = slot.free_lock.read().unwrap(); + if slot.uid == self.uid.get() { + self.ptr.as_ref().ref_count.fetch_add(1, Ordering::Relaxed); + Some(PoolArc { ptr: self.ptr, _p: PhantomData }) + } else { + None + } + } + } +} +impl, const L: usize> Clone for PoolWeakRef { + fn clone(&self) -> Self { + Self { uid: self.uid, ptr: self.ptr, _p: PhantomData } + } +} +impl, const L: usize> Copy for PoolWeakRef {} +unsafe impl, const L: usize> Send for PoolWeakRef where T: Send {} +unsafe impl, const L: usize> Sync for PoolWeakRef where T: Sync {} +impl, const L: usize> Debug for PoolWeakRef +where + T: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let inner = self.grab(); + f.debug_tuple("PoolWeakRef").field(&inner).finish() + } +} +impl, const L: usize> Display for PoolWeakRef +where + T: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(inner) = self.grab() { + inner.fmt(f) + } else { + f.write_str("Empty") + } + } +} + +/// A multithreading lock guard that prevents another thread from freeing the underlying `T` while it is held. It does not prevent other threads from accessing the underlying `T`. +/// +/// If the same thread that holds this guard attempts to free `T` before dropping the guard, it will deadlock. +pub struct PoolGuard<'a, T>(RwLockReadGuard<'a, ()>, &'a T); +impl<'a, T> Deref for PoolGuard<'a, T> { + type Target = T; + #[inline] + fn deref(&self) -> &Self::Target { + &*self.1 + } +} +impl<'a, T> Debug for PoolGuard<'a, T> +where + T: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("PoolGuard").field(self.deref()).finish() + } +} +impl<'a, T> Display for PoolGuard<'a, T> +where + T: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.deref().fmt(f) + } +} + +/// Allows for the Atomic Swapping and Loading of a `PoolArc`, similar to how a `RwLock>` would function, but much faster and less verbose. +pub struct PoolArcSwap, const L: usize = DEFAULT_L> { + ptr: AtomicPtr>, + reads: AtomicU32, + _p: PhantomData<*const OriginPool>, +} +impl, const L: usize> PoolArcSwap { + /// Creates a new `PoolArcSwap`, consuming `arc` in the process. + pub fn new(mut arc: PoolArc) -> Self { + unsafe { + let ret = Self { + ptr: AtomicPtr::new(arc.ptr.as_mut()), + reads: AtomicU32::new(0), + _p: arc._p, + }; + // Suppress reference decrement on new + mem::forget(arc); + ret + } + } + /// Atomically swaps the currently stored `PoolArc` with a new one, returning the previous one. + pub fn swap(&self, arc: PoolArc) -> PoolArc { + unsafe { + let pre_ptr = self.ptr.swap(arc.ptr.as_ptr(), Ordering::Relaxed); + + while self.reads.load(Ordering::Acquire) > 0 { + std::hint::spin_loop() + } + + mem::forget(arc); + PoolArc { ptr: NonNull::new_unchecked(pre_ptr), _p: self._p } + } + } + + /// Atomically loads and clones the currently stored `PoolArc`, guaranteeing that the underlying `T` cannot be freed while the clone is held. + pub fn load(&self) -> PoolArc { + unsafe { + self.reads.fetch_add(1, Ordering::Acquire); + let ptr = self.ptr.load(Ordering::Relaxed); + (*ptr).ref_count.fetch_add(1, Ordering::Relaxed); + self.reads.fetch_sub(1, Ordering::Release); + PoolArc { ptr: NonNull::new_unchecked(ptr), _p: self._p } + } + } +} +impl, const L: usize> Drop for PoolArcSwap { + #[inline] + fn drop(&mut self) { + unsafe { + let pre = self.ptr.load(Ordering::SeqCst); + PoolArc { _p: self._p, ptr: NonNull::new_unchecked(pre) }; + } + } +} +unsafe impl, const L: usize> Send for PoolArcSwap where T: Send {} +unsafe impl, const L: usize> Sync for PoolArcSwap where T: Sync {} +impl, const L: usize> Debug for PoolArcSwap +where + T: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("PoolArcSwap").field(&self.load()).finish() + } +} +impl, const L: usize> Display for PoolArcSwap +where + T: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + (&self.load()).fmt(f) + } +} + +/// Another implementation of a `PoolArcSwap` utalizing a RwLock instead of atomics. +/// This implementation has slower a `load` but a faster `swap` than the previous implementation of `PoolArcSwap`. +/// If you plan on swapping way more often than loading, this may be a better choice. +pub struct PoolArcSwapRw, const L: usize = DEFAULT_L> { + ptr: RwLock>>, + _p: PhantomData<*const OriginPool>, +} + +impl, const L: usize> PoolArcSwapRw { + /// Creates a new `PoolArcSwap`, consuming `arc` in the process. + pub fn new(arc: PoolArc) -> Self { + let ret = Self { ptr: RwLock::new(arc.ptr), _p: arc._p }; + mem::forget(arc); + ret + } + + /// Atomically swaps the currently stored `PoolArc` with a new one, returning the previous one. + pub fn swap(&self, arc: PoolArc) -> PoolArc { + let mut w = self.ptr.write().unwrap(); + let pre = PoolArc { ptr: *w, _p: self._p }; + *w = arc.ptr; + mem::forget(arc); + pre + } + + /// Atomically loads and clones the currently stored `PoolArc`, guaranteeing that the underlying `T` cannot be freed while the clone is held. + pub fn load(&self) -> PoolArc { + let r = self.ptr.read().unwrap(); + unsafe { + r.as_ref().ref_count.fetch_add(1, Ordering::Relaxed); + } + let pre = PoolArc { ptr: *r, _p: self._p }; + pre + } +} +impl, const L: usize> Drop for PoolArcSwapRw { + #[inline] + fn drop(&mut self) { + let w = self.ptr.write().unwrap(); + PoolArc { ptr: *w, _p: self._p }; + } +} +unsafe impl, const L: usize> Send for PoolArcSwapRw where T: Send {} +unsafe impl, const L: usize> Sync for PoolArcSwapRw where T: Sync {} +impl, const L: usize> Debug for PoolArcSwapRw +where + T: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("PoolArcSwapRw").field(&self.load()).finish() + } +} +impl, const L: usize> Display for PoolArcSwapRw +where + T: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + (&self.load()).fmt(f) + } +} + +/// Automatically generates valid implementations of `StaticPool` onto a chosen identifier, allowing this module to allocate instances of `T` with `alloc`. Users have to generate implementations clientside because rust does not allow for generic globals. +/// +/// The chosen identifier is declared to be a struct with no fields, and instead contains a static global `Pool` for every implementation of `StaticPool` requested. +/// +/// # Example +/// ``` +/// use zerotier_utils::arc_pool::{static_pool, StaticPool, Pool, PoolArc}; +/// +/// static_pool!(pub StaticPool MyPools { +/// Pool, Pool<&u32, 12> +/// }); +/// +/// struct Container { +/// item: PoolArc +/// } +/// +/// let object = 1u32; +/// let arc_object = MyPools::alloc(object); +/// let arc_ref = MyPools::alloc(&object); +/// let arc_container = Container {item: MyPools::alloc(object)}; +/// +/// assert_eq!(*arc_object, **arc_ref); +/// assert_eq!(*arc_object, *arc_container.item); +/// ``` +#[macro_export] +macro_rules! __static_pool__ { + ($m:ident $s:ident { $($($p:ident)::+<$t:ty$(, $l:tt)?>),+ $(,)?}) => { + struct $s {} + $( + impl $m<$t$(, $l)?> for $s { + #[inline(always)] + unsafe fn get_static_pool() -> *const () { + static POOL: $($p)::+<$t$(, $l)?> = $($p)::+::new(); + (&POOL as *const $($p)::+<$t$(, $l)?>).cast() + } + } + )* + }; + ($m:ident::$n:ident $s:ident { $($($p:ident)::+<$t:ty$(, $l:tt)?>),+ $(,)?}) => { + struct $s {} + $( + impl $m::$n<$t$(, $l)?> for $s { + #[inline(always)] + unsafe fn get_static_pool() -> *const () { + static POOL: $($p)::+<$t$(, $l)?> = $($p)::+::new(); + (&POOL as *const $($p)::+<$t$(, $l)?>).cast() + } + } + )* + }; + (pub $m:ident $s:ident { $($($p:ident)::+<$t:ty$(, $l:tt)?>),+ $(,)?}) => { + pub struct $s {} + $( + impl $m<$t$(, $l)?> for $s { + #[inline(always)] + unsafe fn get_static_pool() -> *const () { + static POOL: $($p)::+<$t$(, $l)?> = $($p)::+::new(); + (&POOL as *const $($p)::+<$t$(, $l)?>).cast() + } + } + )* + }; + (pub $m:ident::$n:ident $s:ident { $($($p:ident)::+<$t:ty$(, $l:tt)?>),+ $(,)?}) => { + pub struct $s {} + $( + impl $m::$n<$t$(, $l)?> for $s { + #[inline(always)] + unsafe fn get_static_pool() -> *const () { + static POOL: $($p)::+<$t$(, $l)?> = $($p)::+::new(); + (&POOL as *const $($p)::+<$t$(, $l)?>).cast() + } + } + )* + }; +} +pub use __static_pool__ as static_pool; + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + sync::{atomic::AtomicU64, Arc}, + thread, + }; + + fn rand(r: &mut u32) -> u32 { + /* Algorithm "xor" from p. 4 of Marsaglia, "Xorshift RNGs" */ + *r ^= *r << 13; + *r ^= *r >> 17; + *r ^= *r << 5; + *r + } + const fn prob(p: u64) -> u32 { + (p * (u32::MAX as u64) / 100) as u32 + } + fn rand_idx<'a, T>(v: &'a [T], r: &mut u32) -> Option<&'a T> { + if v.len() > 0 { + Some(&v[(rand(r) as usize) % v.len()]) + } else { + None + } + } + fn rand_i<'a, T>(v: &'a [T], r: &mut u32) -> Option { + if v.len() > 0 { + Some((rand(r) as usize) % v.len()) + } else { + None + } + } + + struct Item { + a: u32, + count: &'static AtomicU64, + b: u32, + } + impl Item { + fn new(r: u32, count: &'static AtomicU64) -> Item { + count.fetch_add(1, Ordering::Relaxed); + Item { a: r, count, b: r } + } + fn check(&self, id: u32) { + assert_eq!(self.a, self.b); + assert_eq!(self.a, id); + } + } + impl Drop for Item { + fn drop(&mut self) { + let _a = self.count.fetch_sub(1, Ordering::Relaxed); + assert_eq!(self.a, self.b); + } + } + + const POOL_U32_LEN: usize = (5 * 12) << 2; + static_pool!(StaticPool TestPools { + Pool, Pool + }); + + #[test] + fn usage() { + let num1 = TestPools::alloc(1u32); + let num2 = TestPools::alloc(2u32); + let num3 = TestPools::alloc(3u32); + let num4 = TestPools::alloc(4u32); + let num2_weak = num2.downgrade(); + + assert_eq!(*num2_weak.grab().unwrap(), 2); + drop(num2); + + assert_eq!(*num1, 1); + assert_eq!(*num3, 3); + assert_eq!(*num4, 4); + assert!(num2_weak.grab().is_none()); + } + #[test] + fn single_thread() { + let mut history = Vec::new(); + + let num1 = TestPools::alloc(1u32); + let num2 = TestPools::alloc(2u32); + let num3 = TestPools::alloc(3u32); + let num4 = TestPools::alloc(4u32); + let num2_weak = num2.downgrade(); + + for i in 0..1000 { + history.push(TestPools::alloc(i as u32)); + } + for i in 0..100 { + let arc = history.remove((i * 10) % history.len()); + assert!(*arc < 1000); + } + for i in 0..1000 { + history.push(TestPools::alloc(i as u32)); + } + + assert_eq!(*num2_weak.grab().unwrap(), 2); + drop(num2); + + assert_eq!(*num1, 1); + assert_eq!(*num3, 3); + assert_eq!(*num4, 4); + assert!(num2_weak.grab().is_none()); + } + + #[test] + fn multi_thread() { + const N: usize = 12345; + static COUNT: AtomicU64 = AtomicU64::new(0); + + let mut joins = Vec::new(); + for i in 0..32 { + joins.push(thread::spawn(move || { + let r = &mut (i + 1234); + + let mut items_dup = Vec::new(); + let mut items = Vec::new(); + for _ in 0..N { + let p = rand(r); + if p < prob(30) { + let id = rand(r); + let s = TestPools::alloc(Item::new(id, &COUNT)); + items.push((id, s.clone(), s.downgrade())); + s.check(id); + } else if p < prob(60) { + if let Some((id, s, w)) = rand_idx(&items, r) { + items_dup.push((*id, s.clone(), (*w).clone())); + s.check(*id); + } + } else if p < prob(80) { + if let Some(i) = rand_i(&items, r) { + let (id, s, w) = items.swap_remove(i); + w.grab().unwrap().check(id); + s.check(id); + } + } else if p < prob(100) { + if let Some(i) = rand_i(&items_dup, r) { + let (id, s, w) = items_dup.swap_remove(i); + w.grab().unwrap().check(id); + s.check(id); + } + } + } + for (id, s, w) in items_dup { + s.check(id); + w.grab().unwrap().check(id); + } + for (id, s, w) in items { + s.check(id); + w.grab().unwrap().check(id); + drop(s); + assert!(w.grab().is_none()) + } + })); + } + for j in joins { + j.join().unwrap(); + } + assert_eq!(COUNT.load(Ordering::Relaxed), 0); + } + + #[test] + fn multi_thread_swap() { + const N: usize = 1234; + static COUNT: AtomicU64 = AtomicU64::new(0); + + let s = Arc::new(PoolArcSwap::new(TestPools::alloc(Item::new(0, &COUNT)))); + + for _ in 0..123 { + let mut joins = Vec::new(); + for _ in 0..8 { + let swaps = s.clone(); + joins.push(thread::spawn(move || { + let r = &mut 1474; + let mut new = TestPools::alloc(Item::new(rand(r), &COUNT)); + for _ in 0..N { + new = swaps.swap(new); + } + })); + } + for j in joins { + j.join().unwrap(); + } + } + drop(s); + assert_eq!(COUNT.load(Ordering::Relaxed), 0); + } + + #[test] + fn multi_thread_swap_load() { + const N: usize = 12345; + static COUNT: AtomicU64 = AtomicU64::new(0); + + let s: Arc<[_; 8]> = Arc::new(std::array::from_fn(|i| PoolArcSwap::new(TestPools::alloc(Item::new(i as u32, &COUNT))))); + + let mut joins = Vec::new(); + + for i in 0..4 { + let swaps = s.clone(); + joins.push(thread::spawn(move || { + let r = &mut (i + 2783); + for _ in 0..N { + if let Some(s) = rand_idx(&swaps[..], r) { + let new = TestPools::alloc(Item::new(rand(r), &COUNT)); + let _a = s.swap(new); + } + } + })); + } + for i in 0..28 { + let swaps = s.clone(); + joins.push(thread::spawn(move || { + let r = &mut (i + 4136); + for _ in 0..N { + if let Some(s) = rand_idx(&swaps[..], r) { + let _a = s.load(); + assert_eq!(_a.a, _a.b); + } + } + })); + } + for j in joins { + j.join().unwrap(); + } + drop(s); + assert_eq!(COUNT.load(Ordering::Relaxed), 0); + } +} diff --git a/zssp/Cargo.toml b/zssp/Cargo.toml index d4888b865..1d381576a 100644 --- a/zssp/Cargo.toml +++ b/zssp/Cargo.toml @@ -25,3 +25,4 @@ doc = false zerotier-utils = { path = "../utils" } zerotier-crypto = { path = "../crypto" } pqc_kyber = { version = "0.4.0", default-features = false, features = ["kyber1024", "std"] } +hex-literal = "0.3.4" diff --git a/zssp/src/applicationlayer.rs b/zssp/src/applicationlayer.rs index 1c89ff72f..ef8df25eb 100644 --- a/zssp/src/applicationlayer.rs +++ b/zssp/src/applicationlayer.rs @@ -71,7 +71,7 @@ pub trait ApplicationLayer: Sized { /// /// A physical path could be an IP address or IP plus device in the case of UDP, a socket in the /// case of TCP, etc. - type PhysicalPath: PartialEq + Eq + Hash + Clone; + type PhysicalPath: Hash; /// Get a reference to this host's static public key blob. /// diff --git a/zssp/src/fragged.rs b/zssp/src/fragged.rs index 01a0c3b20..47370158b 100644 --- a/zssp/src/fragged.rs +++ b/zssp/src/fragged.rs @@ -3,6 +3,7 @@ use std::ptr::slice_from_raw_parts; /// Fast packet defragmenter pub struct Fragged { + count: u32, have: u64, counter: u64, frags: [MaybeUninit; MAX_FRAGMENTS], @@ -35,65 +36,57 @@ impl Fragged { // that the array of MaybeUninit can be freely cast into an array of // Fragment. They also check that the maximum number of fragments is not too large // for the fact that we use bits in a u64 to track which fragments are received. - assert!(MAX_FRAGMENTS <= 64); - assert_eq!(size_of::>(), size_of::()); - assert_eq!( + debug_assert!(MAX_FRAGMENTS <= 64); + debug_assert_eq!(size_of::>(), size_of::()); + debug_assert_eq!( size_of::<[MaybeUninit; MAX_FRAGMENTS]>(), size_of::<[Fragment; MAX_FRAGMENTS]>() ); unsafe { zeroed() } } + /// Returns the counter value associated with the packet currently being assembled. + /// If no packet is currently being assembled it returns 0. + #[inline(always)] + pub fn counter(&self) -> u64 { + self.counter + } /// Add a fragment and return an assembled packet container if all fragments have been received. /// /// When a fully assembled packet is returned the internal state is reset and this object can /// be reused to assemble another packet. - #[inline(always)] pub fn assemble(&mut self, counter: u64, fragment: Fragment, fragment_no: u8, fragment_count: u8) -> Option> { if fragment_no < fragment_count && (fragment_count as usize) <= MAX_FRAGMENTS { - let mut have = self.have; - // If the counter has changed, reset the structure to receive a new packet. if counter != self.counter { + self.drop_in_place(); + self.count = fragment_count as u32; self.counter = counter; - if needs_drop::() { - let mut i = 0; - while have != 0 { - if (have & 1) != 0 { - debug_assert!(i < MAX_FRAGMENTS); - unsafe { self.frags.get_unchecked_mut(i).assume_init_drop() }; - } - have = have.wrapping_shr(1); - i += 1; - } - } else { - have = 0; + } + + let got = 1u64.wrapping_shl(fragment_no as u32); + if got & self.have == 0 && self.count as u8 == fragment_count { + self.have |= got; + unsafe { + self.frags.get_unchecked_mut(fragment_no as usize).write(fragment); + } + if self.have == 1u64.wrapping_shl(self.count) - 1 { + self.have = 0; + self.count = 0; + self.counter = 0; + // Setting 'have' to 0 resets the state of this object, and the fragments + // are effectively moved into the Assembled<> container and returned. That + // container will drop them when it is dropped. + return Some(Assembled(unsafe { std::mem::transmute_copy(&self.frags) }, fragment_count as usize)); } - } - - unsafe { - self.frags.get_unchecked_mut(fragment_no as usize).write(fragment); - } - - let want = 0xffffffffffffffffu64.wrapping_shr((64 - fragment_count) as u32); - have |= 1u64.wrapping_shl(fragment_no as u32); - if (have & want) == want { - self.have = 0; - // Setting 'have' to 0 resets the state of this object, and the fragments - // are effectively moved into the Assembled<> container and returned. That - // container will drop them when it is dropped. - return Some(Assembled(unsafe { std::mem::transmute_copy(&self.frags) }, fragment_count as usize)); - } else { - self.have = have; } } return None; } -} -impl Drop for Fragged { + /// Drops any remaining fragments and resets this object. #[inline(always)] - fn drop(&mut self) { + pub fn drop_in_place(&mut self) { if needs_drop::() { let mut have = self.have; let mut i = 0; @@ -106,5 +99,15 @@ impl Drop for Fragged Drop for Fragged { + #[inline(always)] + fn drop(&mut self) { + self.drop_in_place(); } } diff --git a/zssp/src/main.rs b/zssp/src/main.rs index 8bf5a9684..0df766689 100644 --- a/zssp/src/main.rs +++ b/zssp/src/main.rs @@ -46,7 +46,7 @@ fn alice_main( alice_out: mpsc::SyncSender>, alice_in: mpsc::Receiver>, ) { - let context = zssp::Context::::new(16, TEST_MTU); + let context = zssp::Context::::new(TEST_MTU); let mut data_buf = [0u8; 65536]; let mut next_service = ms_monotonic() + 500; let mut last_ratchet_count = 0; @@ -157,7 +157,7 @@ fn bob_main( bob_out: mpsc::SyncSender>, bob_in: mpsc::Receiver>, ) { - let context = zssp::Context::::new(16, TEST_MTU); + let context = zssp::Context::::new(TEST_MTU); let mut data_buf = [0u8; 65536]; let mut data_buf_2 = [0u8; TEST_MTU]; let mut last_ratchet_count = 0; diff --git a/zssp/src/proto.rs b/zssp/src/proto.rs index 7952c9d30..06dc1dfaa 100644 --- a/zssp/src/proto.rs +++ b/zssp/src/proto.rs @@ -8,8 +8,9 @@ use std::mem::size_of; +use hex_literal::hex; use pqc_kyber::{KYBER_CIPHERTEXTBYTES, KYBER_PUBLICKEYBYTES}; -use zerotier_crypto::hash::SHA384_HASH_SIZE; +use zerotier_crypto::hash::SHA512_HASH_SIZE; use zerotier_crypto::p384::P384_PUBLIC_KEY_SIZE; use crate::error::Error; @@ -25,11 +26,13 @@ pub const MIN_TRANSPORT_MTU: usize = 128; pub const MAX_INIT_PAYLOAD_SIZE: usize = MAX_NOISE_HANDSHAKE_SIZE - ALICE_NOISE_XK_ACK_MIN_SIZE; /// Initial value of 'h' -/// echo -n 'Noise_XKpsk3_P384_AESGCM_SHA384_hybridKyber1024' | shasum -a 384 -pub(crate) const INITIAL_H: [u8; SHA384_HASH_SIZE] = [ - 0x35, 0x27, 0x16, 0x62, 0x58, 0x04, 0x0c, 0x7a, 0x99, 0xa8, 0x0b, 0x49, 0xb2, 0x6b, 0x25, 0xfb, 0xf5, 0x26, 0x2a, 0x26, 0xe7, 0xb3, 0x70, 0xcb, - 0x2c, 0x3c, 0xcb, 0x7f, 0xca, 0x20, 0x06, 0x91, 0x20, 0x55, 0x52, 0x8e, 0xd4, 0x3c, 0x97, 0xc3, 0xd5, 0x6c, 0xb4, 0x13, 0x02, 0x54, 0x83, 0x12, -]; +/// echo -n 'Noise_XKpsk3_P384_AESGCM_SHA512_hybridKyber1024' | shasum -a 512 +pub(crate) const INITIAL_H: [u8; SHA512_HASH_SIZE] = + hex!("12ae70954e8d93bf7f73d0fe48d487155666f541e532f9461af5ef52ab90c8fd9259ef9e48f5adcf9af63f869805a570004ae095655dcaddbc226a50623b2b25"); +/// Initial value of 'h' +/// echo -n 'Noise_KKpsk0_P384_AESGCM_SHA512' | shasum -a 512 +pub(crate) const INITIAL_H_REKEY: [u8; SHA512_HASH_SIZE] = + hex!("daeedd651ac9c5173f2eaaff996beebac6f3f1bfe9a70bb1cc54fa1fb2bf46260d71a3c4fb4d4ee36f654c31773a8a15e5d5be974a0668dc7db70f4e13ed172e"); /// Version 0: Noise_XK with NIST P-384 plus Kyber1024 hybrid exchange on session init. pub(crate) const SESSION_PROTOCOL_VERSION: u8 = 0x00; @@ -60,12 +63,13 @@ pub(crate) const KBKDF_KEY_USAGE_LABEL_AES_GCM_ALICE_TO_BOB: u8 = b'A'; // AES-G pub(crate) const KBKDF_KEY_USAGE_LABEL_AES_GCM_BOB_TO_ALICE: u8 = b'B'; // AES-GCM in B->A direction pub(crate) const KBKDF_KEY_USAGE_LABEL_RATCHET: u8 = b'R'; // Key used in derivatin of next session key +pub(crate) const MAX_INCOMPLETE_SESSION_QUEUE_SIZE: usize = 32; pub(crate) const MAX_FRAGMENTS: usize = 48; // hard protocol max: 63 pub(crate) const MAX_NOISE_HANDSHAKE_FRAGMENTS: usize = 16; // enough room for p384 + ZT identity + kyber1024 + tag/hmac/etc. pub(crate) const MAX_NOISE_HANDSHAKE_SIZE: usize = MAX_NOISE_HANDSHAKE_FRAGMENTS * MIN_TRANSPORT_MTU; /// Size of keys used during derivation, mixing, etc. process. -pub(crate) const BASE_KEY_SIZE: usize = 64; +pub(crate) const NOISE_HASHLEN: usize = SHA512_HASH_SIZE; pub(crate) const AES_256_KEY_SIZE: usize = 32; pub(crate) const AES_HEADER_PROTECTION_KEY_SIZE: usize = 16; @@ -160,14 +164,14 @@ pub(crate) struct RekeyAck { pub session_protocol_version: u8, // -- start AES-GCM encrypted portion (using current key) pub bob_e: [u8; P384_PUBLIC_KEY_SIZE], - pub next_key_fingerprint: [u8; SHA384_HASH_SIZE], // SHA384(next secret) + pub next_key_fingerprint: [u8; SHA512_HASH_SIZE], // SHA384(next secret) // -- end AES-GCM encrypted portion pub gcm_tag: [u8; AES_GCM_TAG_SIZE], } impl RekeyAck { pub const ENC_START: usize = HEADER_SIZE + 1; - pub const AUTH_START: usize = Self::ENC_START + P384_PUBLIC_KEY_SIZE + SHA384_HASH_SIZE; + pub const AUTH_START: usize = Self::ENC_START + P384_PUBLIC_KEY_SIZE + SHA512_HASH_SIZE; pub const SIZE: usize = Self::AUTH_START + AES_GCM_TAG_SIZE; } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 4db210db4..544368e6c 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -9,14 +9,17 @@ // ZSSP: ZeroTier Secure Session Protocol // FIPS compliant Noise_XK with Jedi powers (Kyber1024) and built-in attack-resistant large payload (fragmentation) support. -use std::collections::{HashMap, HashSet}; +use std::collections::hash_map::RandomState; +//use std::collections::hash_map::DefaultHasher; +use std::collections::HashMap; +use std::hash::{BuildHasher, Hash, Hasher}; use std::num::NonZeroU64; -use std::sync::atomic::{AtomicI64, AtomicU64, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; use zerotier_crypto::aes::{Aes, AesGcm}; -use zerotier_crypto::hash::{hmac_sha512_secret, SHA384, SHA384_HASH_SIZE}; -use zerotier_crypto::p384::{P384KeyPair, P384PublicKey, P384_ECDH_SHARED_SECRET_SIZE}; +use zerotier_crypto::hash::{hmac_sha512_secret, hmac_sha512_secret256, SHA512}; +use zerotier_crypto::p384::{P384KeyPair, P384PublicKey}; use zerotier_crypto::secret::Secret; use zerotier_crypto::{random, secure_eq}; @@ -36,17 +39,10 @@ const GCM_CIPHER_POOL_SIZE: usize = 4; /// Each application using ZSSP must create an instance of this to own sessions and /// defragment incoming packets that are not yet associated with a session. pub struct Context { - max_incomplete_session_queue_size: usize, default_physical_mtu: AtomicUsize, - defrag: Mutex< - HashMap< - (Application::PhysicalPath, u64), - Arc<( - Mutex>, - i64, // creation timestamp - )>, - >, - >, + defrag_salt: RandomState, + defrag_has_pending: AtomicBool, // Allowed to be falsely positive + defrag: [Mutex<(Fragged, i64)>; MAX_INCOMPLETE_SESSION_QUEUE_SIZE], sessions: RwLock>, } @@ -106,8 +102,8 @@ struct IncomingIncompleteSession { timestamp: i64, alice_session_id: SessionId, bob_session_id: SessionId, - noise_h: [u8; SHA384_HASH_SIZE], - noise_es_ee: Secret, + noise_h: [u8; NOISE_HASHLEN], + noise_ck_es_ee: Secret, hk: Secret, header_protection_key: Secret, bob_noise_e_secret: P384KeyPair, @@ -116,9 +112,9 @@ struct IncomingIncompleteSession { struct OutgoingSessionOffer { last_retry_time: AtomicI64, - psk: Secret, - noise_h: [u8; SHA384_HASH_SIZE], - noise_es: Secret, + psk: Secret, + noise_h: [u8; NOISE_HASHLEN], + noise_ck_es: Secret, alice_noise_e_secret: P384KeyPair, alice_hk_secret: Secret, metadata: Option>, @@ -139,7 +135,7 @@ enum Offer { } struct SessionKey { - ratchet_key: Secret, // Key used in derivation of the next session key + ratchet_key: Secret, // Key used in derivation of the next session key receive_cipher_pool: [Mutex>; GCM_CIPHER_POOL_SIZE], // Pool of reusable sending ciphers send_cipher_pool: [Mutex>; GCM_CIPHER_POOL_SIZE], // Pool of reusable receiving ciphers rekey_at_time: i64, // Rekey at or after this time (ticks) @@ -154,11 +150,12 @@ impl Context { /// Create a new session context. /// /// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase - pub fn new(max_incomplete_session_queue_size: usize, default_physical_mtu: usize) -> Self { + pub fn new(default_physical_mtu: usize) -> Self { Self { - max_incomplete_session_queue_size, default_physical_mtu: AtomicUsize::new(default_physical_mtu), - defrag: Mutex::new(HashMap::new()), + defrag_salt: RandomState::new(), + defrag_has_pending: AtomicBool::new(false), + defrag: std::array::from_fn(|_| Mutex::new((Fragged::new(), i64::MAX))), sessions: RwLock::new(SessionsById { active: HashMap::with_capacity(64), incoming: HashMap::with_capacity(64), @@ -203,7 +200,7 @@ impl Context { PACKET_TYPE_ALICE_NOISE_XK_INIT, None, 0, - 1, + random::next_u64_secure(), None, ); } @@ -251,6 +248,23 @@ impl Context { } } } + // Only check for expiration if we have a pending packet. + // This check is allowed to have false positives for simplicity's sake. + if self.defrag_has_pending.swap(false, Ordering::Relaxed) { + let mut has_pending = false; + for m in &self.defrag { + let mut pending = m.lock().unwrap(); + if pending.1 <= negotiation_timeout_cutoff { + pending.1 = i64::MAX; + pending.0.drop_in_place(); + } else if pending.0.counter() != 0 { + has_pending = true; + } + } + if has_pending { + self.defrag_has_pending.store(true, Ordering::Relaxed); + } + } if !dead_active.is_empty() || !dead_pending.is_empty() { let mut sessions = self.sessions.write().unwrap(); @@ -262,9 +276,6 @@ impl Context { } } - // Delete any expired defragmentation queue items not associated with a session. - self.defrag.lock().unwrap().retain(|_, fragged| fragged.1 > negotiation_timeout_cutoff); - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS.min(Application::RETRY_INTERVAL) } @@ -289,7 +300,7 @@ impl Context { mtu: usize, remote_s_public_blob: &[u8], remote_s_public_p384: P384PublicKey, - psk: Secret, + psk: Secret, metadata: Option>, application_data: Application::Data, current_time: i64, @@ -301,9 +312,14 @@ impl Context { let alice_noise_e_secret = P384KeyPair::generate(); let alice_noise_e = alice_noise_e_secret.public_key_bytes().clone(); let noise_es = alice_noise_e_secret.agree(&remote_s_public_p384).ok_or(Error::InvalidParameter)?; + let noise_h = mix_hash(&mix_hash(&INITIAL_H, remote_s_public_blob), &alice_noise_e); + let noise_ck_es = hmac_sha512_secret(&INITIAL_H, noise_es.as_bytes()); let alice_hk_secret = pqc_kyber::keypair(&mut random::SecureRandom::default()); let header_protection_key: Secret = Secret(random::get_bytes_secure()); + // Init aesgcm before we move noise_ck_es + let mut gcm = AesGcm::new(&kbkdf256::(&noise_ck_es)); + let (local_session_id, session) = { let mut sessions = self.sessions.write().unwrap(); @@ -330,8 +346,8 @@ impl Context { outgoing_offer: Offer::NoiseXKInit(Box::new(OutgoingSessionOffer { last_retry_time: AtomicI64::new(current_time), psk, - noise_h: mix_hash(&mix_hash(&INITIAL_H, remote_s_public_blob), &alice_noise_e), - noise_es: noise_es.clone(), + noise_h, + noise_ck_es, alice_noise_e_secret, alice_hk_secret: Secret(alice_hk_secret.secret), metadata, @@ -366,7 +382,6 @@ impl Context { } // Encrypt and add authentication tag. - let mut gcm = AesGcm::new(&kbkdf::(noise_es.as_bytes())); gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_INIT, 1)); gcm.aad(&offer.noise_h); gcm.crypt_in_place(&mut init_packet[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]); @@ -382,7 +397,7 @@ impl Context { PACKET_TYPE_ALICE_NOISE_XK_INIT, None, 0, - 1, + random::next_u64_secure(), None, )?; } @@ -454,7 +469,7 @@ impl Context { let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); if session.check_receive_window(incoming_counter) { - let (assembled_packet, incoming_packet_buf_arr); + let assembled_packet; let incoming_packet = if fragment_count > 1 { assembled_packet = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO] .lock() @@ -466,8 +481,7 @@ impl Context { return Ok(ReceiveResult::Ok(Some(session))); } } else { - incoming_packet_buf_arr = [incoming_physical_packet_buf]; - &incoming_packet_buf_arr + std::array::from_ref(&incoming_physical_packet_buf) }; return self.process_complete_incoming_packet( @@ -495,7 +509,7 @@ impl Context { .decrypt_block_in_place(&mut incoming_physical_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); - let (assembled_packet, incoming_packet_buf_arr); + let assembled_packet; let incoming_packet = if fragment_count > 1 { assembled_packet = incoming.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO] .lock() @@ -507,8 +521,7 @@ impl Context { return Ok(ReceiveResult::Ok(None)); } } else { - incoming_packet_buf_arr = [incoming_physical_packet_buf]; - &incoming_packet_buf_arr + std::array::from_ref(&incoming_physical_packet_buf) }; return self.process_complete_incoming_packet( @@ -531,48 +544,57 @@ impl Context { } else { let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_physical_packet); - let (assembled_packet, incoming_packet_buf_arr); + let assembled; let incoming_packet = if fragment_count > 1 { - assembled_packet = { - let mut defrag = self.defrag.lock().unwrap(); - let f = defrag - .entry((source.clone(), incoming_counter)) - .or_insert_with(|| Arc::new((Mutex::new(Fragged::new()), current_time))) - .clone(); + // incoming_counter is expected to be a random u64 generated by the remote peer. + // Using just incoming_counter to defragment would be good DOS resistance, + // but why not make it harder by hasing it with a random salt and the physical path as well. + let mut hasher = self.defrag_salt.build_hasher(); + source.hash(&mut hasher); + hasher.write_u64(incoming_counter); + let hashed_counter = hasher.finish(); + let idx0 = (hashed_counter as usize) % MAX_INCOMPLETE_SESSION_QUEUE_SIZE; + let idx1 = (hashed_counter as usize) / MAX_INCOMPLETE_SESSION_QUEUE_SIZE % MAX_INCOMPLETE_SESSION_QUEUE_SIZE; - // Anti-DOS overflow purge of the incoming defragmentation queue for packets not associated with known sessions. - if defrag.len() >= self.max_incomplete_session_queue_size { - // First, drop all entries that are timed out or whose physical source duplicates another entry. - let mut sources = HashSet::with_capacity(defrag.len()); - let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; - defrag - .retain(|k, fragged| (fragged.1 > negotiation_timeout_cutoff && sources.insert(k.0.clone())) || Arc::ptr_eq(fragged, &f)); - - // Then, if we are still at or over the limit, drop 10% of remaining entries at random. - if defrag.len() >= self.max_incomplete_session_queue_size { - let mut rn = random::next_u32_secure(); - defrag.retain(|_, fragged| { - rn = prng32(rn); - rn > (u32::MAX / 10) || Arc::ptr_eq(fragged, &f) - }); - } + // Open hash lookup of just 2 slots. + // By only checking 2 slots we avoid a full table lookup while also minimizing the chance that 2 offers collide. + // To DOS, an adversary would either need to volumetrically spam the defrag table to keep all slots full + // or replay Alice's packet header from a spoofed physical path before Alice's packet is fully assembled. + // Volumetric spam is quite difficult since without the `defrag_salt` value an adversary cannot + // control which slots their fragments index to. And since Alice's packet header has a randomly + // generated counter value replaying it in time requires extreme amounts of network control. + let (slot0, timestamp0) = &mut *self.defrag[idx0].lock().unwrap(); + if slot0.counter() == hashed_counter { + assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if assembled.is_some() { + *timestamp0 = i64::MAX; + } + } else { + let (slot1, timestamp1) = &mut *self.defrag[idx1].lock().unwrap(); + if slot1.counter() == hashed_counter { + assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + if assembled.is_some() { + *timestamp1 = i64::MAX; + } + } else if slot0.counter() == 0 { + *timestamp0 = current_time; + self.defrag_has_pending.store(true, Ordering::Relaxed); + assembled = slot0.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); + } else { + // slot1 is either occupied or empty so we overwrite whatever is there to make more room. + *timestamp1 = current_time; + self.defrag_has_pending.store(true, Ordering::Relaxed); + assembled = slot1.assemble(hashed_counter, incoming_physical_packet_buf, fragment_no, fragment_count); } - - f } - .0 - .lock() - .unwrap() - .assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count); - if let Some(assembled_packet) = assembled_packet.as_ref() { - self.defrag.lock().unwrap().remove(&(source.clone(), incoming_counter)); + + if let Some(assembled_packet) = &assembled { assembled_packet.as_ref() } else { return Ok(ReceiveResult::Ok(None)); } } else { - incoming_packet_buf_arr = [incoming_physical_packet_buf]; - &incoming_packet_buf_arr + std::array::from_ref(&incoming_physical_packet_buf) }; return self.process_complete_incoming_packet( @@ -581,7 +603,7 @@ impl Context { &mut check_allow_incoming_session, &mut check_accept_session, data_buf, - incoming_counter, + 1, // The incoming_counter on init packets is only meant for DOS resistant defragmentation, we do not want to use it for anything noise related. incoming_packet, packet_type, None, @@ -720,7 +742,7 @@ impl Context { * identity, then responds with his ephemeral keys. */ - if incoming_counter != 1 || session.is_some() || incoming.is_some() { + if session.is_some() || incoming.is_some() { return Err(Error::OutOfSequence); } if pkt_assembled.len() != AliceNoiseXKInit::SIZE { @@ -735,9 +757,11 @@ impl Context { let noise_h = mix_hash(&mix_hash(&INITIAL_H, app.get_local_s_public_blob()), alice_noise_e.as_bytes()); let noise_h_next = mix_hash(&noise_h, &pkt_assembled[HEADER_SIZE..]); + let noise_ck_es = hmac_sha512_secret(&INITIAL_H, noise_es.as_bytes()); + drop(noise_es); // Decrypt and authenticate init packet, also proving that caller knows our static identity. - let mut gcm = AesGcm::new(&kbkdf::(noise_es.as_bytes())); + let mut gcm = AesGcm::new(&kbkdf256::(&noise_ck_es)); gcm.reset_init_gcm(&incoming_message_nonce); gcm.aad(&noise_h); gcm.crypt_in_place(&mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]); @@ -754,12 +778,12 @@ impl Context { let alice_session_id = SessionId::new_from_array(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?; let header_protection_key = Secret(pkt.header_protection_key); - // Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create + // Create Bob's ephemeral keys and derive noise_ck by agreeing with Alice's. Also create // a Kyber ciphertext to send back to Alice. let bob_noise_e_secret = P384KeyPair::generate(); let bob_noise_e = bob_noise_e_secret.public_key_bytes().clone(); - let noise_es_ee = hmac_sha512_secret( - noise_es.as_bytes(), + let noise_ck_es_ee = hmac_sha512_secret( + noise_ck_es.as_bytes(), bob_noise_e_secret.agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?.as_bytes(), ); let (bob_hk_ciphertext, hk) = pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default()) @@ -786,7 +810,7 @@ impl Context { ack.bob_hk_ciphertext = bob_hk_ciphertext; // Encrypt main section of reply and attach tag. - let mut gcm = AesGcm::new(&kbkdf::(noise_es_ee.as_bytes())); + let mut gcm = AesGcm::new(&kbkdf256::(&noise_ck_es_ee)); gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_BOB_NOISE_XK_ACK, 1)); gcm.aad(&noise_h_next); gcm.crypt_in_place(&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]); @@ -795,7 +819,7 @@ impl Context { // If this queue is too big, we remove the latest entry and replace it. The latest // is used because under flood conditions this is most likely to be another bogus // entry. If we find one that is actually timed out, that one is replaced instead. - if sessions.incoming.len() >= self.max_incomplete_session_queue_size { + if sessions.incoming.len() >= MAX_INCOMPLETE_SESSION_QUEUE_SIZE { let mut newest = i64::MIN; let mut replace_id = None; let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS; @@ -819,7 +843,7 @@ impl Context { alice_session_id, bob_session_id, noise_h: mix_hash(&mix_hash(&noise_h_next, &bob_noise_e), &ack_packet[HEADER_SIZE..]), - noise_es_ee: noise_es_ee.clone(), + noise_ck_es_ee, hk, bob_noise_e_secret, header_protection_key: Secret(pkt.header_protection_key), @@ -874,8 +898,8 @@ impl Context { // Derive noise_es_ee from Bob's ephemeral public key. let bob_noise_e = P384PublicKey::from_bytes(&pkt.bob_noise_e).ok_or(Error::FailedAuthentication)?; - let noise_es_ee = hmac_sha512_secret::( - outgoing_offer.noise_es.as_bytes(), + let noise_ck_es_ee = hmac_sha512_secret( + outgoing_offer.noise_ck_es.as_bytes(), outgoing_offer .alice_noise_e_secret .agree(&bob_noise_e) @@ -887,7 +911,7 @@ impl Context { let noise_h_next = mix_hash(&mix_hash(&outgoing_offer.noise_h, bob_noise_e.as_bytes()), &pkt_assembled[HEADER_SIZE..]); // Decrypt and authenticate Bob's reply. - let mut gcm = AesGcm::new(&kbkdf::(noise_es_ee.as_bytes())); + let mut gcm = AesGcm::new(&kbkdf256::(&noise_ck_es_ee)); gcm.reset_init_gcm(&incoming_message_nonce); gcm.aad(&outgoing_offer.noise_h); gcm.crypt_in_place(&mut pkt_assembled[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]); @@ -909,9 +933,9 @@ impl Context { // Packet fully authenticated if session.update_receive_window(incoming_counter) { - let noise_es_ee_se_hk_psk = hmac_sha512_secret::( - hmac_sha512_secret::(noise_es_ee.as_bytes(), noise_se.as_bytes()).as_bytes(), - hmac_sha512_secret::(outgoing_offer.psk.as_bytes(), hk.as_bytes()).as_bytes(), + let noise_ck_es_ee_se_hk_psk = hmac_sha512_secret( + hmac_sha512_secret(noise_ck_es_ee.as_bytes(), noise_se.as_bytes()).as_bytes(), + hmac_sha512_secret(outgoing_offer.psk.as_bytes(), hk.as_bytes()).as_bytes(), ); let reply_message_nonce = create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_ACK, 2); @@ -928,9 +952,10 @@ impl Context { let mut enc_start = ack_len; ack_len = append_to_slice(&mut ack, ack_len, alice_s_public_blob)?; - let mut gcm = AesGcm::new(&kbkdf::( - hmac_sha512_secret::(noise_es_ee.as_bytes(), hk.as_bytes()).as_bytes(), - )); + let mut gcm = AesGcm::new(&kbkdf256::(&hmac_sha512_secret( + noise_ck_es_ee.as_bytes(), + hk.as_bytes(), + ))); gcm.reset_init_gcm(&reply_message_nonce); gcm.aad(&noise_h_next); gcm.crypt_in_place(&mut ack[enc_start..ack_len]); @@ -946,9 +971,7 @@ impl Context { enc_start = ack_len; ack_len = append_to_slice(&mut ack, ack_len, metadata)?; - let mut gcm = AesGcm::new(&kbkdf::( - noise_es_ee_se_hk_psk.as_bytes(), - )); + let mut gcm = AesGcm::new(&kbkdf256::(&noise_ck_es_ee_se_hk_psk)); gcm.reset_init_gcm(&reply_message_nonce); gcm.aad(&noise_h_next); gcm.crypt_in_place(&mut ack[enc_start..ack_len]); @@ -961,7 +984,7 @@ impl Context { let mut state = session.state.write().unwrap(); let _ = state.remote_session_id.insert(bob_session_id); let _ = state.keys[0].insert(SessionKey::new::( - noise_es_ee_se_hk_psk, + noise_ck_es_ee_se_hk_psk, 1, current_time, 2, @@ -1034,9 +1057,10 @@ impl Context { let alice_static_public_blob = r.read_decrypt_auth( alice_static_public_blob_size, - kbkdf::( - hmac_sha512_secret::(incoming.noise_es_ee.as_bytes(), incoming.hk.as_bytes()).as_bytes(), - ), + kbkdf256::(&hmac_sha512_secret( + incoming.noise_ck_es_ee.as_bytes(), + incoming.hk.as_bytes(), + )), &incoming.noise_h, &incoming_message_nonce, )?; @@ -1052,9 +1076,9 @@ impl Context { let noise_h_next = mix_hash(&noise_h_next, psk.as_bytes()); // Complete Noise_XKpsk3 on Bob's side. - let noise_es_ee_se_hk_psk = hmac_sha512_secret::( - hmac_sha512_secret::( - incoming.noise_es_ee.as_bytes(), + let noise_ck_es_ee_se_hk_psk = hmac_sha512_secret( + hmac_sha512_secret( + incoming.noise_ck_es_ee.as_bytes(), incoming .bob_noise_e_secret .agree(&alice_noise_s) @@ -1062,7 +1086,7 @@ impl Context { .as_bytes(), ) .as_bytes(), - hmac_sha512_secret::(psk.as_bytes(), incoming.hk.as_bytes()).as_bytes(), + hmac_sha512_secret(psk.as_bytes(), incoming.hk.as_bytes()).as_bytes(), ); // Decrypt meta-data and verify the final key in the process. Copy meta-data @@ -1070,7 +1094,7 @@ impl Context { let alice_meta_data_size = r.read_u16()? as usize; let alice_meta_data = r.read_decrypt_auth( alice_meta_data_size, - kbkdf::(noise_es_ee_se_hk_psk.as_bytes()), + kbkdf256::(&noise_ck_es_ee_se_hk_psk), &noise_h_next, &incoming_message_nonce, )?; @@ -1090,7 +1114,7 @@ impl Context { physical_mtu: self.default_physical_mtu.load(Ordering::Relaxed), remote_session_id: Some(incoming.alice_session_id), keys: [ - Some(SessionKey::new::(noise_es_ee_se_hk_psk, 1, current_time, 2, true, true)), + Some(SessionKey::new::(noise_ck_es_ee_se_hk_psk, 1, current_time, 2, true, true)), None, ], current_key: 0, @@ -1150,9 +1174,13 @@ impl Context { let noise_es = app.get_local_s_keypair().agree(&alice_e).ok_or(Error::FailedAuthentication)?; let noise_ee = bob_e_secret.agree(&alice_e).ok_or(Error::FailedAuthentication)?; let noise_se = bob_e_secret.agree(&session.static_public_key).ok_or(Error::FailedAuthentication)?; - let noise_psk_se_ee_es = hmac_sha512_secret::( - hmac_sha512_secret::( - hmac_sha512_secret::(key.ratchet_key.as_bytes(), noise_es.as_bytes()).as_bytes(), + let noise_ck_psk_es_ee_se = hmac_sha512_secret( + hmac_sha512_secret( + hmac_sha512_secret( + hmac_sha512_secret(&INITIAL_H_REKEY, key.ratchet_key.as_bytes()).as_bytes(), + noise_es.as_bytes(), + ) + .as_bytes(), noise_ee.as_bytes(), ) .as_bytes(), @@ -1165,7 +1193,7 @@ impl Context { let reply: &mut RekeyAck = byte_array_as_proto_buffer_mut(&mut reply_buf).unwrap(); reply.session_protocol_version = SESSION_PROTOCOL_VERSION; reply.bob_e = *bob_e_secret.public_key_bytes(); - reply.next_key_fingerprint = SHA384::hash(noise_psk_se_ee_es.as_bytes()); + reply.next_key_fingerprint = SHA512::hash(noise_ck_psk_es_ee_se.as_bytes()); let counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get(); set_packet_header( @@ -1197,7 +1225,7 @@ impl Context { drop(state); let mut state = session.state.write().unwrap(); let _ = state.keys[key_index ^ 1].replace(SessionKey::new::( - noise_psk_se_ee_es, + noise_ck_psk_es_ee_se, next_ratchet_count, current_time, counter, @@ -1248,9 +1276,13 @@ impl Context { let noise_es = alice_e_secret.agree(&session.static_public_key).ok_or(Error::FailedAuthentication)?; let noise_ee = alice_e_secret.agree(&bob_e).ok_or(Error::FailedAuthentication)?; let noise_se = app.get_local_s_keypair().agree(&bob_e).ok_or(Error::FailedAuthentication)?; - let noise_psk_se_ee_es = hmac_sha512_secret::( - hmac_sha512_secret::( - hmac_sha512_secret::(key.ratchet_key.as_bytes(), noise_es.as_bytes()).as_bytes(), + let noise_ck_psk_es_ee_se = hmac_sha512_secret( + hmac_sha512_secret( + hmac_sha512_secret( + hmac_sha512_secret(&INITIAL_H_REKEY, key.ratchet_key.as_bytes()).as_bytes(), + noise_es.as_bytes(), + ) + .as_bytes(), noise_ee.as_bytes(), ) .as_bytes(), @@ -1259,7 +1291,7 @@ impl Context { // We need to check that the key Bob is acknowledging matches the latest sent offer. // Because of OOO, it might not, in which case this rekey must be cancelled and retried. - if secure_eq(&pkt.next_key_fingerprint, &SHA384::hash(noise_psk_se_ee_es.as_bytes())) { + if secure_eq(&pkt.next_key_fingerprint, &SHA512::hash(noise_ck_psk_es_ee_se.as_bytes())) { if session.update_receive_window(incoming_counter) { // The new "Alice" knows Bob has the key since this is an ACK, so she can go // ahead and set current_key to the new key. Then when she sends something @@ -1269,7 +1301,7 @@ impl Context { let next_key_index = key_index ^ 1; let mut state = session.state.write().unwrap(); let _ = state.keys[next_key_index].replace(SessionKey::new::( - noise_psk_se_ee_es, + noise_ck_psk_es_ee_se, next_ratchet_count, current_time, session.send_counter.load(Ordering::Relaxed), @@ -1390,10 +1422,10 @@ impl Session { } /// Get the ratchet count and a hash fingerprint of the current active key. - pub fn key_info(&self) -> Option<(u64, [u8; 48])> { + pub fn key_info(&self) -> Option<(u64, [u8; NOISE_HASHLEN])> { let state = self.state.read().unwrap(); if let Some(key) = state.keys[state.current_key].as_ref() { - Some((key.ratchet_count, SHA384::hash(key.ratchet_key.as_bytes()))) + Some((key.ratchet_count, SHA512::hash(key.ratchet_key.as_bytes()))) } else { None } @@ -1582,15 +1614,15 @@ fn assemble_fragments_into(fragments: &[A::IncomingPacketBu impl SessionKey { fn new( - key: Secret, + key: Secret, ratchet_count: u64, current_time: i64, current_counter: u64, bob: bool, confirmed: bool, ) -> Self { - let a2b = kbkdf::(key.as_bytes()); - let b2a = kbkdf::(key.as_bytes()); + let a2b = kbkdf256::(&key); + let b2a = kbkdf256::(&key); let (receive_key, send_key) = if bob { (a2b, b2a) } else { @@ -1599,7 +1631,7 @@ impl SessionKey { let receive_cipher_pool = std::array::from_fn(|_| Mutex::new(AesGcm::new(&receive_key))); let send_cipher_pool = std::array::from_fn(|_| Mutex::new(AesGcm::new(&send_key))); Self { - ratchet_key: kbkdf::(key.as_bytes()), + ratchet_key: kbkdf512::(&key), receive_cipher_pool, send_cipher_pool, rekey_at_time: current_time @@ -1679,8 +1711,8 @@ fn append_to_slice(s: &mut [u8], p: usize, d: &[u8]) -> Result { } /// MixHash to update 'h' during negotiation. -fn mix_hash(h: &[u8; SHA384_HASH_SIZE], m: &[u8]) -> [u8; SHA384_HASH_SIZE] { - let mut hasher = SHA384::new(); +fn mix_hash(h: &[u8; NOISE_HASHLEN], m: &[u8]) -> [u8; NOISE_HASHLEN] { + let mut hasher = SHA512::new(); hasher.update(h); hasher.update(m); hasher.finish() @@ -1688,31 +1720,11 @@ fn mix_hash(h: &[u8; SHA384_HASH_SIZE], m: &[u8]) -> [u8; SHA384_HASH_SIZE] { /// HMAC-SHA512 key derivation based on: https://csrc.nist.gov/publications/detail/sp/800-108/final (page 7) /// Cryptographically this isn't meaningfully different from HMAC(key, [label]) but this is how NIST rolls. -fn kbkdf(key: &[u8]) -> Secret { - //These are the values we have assigned to the 5 variables involved in https://csrc.nist.gov/publications/detail/sp/800-108/final: - // K_in = key, i = 0x01, Label = 'Z'||'T'||LABEL, Context = 0x00, L = (OUTPUT_BYTES * 8) - hmac_sha512_secret( - key, - &[ - 1, - b'Z', - b'T', - LABEL, - 0x00, - 0, - (((OUTPUT_BYTES * 8) >> 8) & 0xff) as u8, - ((OUTPUT_BYTES * 8) & 0xff) as u8, - ], - ) +/// These are the values we have assigned to the 5 variables involved in their KDF: +/// K_in = key, i = 1u8, Label = b'Z'||b'T'||LABEL, Context = 0u8, L = 512u16 or 256u16 +fn kbkdf512(key: &Secret) -> Secret { + hmac_sha512_secret(key.as_bytes(), &[1, b'Z', b'T', LABEL, 0x00, 0, 2u8, 0u8]) } - -fn prng32(mut x: u32) -> u32 { - // based on lowbias32 from https://nullprogram.com/blog/2018/07/31/ - x = x.wrapping_add(1); // don't get stuck on 0 - x ^= x.wrapping_shr(16); - x = x.wrapping_mul(0x7feb352d); - x ^= x.wrapping_shr(15); - x = x.wrapping_mul(0x846ca68b); - x ^= x.wrapping_shr(16); - x +fn kbkdf256(key: &Secret) -> Secret<32> { + hmac_sha512_secret256(key.as_bytes(), &[1, b'Z', b'T', LABEL, 0x00, 0, 1u8, 0u8]) }