diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index a82dbe682..536f3c8b7 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -602,7 +602,7 @@ impl Context { if let Some(session) = session { let state = session.state.read().unwrap(); if let Some(key) = state.keys[key_index].as_ref() { - let mut c = key.get_receive_cipher(); + let mut c = key.get_receive_cipher(incoming_counter); c.reset_init_gcm(&incoming_message_nonce); let mut data_len = 0; @@ -1116,7 +1116,7 @@ impl Context { // Only the current "Alice" accepts rekeys initiated by the current "Bob." These roles // flip with each rekey event. if !key.bob { - let mut c = key.get_receive_cipher(); + let mut c = key.get_receive_cipher(incoming_counter); c.reset_init_gcm(&incoming_message_nonce); c.crypt_in_place(&mut pkt_assembled[RekeyInit::ENC_START..RekeyInit::AUTH_START]); let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyInit::AUTH_START..]); @@ -1203,7 +1203,7 @@ impl Context { if let Some(key) = state.keys[key_index].as_ref() { // Only the current "Bob" initiates rekeys and expects this ACK. if key.bob { - let mut c = key.get_receive_cipher(); + let mut c = key.get_receive_cipher(incoming_counter); c.reset_init_gcm(&incoming_message_nonce); c.crypt_in_place(&mut pkt_assembled[RekeyAck::ENC_START..RekeyAck::AUTH_START]); let aead_authentication_ok = c.finish_decrypt(&pkt_assembled[RekeyAck::AUTH_START..]); @@ -1575,24 +1575,20 @@ impl SessionKey { fn get_send_cipher<'a>(&'a self, counter: u64) -> Result>, Error> { if counter < self.expire_at_counter { - for mutex in &self.send_cipher_pool { - if let Ok(guard) = mutex.try_lock() { - return Ok(guard); - } - } - Ok(self.send_cipher_pool[0].lock().unwrap()) + //for i in 0..self.send_cipher_pool.len() { + // let mutex = &self.send_cipher_pool[(counter as usize + i)%self.send_cipher_pool.len()]; + // if let Ok(guard) = mutex.try_lock() { + // return Ok(guard); + // } + //} + Ok(self.send_cipher_pool[(counter as usize)%self.send_cipher_pool.len()].lock().unwrap()) } else { Err(Error::MaxKeyLifetimeExceeded) } } - fn get_receive_cipher<'a>(&'a self) -> MutexGuard<'a, AesGcm> { - for mutex in &self.receive_cipher_pool { - if let Ok(guard) = mutex.try_lock() { - return guard; - } - } - self.receive_cipher_pool[0].lock().unwrap() + fn get_receive_cipher<'a>(&'a self, counter: u64) -> MutexGuard<'a, AesGcm> { + self.receive_cipher_pool[(counter as usize)%self.receive_cipher_pool.len()].lock().unwrap() } }