diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 536f3c8b7..ab6324f78 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -130,12 +130,13 @@ enum Offer { RekeyInit(P384KeyPair, i64), } +const AES_POOL_SIZE: usize = 4; struct SessionKey { - ratchet_key: Secret, // Key used in derivation of the next session key + ratchet_key: Secret, // Key used in derivation of the next session key //receive_key: Secret, // Receive side AES-GCM key //send_key: Secret, // Send side AES-GCM key - receive_cipher_pool: [Mutex>; 4], // Pool of reusable sending ciphers - send_cipher_pool: [Mutex>; 4], // Pool of reusable receiving ciphers + receive_cipher_pool: [Mutex>; AES_POOL_SIZE], // Pool of reusable sending ciphers + send_cipher_pool: [Mutex>; AES_POOL_SIZE], // Pool of reusable receiving ciphers rekey_at_time: i64, // Rekey at or after this time (ticks) created_at_counter: u64, // Counter at which session was created rekey_at_counter: u64, // Rekey at or after this counter @@ -1575,20 +1576,24 @@ impl SessionKey { fn get_send_cipher<'a>(&'a self, counter: u64) -> Result>, Error> { if counter < self.expire_at_counter { - //for i in 0..self.send_cipher_pool.len() { - // let mutex = &self.send_cipher_pool[(counter as usize + i)%self.send_cipher_pool.len()]; - // if let Ok(guard) = mutex.try_lock() { - // return Ok(guard); - // } - //} - Ok(self.send_cipher_pool[(counter as usize)%self.send_cipher_pool.len()].lock().unwrap()) + for i in 0..(AES_POOL_SIZE - 1) { + if let Ok(p) = self.send_cipher_pool[(counter as usize).wrapping_add(i)%AES_POOL_SIZE].try_lock() { + return Ok(p); + } + } + Ok(self.send_cipher_pool[(counter as usize)%AES_POOL_SIZE].lock().unwrap()) } else { Err(Error::MaxKeyLifetimeExceeded) } } fn get_receive_cipher<'a>(&'a self, counter: u64) -> MutexGuard<'a, AesGcm> { - self.receive_cipher_pool[(counter as usize)%self.receive_cipher_pool.len()].lock().unwrap() + for i in 0..(AES_POOL_SIZE - 1) { + if let Ok(p) = self.receive_cipher_pool[(counter as usize).wrapping_add(i)%AES_POOL_SIZE].try_lock() { + return p; + } + } + self.receive_cipher_pool[(counter as usize)%AES_POOL_SIZE].lock().unwrap() } }