fixed multithreading bug

This commit is contained in:
mamoniot 2023-03-09 13:46:32 -05:00
parent d33b8e50cd
commit 02ea954329
No known key found for this signature in database
GPG key ID: ADCCDBBE0E3D3B3B
5 changed files with 27 additions and 30 deletions

View file

@ -97,7 +97,7 @@ impl<const ENCRYPT: bool> AesGcm<ENCRYPT> {
} }
#[inline(always)] #[inline(always)]
pub fn reset_init_gcm(&self, iv: &[u8]) { pub fn reset_init_gcm(&mut self, iv: &[u8]) {
assert_eq!(iv.len(), 12); assert_eq!(iv.len(), 12);
unsafe { unsafe {
assert_eq!(CCCryptorGCMReset(self.0), 0); assert_eq!(CCCryptorGCMReset(self.0), 0);
@ -106,14 +106,14 @@ impl<const ENCRYPT: bool> AesGcm<ENCRYPT> {
} }
#[inline(always)] #[inline(always)]
pub fn aad(&self, aad: &[u8]) { pub fn aad(&mut self, aad: &[u8]) {
unsafe { unsafe {
assert_eq!(CCCryptorGCMAddAAD(self.0, aad.as_ptr().cast(), aad.len()), 0); assert_eq!(CCCryptorGCMAddAAD(self.0, aad.as_ptr().cast(), aad.len()), 0);
} }
} }
#[inline(always)] #[inline(always)]
pub fn crypt(&self, input: &[u8], output: &mut [u8]) { pub fn crypt(&mut self, input: &[u8], output: &mut [u8]) {
unsafe { unsafe {
assert_eq!(input.len(), output.len()); assert_eq!(input.len(), output.len());
if ENCRYPT { if ENCRYPT {
@ -131,7 +131,7 @@ impl<const ENCRYPT: bool> AesGcm<ENCRYPT> {
} }
#[inline(always)] #[inline(always)]
pub fn crypt_in_place(&self, data: &mut [u8]) { pub fn crypt_in_place(&mut self, data: &mut [u8]) {
unsafe { unsafe {
if ENCRYPT { if ENCRYPT {
assert_eq!(CCCryptorGCMEncrypt(self.0, data.as_ptr().cast(), data.len(), data.as_mut_ptr().cast()), 0); assert_eq!(CCCryptorGCMEncrypt(self.0, data.as_ptr().cast(), data.len(), data.as_mut_ptr().cast()), 0);
@ -142,7 +142,7 @@ impl<const ENCRYPT: bool> AesGcm<ENCRYPT> {
} }
#[inline(always)] #[inline(always)]
fn finish(&self) -> [u8; 16] { fn finish(&mut self) -> [u8; 16] {
let mut tag = 0_u128.to_ne_bytes(); let mut tag = 0_u128.to_ne_bytes();
unsafe { unsafe {
let mut tag_len = 16; let mut tag_len = 16;
@ -159,14 +159,14 @@ impl<const ENCRYPT: bool> AesGcm<ENCRYPT> {
impl AesGcm<true> { impl AesGcm<true> {
/// Produce the gcm authentication tag. /// Produce the gcm authentication tag.
#[inline(always)] #[inline(always)]
pub fn finish_encrypt(&self) -> [u8; 16] { pub fn finish_encrypt(&mut self) -> [u8; 16] {
self.finish() self.finish()
} }
} }
impl AesGcm<false> { impl AesGcm<false> {
/// Check the gcm authentication tag. Outputs true if it matches the just decrypted message, outputs false otherwise. /// Check the gcm authentication tag. Outputs true if it matches the just decrypted message, outputs false otherwise.
#[inline(always)] #[inline(always)]
pub fn finish_decrypt(&self, expected_tag: &[u8]) -> bool { pub fn finish_decrypt(&mut self, expected_tag: &[u8]) -> bool {
secure_eq(&self.finish(), expected_tag) secure_eq(&self.finish(), expected_tag)
} }
} }
@ -229,7 +229,7 @@ impl Aes {
} }
#[inline(always)] #[inline(always)]
pub fn encrypt_block_in_place(&self, data: &mut [u8]) { pub fn encrypt_block_in_place(&mut self, data: &mut [u8]) {
assert_eq!(data.len(), 16); assert_eq!(data.len(), 16);
unsafe { unsafe {
let mut data_out_written = 0; let mut data_out_written = 0;
@ -238,7 +238,7 @@ impl Aes {
} }
#[inline(always)] #[inline(always)]
pub fn decrypt_block_in_place(&self, data: &mut [u8]) { pub fn decrypt_block_in_place(&mut self, data: &mut [u8]) {
assert_eq!(data.len(), 16); assert_eq!(data.len(), 16);
unsafe { unsafe {
let mut data_out_written = 0; let mut data_out_written = 0;

View file

@ -34,7 +34,7 @@ impl<const ENCRYPT: bool> AesGcm<ENCRYPT> {
/// Set the IV of this AesGcm context. This call resets the IV but leaves the key and encryption algorithm alone. /// Set the IV of this AesGcm context. This call resets the IV but leaves the key and encryption algorithm alone.
/// This method must be called before any other method on AesGcm. /// This method must be called before any other method on AesGcm.
/// `iv` must be exactly 12 bytes in length, because that is what Aes supports. /// `iv` must be exactly 12 bytes in length, because that is what Aes supports.
pub fn reset_init_gcm(&self, iv: &[u8]) { pub fn reset_init_gcm(&mut self, iv: &[u8]) {
debug_assert_eq!(iv.len(), 12, "Aes IV must be 12 bytes long"); debug_assert_eq!(iv.len(), 12, "Aes IV must be 12 bytes long");
unsafe { unsafe {
self.0.cipher_init::<ENCRYPT>(ptr::null(), ptr::null(), iv.as_ptr()).unwrap(); self.0.cipher_init::<ENCRYPT>(ptr::null(), ptr::null(), iv.as_ptr()).unwrap();
@ -43,20 +43,20 @@ impl<const ENCRYPT: bool> AesGcm<ENCRYPT> {
/// Add additional authentication data to AesGcm (same operation with CTR mode). /// Add additional authentication data to AesGcm (same operation with CTR mode).
#[inline(always)] #[inline(always)]
pub fn aad(&self, aad: &[u8]) { pub fn aad(&mut self, aad: &[u8]) {
unsafe { self.0.update::<ENCRYPT>(aad, ptr::null_mut()).unwrap() }; unsafe { self.0.update::<ENCRYPT>(aad, ptr::null_mut()).unwrap() };
} }
/// Encrypt or decrypt (same operation with CTR mode) /// Encrypt or decrypt (same operation with CTR mode)
#[inline(always)] #[inline(always)]
pub fn crypt(&self, input: &[u8], output: &mut [u8]) { pub fn crypt(&mut self, input: &[u8], output: &mut [u8]) {
debug_assert!(output.len() >= input.len(), "output buffer must fit the size of the input buffer"); debug_assert!(output.len() >= input.len(), "output buffer must fit the size of the input buffer");
unsafe { self.0.update::<ENCRYPT>(input, output.as_mut_ptr()).unwrap() }; unsafe { self.0.update::<ENCRYPT>(input, output.as_mut_ptr()).unwrap() };
} }
/// Encrypt or decrypt in place (same operation with CTR mode). /// Encrypt or decrypt in place (same operation with CTR mode).
#[inline(always)] #[inline(always)]
pub fn crypt_in_place(&self, data: &mut [u8]) { pub fn crypt_in_place(&mut self, data: &mut [u8]) {
let ptr = data.as_mut_ptr(); let ptr = data.as_mut_ptr();
unsafe { self.0.update::<ENCRYPT>(data, ptr).unwrap() } unsafe { self.0.update::<ENCRYPT>(data, ptr).unwrap() }
} }
@ -64,7 +64,7 @@ impl<const ENCRYPT: bool> AesGcm<ENCRYPT> {
impl AesGcm<true> { impl AesGcm<true> {
/// Produce the gcm authentication tag. /// Produce the gcm authentication tag.
#[inline(always)] #[inline(always)]
pub fn finish_encrypt(&self) -> [u8; 16] { pub fn finish_encrypt(&mut self) -> [u8; 16] {
unsafe { unsafe {
let mut tag = MaybeUninit::<[u8; 16]>::uninit(); let mut tag = MaybeUninit::<[u8; 16]>::uninit();
self.0.finalize::<true>(tag.as_mut_ptr().cast()).unwrap(); self.0.finalize::<true>(tag.as_mut_ptr().cast()).unwrap();
@ -76,7 +76,7 @@ impl AesGcm<true> {
impl AesGcm<false> { impl AesGcm<false> {
/// Check the gcm authentication tag. Outputs true if it matches the just decrypted message, outputs false otherwise. /// Check the gcm authentication tag. Outputs true if it matches the just decrypted message, outputs false otherwise.
#[inline(always)] #[inline(always)]
pub fn finish_decrypt(&self, expected_tag: &[u8]) -> bool { pub fn finish_decrypt(&mut self, expected_tag: &[u8]) -> bool {
debug_assert_eq!(expected_tag.len(), 16); debug_assert_eq!(expected_tag.len(), 16);
if self.0.set_tag(expected_tag).is_ok() { if self.0.set_tag(expected_tag).is_ok() {
unsafe { self.0.finalize::<false>(ptr::null_mut()).is_ok() } unsafe { self.0.finalize::<false>(ptr::null_mut()).is_ok() }
@ -116,14 +116,14 @@ impl Aes {
/// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls. /// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls.
#[inline(always)] #[inline(always)]
pub fn encrypt_block_in_place(&self, data: &mut [u8]) { pub fn encrypt_block_in_place(&mut self, data: &mut [u8]) {
debug_assert_eq!(data.len(), AES_BLOCK_SIZE, "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing."); debug_assert_eq!(data.len(), AES_BLOCK_SIZE, "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing.");
let ptr = data.as_mut_ptr(); let ptr = data.as_mut_ptr();
unsafe { self.0.update::<true>(data, ptr).unwrap() } unsafe { self.0.update::<true>(data, ptr).unwrap() }
} }
/// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls. /// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls.
#[inline(always)] #[inline(always)]
pub fn decrypt_block_in_place(&self, data: &mut [u8]) { pub fn decrypt_block_in_place(&mut self, data: &mut [u8]) {
debug_assert_eq!(data.len(), AES_BLOCK_SIZE, "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing."); debug_assert_eq!(data.len(), AES_BLOCK_SIZE, "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing.");
let ptr = data.as_mut_ptr(); let ptr = data.as_mut_ptr();
unsafe { self.1.update::<false>(data, ptr).unwrap() } unsafe { self.1.update::<false>(data, ptr).unwrap() }

View file

@ -12,8 +12,8 @@ mod test {
fn aes_256_gcm() { fn aes_256_gcm() {
init(); init();
let key = Secret::move_bytes([1u8; 32]); let key = Secret::move_bytes([1u8; 32]);
let enc = AesGcm::<true>::new(&key); let mut enc = AesGcm::<true>::new(&key);
let dec = AesGcm::<false>::new(&key); let mut dec = AesGcm::<false>::new(&key);
let plain = [2u8; 127]; let plain = [2u8; 127];
let iv0 = [3u8; 12]; let iv0 = [3u8; 12];
@ -59,7 +59,7 @@ mod test {
} }
let iv = [1_u8; 12]; let iv = [1_u8; 12];
let c = AesGcm::<true>::new(&Secret::move_bytes([1_u8; 32])); let mut c = AesGcm::<true>::new(&Secret::move_bytes([1_u8; 32]));
let benchmark_iterations: usize = 80000; let benchmark_iterations: usize = 80000;
let start = SystemTime::now(); let start = SystemTime::now();
@ -73,7 +73,7 @@ mod test {
(((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64() (((benchmark_iterations * buf.len()) as f64) / 1048576.0) / duration.as_secs_f64()
); );
let c = AesGcm::<false>::new(&Secret::move_bytes([1_u8; 32])); let mut c = AesGcm::<false>::new(&Secret::move_bytes([1_u8; 32]));
let start = SystemTime::now(); let start = SystemTime::now();
for _ in 0..benchmark_iterations { for _ in 0..benchmark_iterations {
@ -91,7 +91,7 @@ mod test {
fn aes_gcm_test_vectors() { fn aes_gcm_test_vectors() {
// Even though we are just wrapping other implementations, it's still good to test thoroughly! // Even though we are just wrapping other implementations, it's still good to test thoroughly!
for tv in NIST_AES_GCM_TEST_VECTORS.iter() { for tv in NIST_AES_GCM_TEST_VECTORS.iter() {
let gcm = AesGcm::new(unsafe { &Secret::<32>::from_bytes(tv.key) }); let mut gcm = AesGcm::new(unsafe { &Secret::<32>::from_bytes(tv.key) });
gcm.reset_init_gcm(tv.nonce); gcm.reset_init_gcm(tv.nonce);
gcm.aad(tv.aad); gcm.aad(tv.aad);
let mut ciphertext = Vec::new(); let mut ciphertext = Vec::new();
@ -101,7 +101,7 @@ mod test {
assert!(tag.eq(tv.tag)); assert!(tag.eq(tv.tag));
assert!(ciphertext.as_slice().eq(tv.ciphertext)); assert!(ciphertext.as_slice().eq(tv.ciphertext));
let gcm = AesGcm::new(unsafe { &Secret::<32>::from_bytes(tv.key) }); let mut gcm = AesGcm::new(unsafe { &Secret::<32>::from_bytes(tv.key) });
gcm.reset_init_gcm(tv.nonce); gcm.reset_init_gcm(tv.nonce);
gcm.aad(tv.aad); gcm.aad(tv.aad);
let mut ct_copy = ciphertext.clone(); let mut ct_copy = ciphertext.clone();

View file

@ -281,7 +281,7 @@ pub fn hmac_sha512_secret<const C: usize>(key: &[u8], msg: &[u8]) -> Secret<C> {
let mut hm = HMACSHA512::new(key); let mut hm = HMACSHA512::new(key);
hm.update(msg); hm.update(msg);
let buff = hm.finish(); let buff = hm.finish();
unsafe { Secret::from_bytes(&buff) } unsafe { Secret::from_bytes(&buff[..C]) }
} }
#[inline(always)] #[inline(always)]

View file

@ -35,12 +35,9 @@ impl<const L: usize> Secret<L> {
/// Copy bytes into secret, then nuke the previous value, will panic if the slice does not match the size of this secret. /// Copy bytes into secret, then nuke the previous value, will panic if the slice does not match the size of this secret.
#[inline(always)] #[inline(always)]
pub fn from_bytes_then_nuke(b: &mut [u8]) -> Self { pub fn from_bytes_then_nuke(b: &mut [u8]) -> Self {
let mut k = [0u8; L]; let ret = Self (b.try_into().unwrap());
k.copy_from_slice(b); unsafe { OPENSSL_cleanse(b.as_mut_ptr().cast(), L) };
unsafe { ret
OPENSSL_cleanse(b.as_mut_ptr().cast(), L)
};
Self (k)
} }
#[inline(always)] #[inline(always)]
pub unsafe fn from_bytes(b: &[u8]) -> Self { pub unsafe fn from_bytes(b: &[u8]) -> Self {