diff --git a/crypto/src/aes_fruity.rs b/crypto/src/aes_fruity.rs index 01b505b49..5a0d5355c 100644 --- a/crypto/src/aes_fruity.rs +++ b/crypto/src/aes_fruity.rs @@ -3,7 +3,7 @@ // MacOS implementation of AES primitives since CommonCrypto seems to be faster than OpenSSL, especially on ARM64. use std::os::raw::{c_int, c_void}; use std::ptr::{null, null_mut}; -use std::sync::Mutex; +use std::sync::atomic::AtomicPtr; use crate::secret::Secret; use crate::secure_eq; @@ -172,14 +172,26 @@ impl AesGcm { } } -pub struct Aes(Mutex<*mut c_void>, Mutex<*mut c_void>); +pub struct Aes(AtomicPtr, AtomicPtr); impl Drop for Aes { #[inline(always)] fn drop(&mut self) { unsafe { - CCCryptorRelease(*self.0.lock().unwrap()); - CCCryptorRelease(*self.1.lock().unwrap()); + loop { + let p = self.0.load(std::sync::atomic::Ordering::Acquire); + if !p.is_null() { + CCCryptorRelease(p); + break; + } + } + loop { + let p = self.1.load(std::sync::atomic::Ordering::Acquire); + if !p.is_null() { + CCCryptorRelease(p); + break; + } + } } } } @@ -191,7 +203,7 @@ impl Aes { KEY_SIZE == 32 || KEY_SIZE == 24 || KEY_SIZE == 16, "AES supports 128, 192, or 256 bits keys" ); - let aes: Self = std::mem::zeroed(); + let (mut p0, mut p1) = (null_mut(), null_mut()); assert_eq!( CCCryptorCreateWithMode( kCCEncrypt, @@ -205,7 +217,7 @@ impl Aes { 0, 0, kCCOptionECBMode, - &mut *aes.0.lock().unwrap() + &mut p0, ), 0 ); @@ -222,11 +234,11 @@ impl Aes { 0, 0, kCCOptionECBMode, - &mut *aes.1.lock().unwrap() + &mut p1, ), 0 ); - aes + Self(AtomicPtr::new(p0), AtomicPtr::new(p1)) } } @@ -235,8 +247,16 @@ impl Aes { assert_eq!(data.len(), 16); unsafe { let mut data_out_written = 0; - let e = self.0.lock().unwrap(); - CCCryptorUpdate(*e, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + loop { + let p = self.0.load(std::sync::atomic::Ordering::Acquire); + if !p.is_null() { + CCCryptorUpdate(p, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + self.0.store(p, std::sync::atomic::Ordering::Release); + break; + } else { + std::thread::yield_now(); + } + } } } @@ -245,8 +265,16 @@ impl Aes { assert_eq!(data.len(), 16); unsafe { let mut data_out_written = 0; - let d = self.1.lock().unwrap(); - CCCryptorUpdate(*d, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + loop { + let p = self.1.load(std::sync::atomic::Ordering::Acquire); + if !p.is_null() { + CCCryptorUpdate(p, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + self.1.store(p, std::sync::atomic::Ordering::Release); + break; + } else { + std::thread::yield_now(); + } + } } } } diff --git a/crypto/src/aes_openssl.rs b/crypto/src/aes_openssl.rs index 98cb433e2..f0b4fe552 100644 --- a/crypto/src/aes_openssl.rs +++ b/crypto/src/aes_openssl.rs @@ -1,16 +1,16 @@ // (c) 2020-2022 ZeroTier, Inc. -- currently proprietary pending actual release and licensing. See LICENSE.md. -use std::{ptr, mem::MaybeUninit}; +use std::{mem::MaybeUninit, ptr}; use foreign_types::ForeignType; -use crate::{secret::Secret, cipher_ctx::CipherCtx}; +use crate::{cipher_ctx::CipherCtx, secret::Secret}; /// An OpenSSL AES_GCM context. Automatically frees itself on drop. /// The current interface is custom made for ZeroTier, but could easily be adapted for other uses. /// Whether `ENCRYPT` is true or false decides respectively whether this context encrypts or decrypts. /// Even though OpenSSL lets you set this dynamically almost no operations work when you do this without resetting the context. -pub struct AesGcm (CipherCtx); +pub struct AesGcm(CipherCtx); impl AesGcm { /// Create an AesGcm context with the given key, key must be 16, 24 or 32 bytes long. @@ -22,7 +22,7 @@ impl AesGcm { 16 => ffi::EVP_aes_128_gcm(), 24 => ffi::EVP_aes_192_gcm(), 32 => ffi::EVP_aes_256_gcm(), - _ => panic!("Aes KEY_SIZE must be 16, 24 or 32") + _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), }; ctx.cipher_init::(t, key.as_ptr(), ptr::null()).unwrap(); ffi::EVP_CIPHER_CTX_set_padding(ctx.as_ptr(), 0); @@ -103,7 +103,7 @@ impl Aes { 16 => ffi::EVP_aes_128_ecb(), 24 => ffi::EVP_aes_192_ecb(), 32 => ffi::EVP_aes_256_ecb(), - _ => panic!("Aes KEY_SIZE must be 16, 24 or 32") + _ => panic!("Aes KEY_SIZE must be 16, 24 or 32"), }; ctx0.cipher_init::(t, key.as_ptr(), ptr::null()).unwrap(); ffi::EVP_CIPHER_CTX_set_padding(ctx0.as_ptr(), 0); @@ -117,14 +117,23 @@ impl Aes { /// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls. #[inline(always)] pub fn encrypt_block_in_place(&self, data: &mut [u8]) { - debug_assert_eq!(data.len(), AES_BLOCK_SIZE, "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing."); + debug_assert_eq!( + data.len(), + AES_BLOCK_SIZE, + "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing." + ); let ptr = data.as_mut_ptr(); unsafe { self.0.update::(data, ptr).unwrap() } } + /// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls. #[inline(always)] pub fn decrypt_block_in_place(&self, data: &mut [u8]) { - debug_assert_eq!(data.len(), AES_BLOCK_SIZE, "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing."); + debug_assert_eq!( + data.len(), + AES_BLOCK_SIZE, + "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing." + ); let ptr = data.as_mut_ptr(); unsafe { self.1.update::(data, ptr).unwrap() } } diff --git a/crypto/src/cipher_ctx.rs b/crypto/src/cipher_ctx.rs index e46eb3353..f5cba9441 100644 --- a/crypto/src/cipher_ctx.rs +++ b/crypto/src/cipher_ctx.rs @@ -1,8 +1,7 @@ - use std::ptr; -use crate::error::{ErrorStack, cvt_p, cvt}; -use foreign_types::{ForeignType, foreign_type, ForeignTypeRef}; +use crate::error::{cvt, cvt_p, ErrorStack}; +use foreign_types::{foreign_type, ForeignType, ForeignTypeRef}; use libc::c_int; foreign_type! { @@ -22,23 +21,19 @@ impl CipherCtx { } } impl CipherCtxRef { - /// Initializes the context for encryption or decryption. /// All pointer fields can be null, in which case the corresponding field in the context is not updated. - pub unsafe fn cipher_init(&self, t: *const ffi::EVP_CIPHER, key: *const u8, iv: *const u8) -> Result<(), ErrorStack>{ - let evp_f = if ENCRYPT { ffi::EVP_EncryptInit_ex } else { ffi::EVP_DecryptInit_ex }; + pub unsafe fn cipher_init(&self, t: *const ffi::EVP_CIPHER, key: *const u8, iv: *const u8) -> Result<(), ErrorStack> { + let evp_f = if ENCRYPT { + ffi::EVP_EncryptInit_ex + } else { + ffi::EVP_DecryptInit_ex + }; - cvt(evp_f( - self.as_ptr(), - t, - ptr::null_mut(), - key, - iv, - ))?; + cvt(evp_f(self.as_ptr(), t, ptr::null_mut(), key, iv))?; Ok(()) } - /// Writes data into the context. /// /// Providing no output buffer will cause the input to be considered additional authenticated data (AAD). @@ -54,27 +49,20 @@ impl CipherCtxRef { /// ciphers the output buffer size should be at least as big as /// the input buffer. For block ciphers the size of the output /// buffer depends on the state of partially updated blocks. - pub unsafe fn update( - &self, - input: &[u8], - output: *mut u8, - ) -> Result<(), ErrorStack> { - let evp_f = if ENCRYPT { ffi::EVP_EncryptUpdate } else { ffi::EVP_DecryptUpdate }; + pub unsafe fn update(&self, input: &[u8], output: *mut u8) -> Result<(), ErrorStack> { + let evp_f = if ENCRYPT { + ffi::EVP_EncryptUpdate + } else { + ffi::EVP_DecryptUpdate + }; let mut outlen = 0; - cvt(evp_f( - self.as_ptr(), - output, - &mut outlen, - input.as_ptr(), - input.len() as c_int, - ))?; + cvt(evp_f(self.as_ptr(), output, &mut outlen, input.as_ptr(), input.len() as c_int))?; Ok(()) } - /// Finalizes the encryption or decryption process. /// /// Any remaining data will be written to the output buffer. @@ -88,18 +76,15 @@ impl CipherCtxRef { /// large enough to contain correct number of bytes. For streaming /// ciphers the output buffer can be empty, for block ciphers the /// output buffer should be at least as big as the block. - pub unsafe fn finalize( - &self, - output: *mut u8, - ) -> Result<(), ErrorStack> { - let evp_f = if ENCRYPT { ffi::EVP_EncryptFinal_ex } else { ffi::EVP_DecryptFinal_ex }; + pub unsafe fn finalize(&self, output: *mut u8) -> Result<(), ErrorStack> { + let evp_f = if ENCRYPT { + ffi::EVP_EncryptFinal_ex + } else { + ffi::EVP_DecryptFinal_ex + }; let mut outl = 0; - cvt(evp_f( - self.as_ptr(), - output, - &mut outl, - ))?; + cvt(evp_f(self.as_ptr(), output, &mut outl))?; Ok(()) } @@ -111,7 +96,6 @@ impl CipherCtxRef { /// The size of the buffer indicates the size of the tag. While some ciphers support a range of tag sizes, it is /// recommended to pick the maximum size. pub fn tag(&self, tag: &mut [u8]) -> Result<(), ErrorStack> { - unsafe { cvt(ffi::EVP_CIPHER_CTX_ctrl( self.as_ptr(), @@ -125,6 +109,7 @@ impl CipherCtxRef { } /// Sets the authentication tag for verification during decryption. + #[allow(unused)] pub fn set_tag(&self, tag: &[u8]) -> Result<(), ErrorStack> { unsafe { cvt(ffi::EVP_CIPHER_CTX_ctrl( @@ -141,8 +126,8 @@ impl CipherCtxRef { #[cfg(test)] mod test { - use crate::init; use super::*; + use crate::init; #[test] fn aes_128_ecb() { diff --git a/crypto/src/lib.rs b/crypto/src/lib.rs index 8bfdc0569..56d47b4be 100644 --- a/crypto/src/lib.rs +++ b/crypto/src/lib.rs @@ -1,24 +1,26 @@ - -mod error; -mod cipher_ctx; mod bn; +mod cipher_ctx; mod ec; +mod error; -pub mod secret; -pub mod random; pub mod hash; pub mod mimcvdf; pub mod p384; +pub mod random; +pub mod secret; pub mod poly1305; pub mod salsa; pub mod typestate; pub mod x25519; +#[cfg(target_os = "macos")] pub mod aes_fruity; -pub mod aes_openssl; #[cfg(target_os = "macos")] pub use aes_fruity as aes; + +#[cfg(not(target_os = "macos"))] +pub mod aes_openssl; #[cfg(not(target_os = "macos"))] pub use aes_openssl as aes; @@ -31,9 +33,6 @@ pub use aes_gmac_siv_fruity as aes_gmac_siv; #[cfg(not(target_os = "macos"))] pub use aes_gmac_siv_openssl as aes_gmac_siv; - - - /// This must be called before using any function from this library. pub fn init() { ffi::init(); diff --git a/zssp/src/main.rs b/zssp/src/main.rs index 402505e06..8a15beee3 100644 --- a/zssp/src/main.rs +++ b/zssp/src/main.rs @@ -46,7 +46,7 @@ fn alice_main( alice_out: mpsc::SyncSender>, alice_in: mpsc::Receiver>, ) { - let context = zssp::Context::::new(16); + let context = zssp::Context::::new(16, TEST_MTU); let mut data_buf = [0u8; 65536]; let mut next_service = ms_monotonic() + 500; let mut last_ratchet_count = 0; @@ -88,7 +88,6 @@ fn alice_main( &0, &mut data_buf, pkt, - TEST_MTU, current_time, ) { Ok(zssp::ReceiveResult::Ok(_)) => { @@ -144,7 +143,6 @@ fn alice_main( |_, b| { let _ = alice_out.send(b.to_vec()); }, - TEST_MTU, current_time, ); } @@ -159,7 +157,7 @@ fn bob_main( bob_out: mpsc::SyncSender>, bob_in: mpsc::Receiver>, ) { - let context = zssp::Context::::new(16); + let context = zssp::Context::::new(16, TEST_MTU); let mut data_buf = [0u8; 65536]; let mut data_buf_2 = [0u8; TEST_MTU]; let mut last_ratchet_count = 0; @@ -186,7 +184,6 @@ fn bob_main( &0, &mut data_buf, pkt, - TEST_MTU, current_time, ) { Ok(zssp::ReceiveResult::Ok(_)) => { @@ -246,7 +243,6 @@ fn bob_main( |_, b| { let _ = bob_out.send(b.to_vec()); }, - TEST_MTU, current_time, ); } diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index a82dbe682..591eed60a 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -11,7 +11,7 @@ use std::collections::{HashMap, HashSet}; use std::num::NonZeroU64; -use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicI64, AtomicU64, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; use zerotier_crypto::aes::{Aes, AesGcm}; @@ -28,12 +28,15 @@ use crate::fragged::Fragged; use crate::proto::*; use crate::sessionid::SessionId; +const GCM_CIPHER_POOL_SIZE: usize = 4; + /// Session context for local application. /// /// Each application using ZSSP must create an instance of this to own sessions and /// defragment incoming packets that are not yet associated with a session. pub struct Context { max_incomplete_session_queue_size: usize, + default_physical_mtu: AtomicUsize, defrag: Mutex< HashMap< (Application::PhysicalPath, u64), @@ -52,7 +55,7 @@ struct SessionsById { active: HashMap>>, // Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout. - incoming: HashMap>, + incoming: HashMap>, } /// Result generated by the context packet receive function, with possible payloads. @@ -80,7 +83,6 @@ pub struct Session { /// An arbitrary application defined object associated with each session pub application_data: Application::Data, - psk: Secret, send_counter: AtomicU64, receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO], header_protection_cipher: Aes, @@ -90,13 +92,14 @@ pub struct Session { /// Most of the mutable parts of a session state. struct State { + physical_mtu: usize, remote_session_id: Option, keys: [Option; 2], current_key: usize, current_offer: Offer, } -struct BobIncomingIncompleteSessionState { +struct IncomingIncompleteSession { timestamp: i64, alice_session_id: SessionId, bob_session_id: SessionId, @@ -107,8 +110,9 @@ struct BobIncomingIncompleteSessionState { bob_noise_e_secret: P384KeyPair, } -struct AliceOutgoingIncompleteSessionState { +struct OutgoingSessionOffer { last_retry_time: AtomicI64, + psk: Secret, noise_h: [u8; SHA384_HASH_SIZE], noise_es: Secret, alice_noise_e_secret: P384KeyPair, @@ -125,34 +129,33 @@ struct OutgoingSessionAck { enum Offer { None, - NoiseXKInit(Box), + NoiseXKInit(Box), NoiseXKAck(Box), RekeyInit(P384KeyPair, i64), } struct SessionKey { - ratchet_key: Secret, // Key used in derivation of the next session key - //receive_key: Secret, // Receive side AES-GCM key - //send_key: Secret, // Send side AES-GCM key - receive_cipher_pool: [Mutex>; 4], // Pool of reusable sending ciphers - send_cipher_pool: [Mutex>; 4], // Pool of reusable receiving ciphers - rekey_at_time: i64, // Rekey at or after this time (ticks) - created_at_counter: u64, // Counter at which session was created - rekey_at_counter: u64, // Rekey at or after this counter - expire_at_counter: u64, // Hard error when this counter value is reached or exceeded - ratchet_count: u64, // Number of rekey events - bob: bool, // Was this side "Bob" in this exchange? - confirmed: bool, // Is this key confirmed by the other side? + ratchet_key: Secret, // Key used in derivation of the next session key + receive_cipher_pool: [Mutex>; GCM_CIPHER_POOL_SIZE], // Pool of reusable sending ciphers + send_cipher_pool: [Mutex>; GCM_CIPHER_POOL_SIZE], // Pool of reusable receiving ciphers + rekey_at_time: i64, // Rekey at or after this time (ticks) + created_at_counter: u64, // Counter at which session was created + rekey_at_counter: u64, // Rekey at or after this counter + expire_at_counter: u64, // Hard error when this counter value is reached or exceeded + ratchet_count: u64, // Number of rekey events + initiate_rekey: bool, // My turn to initiate rekey next? + confirmed: bool, // Is this key confirmed by the other side yet? } impl Context { /// Create a new session context. /// /// * `max_incomplete_session_queue_size` - Maximum number of incomplete sessions in negotiation phase - pub fn new(max_incomplete_session_queue_size: usize) -> Self { + pub fn new(max_incomplete_session_queue_size: usize, default_physical_mtu: usize) -> Self { zerotier_crypto::init(); Self { max_incomplete_session_queue_size, + default_physical_mtu: AtomicUsize::new(default_physical_mtu), defrag: Mutex::new(HashMap::new()), sessions: RwLock::new(SessionsById { active: HashMap::with_capacity(64), @@ -163,12 +166,13 @@ impl Context { /// Perform periodic background service and cleanup tasks. /// - /// This returns the number of milliseconds until it should be called again. + /// This returns the number of milliseconds until it should be called again. The caller should + /// try to satisfy this but small variations in timing of up to +/- a second or two are not + /// a problem. /// /// * `send` - Function to send packets to remote sessions - /// * `mtu` - Physical MTU /// * `current_time` - Current monotonic time in milliseconds - pub fn service>, &mut [u8])>(&self, mut send: SendFunction, mtu: usize, current_time: i64) -> i64 { + pub fn service>, &mut [u8])>(&self, mut send: SendFunction, current_time: i64) -> i64 { let mut dead_active = Vec::new(); let mut dead_pending = Vec::new(); let retry_cutoff = current_time - Application::RETRY_INTERVAL; @@ -193,7 +197,7 @@ impl Context { let _ = send_with_fragmentation( |b| send(&session, b), &mut (offer.init_packet.clone()), - mtu, + state.physical_mtu, PACKET_TYPE_ALICE_NOISE_XK_INIT, None, 0, @@ -211,7 +215,7 @@ impl Context { let _ = send_with_fragmentation( |b| send(&session, b), &mut (ack.ack.clone())[..ack.ack_size], - mtu, + state.physical_mtu, PACKET_TYPE_ALICE_NOISE_XK_ACK, state.remote_session_id, 0, @@ -226,7 +230,8 @@ impl Context { // Check whether we need to rekey if there is no pending offer or if the last rekey // offer was before retry_cutoff (checked in the 'match' above). if let Some(key) = state.keys[state.current_key].as_ref() { - if key.bob && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) + if key.initiate_rekey + && (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter) { drop(state); session.initiate_rekey(|b| send(&session, b), current_time); @@ -268,7 +273,7 @@ impl Context { /// /// * `app` - Application layer instance /// * `send` - User-supplied packet sending function - /// * `mtu` - Physical MTU for calls to send() + /// * `mtu` - Physical MTU for calls to send() for this session (can be changed later) /// * `remote_s_public_blob` - Remote side's opaque static public blob (which must contain remote_s_public_p384) /// * `remote_s_public_p384` - Remote side's static public NIST P-384 key /// * `psk` - Pre-shared key (use all zero if none) @@ -311,16 +316,17 @@ impl Context { let session = Arc::new(Session { id: local_session_id, application_data, - psk, send_counter: AtomicU64::new(3), // 1 and 2 are reserved for init and final ack receive_window: std::array::from_fn(|_| AtomicU64::new(0)), header_protection_cipher: Aes::new(&header_protection_key), state: RwLock::new(State { + physical_mtu: mtu, remote_session_id: None, keys: [None, None], current_key: 0, - current_offer: Offer::NoiseXKInit(Box::new(AliceOutgoingIncompleteSessionState { + current_offer: Offer::NoiseXKInit(Box::new(OutgoingSessionOffer { last_retry_time: AtomicI64::new(current_time), + psk, noise_h: mix_hash(&mix_hash(&INITIAL_H, remote_s_public_blob), &alice_noise_e), noise_es: noise_es.clone(), alice_noise_e_secret, @@ -385,8 +391,7 @@ impl Context { /// /// The send function may be called one or more times to send packets. If the packet is associated /// wtth an active session this session is supplied, otherwise this parameter is None and the packet - /// should be a reply to the current incoming packet. The size of packets to be sent will not exceed - /// the supplied mtu. + /// should be a reply to the current incoming packet. /// /// The check_allow_incoming_session function is called when an initial Noise_XK init message is /// received. This is before anything is known about the caller. A return value of true proceeds @@ -411,7 +416,6 @@ impl Context { /// * `send` - Function to call to send packets /// * `data_buf` - Buffer to receive decrypted and authenticated object data (an error is returned if too small) /// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership) - /// * `mtu` - Physical wire MTU for sending packets /// * `current_time` - Current monotonic time in milliseconds pub fn receive< 'b, @@ -427,7 +431,6 @@ impl Context { source: &Application::PhysicalPath, data_buf: &'b mut [u8], mut incoming_packet_buf: Application::IncomingPacketBuffer, - mtu: usize, current_time: i64, ) -> Result, Error> { let incoming_packet: &mut [u8] = incoming_packet_buf.as_mut(); @@ -462,7 +465,6 @@ impl Context { Some(session), None, key_index, - mtu, current_time, ); } else { @@ -482,7 +484,6 @@ impl Context { Some(session), None, key_index, - mtu, current_time, ); } @@ -547,7 +548,6 @@ impl Context { None, incoming, key_index, - mtu, current_time, ); } @@ -564,7 +564,6 @@ impl Context { None, incoming, key_index, - mtu, current_time, ); } @@ -588,9 +587,8 @@ impl Context { fragments: &[Application::IncomingPacketBuffer], packet_type: u8, session: Option>>, - incoming: Option>, + incoming: Option>, key_index: usize, - mtu: usize, current_time: i64, ) -> Result, Error> { debug_assert!(fragments.len() >= 1); @@ -602,7 +600,7 @@ impl Context { if let Some(session) = session { let state = session.state.read().unwrap(); if let Some(key) = state.keys[key_index].as_ref() { - let mut c = key.get_receive_cipher(); + let mut c = key.get_receive_cipher(incoming_counter); c.reset_init_gcm(&incoming_message_nonce); let mut data_len = 0; @@ -804,7 +802,7 @@ impl Context { // Reserve session ID on this side and record incomplete session state. sessions.incoming.insert( bob_session_id, - Arc::new(BobIncomingIncompleteSessionState { + Arc::new(IncomingIncompleteSession { timestamp: current_time, alice_session_id, bob_session_id, @@ -823,7 +821,7 @@ impl Context { send_with_fragmentation( |b| send(None, b), &mut ack_packet, - mtu, + self.default_physical_mtu.load(Ordering::Relaxed), PACKET_TYPE_BOB_NOISE_XK_ACK, Some(alice_session_id), 0, @@ -903,7 +901,7 @@ impl Context { .as_bytes(), ) .as_bytes(), - hmac_sha512_secret::(session.psk.as_bytes(), hk.as_bytes()).as_bytes(), + hmac_sha512_secret::(outgoing_offer.psk.as_bytes(), hk.as_bytes()).as_bytes(), ); let reply_message_nonce = create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_ACK, 2); @@ -933,7 +931,10 @@ impl Context { assert!(metadata.len() <= (u16::MAX as usize)); reply_len = append_to_slice(&mut reply_buffer, reply_len, &(metadata.len() as u16).to_le_bytes())?; - let noise_h_next = mix_hash(&mix_hash(&noise_h_next, &reply_buffer[HEADER_SIZE..reply_len]), session.psk.as_bytes()); + let noise_h_next = mix_hash( + &mix_hash(&noise_h_next, &reply_buffer[HEADER_SIZE..reply_len]), + outgoing_offer.psk.as_bytes(), + ); enc_start = reply_len; reply_len = append_to_slice(&mut reply_buffer, reply_len, metadata)?; @@ -946,6 +947,8 @@ impl Context { gcm.crypt_in_place(&mut reply_buffer[enc_start..reply_len]); reply_len = append_to_slice(&mut reply_buffer, reply_len, &gcm.finish_encrypt())?; + let mtu = state.physical_mtu; + drop(state); { let mut state = session.state.write().unwrap(); @@ -1063,11 +1066,11 @@ impl Context { let session = Arc::new(Session { id: incoming.bob_session_id, application_data, - psk, send_counter: AtomicU64::new(2), // 1 was already used during negotiation receive_window: std::array::from_fn(|_| AtomicU64::new(0)), header_protection_cipher: Aes::new(&incoming.header_protection_key), state: RwLock::new(State { + physical_mtu: self.default_physical_mtu.load(Ordering::Relaxed), remote_session_id: Some(incoming.alice_session_id), keys: [ Some(SessionKey::new::(noise_es_ee_se_hk_psk, 1, current_time, 2, true, true)), @@ -1115,8 +1118,8 @@ impl Context { if let Some(key) = state.keys[key_index].as_ref() { // Only the current "Alice" accepts rekeys initiated by the current "Bob." These roles // flip with each rekey event. - if !key.bob { - let mut c = key.get_receive_cipher(); + if !key.initiate_rekey { + let mut c = key.get_receive_cipher(incoming_counter); c.reset_init_gcm(&incoming_message_nonce); c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]); let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]); @@ -1202,8 +1205,8 @@ impl Context { if let Offer::RekeyInit(alice_e_secret, _) = &state.current_offer { if let Some(key) = state.keys[key_index].as_ref() { // Only the current "Bob" initiates rekeys and expects this ACK. - if key.bob { - let mut c = key.get_receive_cipher(); + if key.initiate_rekey { + let mut c = key.get_receive_cipher(incoming_counter); c.reset_init_gcm(&incoming_message_nonce); c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]); let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]); @@ -1555,8 +1558,6 @@ impl SessionKey { let send_cipher_pool = std::array::from_fn(|_| Mutex::new(AesGcm::new(&send_key))); Self { ratchet_key: kbkdf::(key.as_bytes()), - //receive_key, - //send_key, receive_cipher_pool, send_cipher_pool, rekey_at_time: current_time @@ -1568,31 +1569,21 @@ impl SessionKey { rekey_at_counter: current_counter.checked_add(Application::REKEY_AFTER_USES).unwrap(), expire_at_counter: current_counter.checked_add(Application::EXPIRE_AFTER_USES).unwrap(), ratchet_count, - bob, + initiate_rekey: bob, confirmed, } } fn get_send_cipher<'a>(&'a self, counter: u64) -> Result>, Error> { if counter < self.expire_at_counter { - for mutex in &self.send_cipher_pool { - if let Ok(guard) = mutex.try_lock() { - return Ok(guard); - } - } - Ok(self.send_cipher_pool[0].lock().unwrap()) + Ok(self.send_cipher_pool[(counter as usize) % GCM_CIPHER_POOL_SIZE].lock().unwrap()) } else { Err(Error::MaxKeyLifetimeExceeded) } } - fn get_receive_cipher<'a>(&'a self) -> MutexGuard<'a, AesGcm> { - for mutex in &self.receive_cipher_pool { - if let Ok(guard) = mutex.try_lock() { - return guard; - } - } - self.receive_cipher_pool[0].lock().unwrap() + fn get_receive_cipher<'a>(&'a self, counter: u64) -> MutexGuard<'a, AesGcm> { + self.receive_cipher_pool[(counter as usize) % GCM_CIPHER_POOL_SIZE].lock().unwrap() } }