Leverage type parameters and traits to serialize data

This utilizes two traits, AsBytes and FromBytes from the `zerocopy`
crate to enable type parameters on the hashed value.
This commit is contained in:
Erik Hollensbe 2022-04-15 14:50:58 -07:00
parent 7223de6fbb
commit 1b2485b277
No known key found for this signature in database
GPG key ID: 4BB0E241A863B389
2 changed files with 124 additions and 47 deletions

View file

@ -13,5 +13,9 @@ panic = 'abort'
[dependencies]
crc32fast = "^1"
zerocopy = { version = "0.6.1", features = ["alloc"] }
[dev-dependencies]
rand = ">=0"
[lib]

View file

@ -8,28 +8,38 @@
use std::borrow::Cow;
use zerocopy::{AsBytes, FromBytes};
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64", target_arch = "powerpc64")))]
#[inline(always)]
fn xor_with<const L: usize>(x: &mut [u8; L], y: &[u8; L]) {
x.iter_mut().zip(y.iter()).for_each(|(a, b)| *a ^= *b);
fn xor_with<T>(x: &mut T, y: &T)
where
T: FromBytes + AsBytes + Sized,
{
x.as_bytes_mut().iter_mut().zip(y.as_bytes().iter()).for_each(|(a, b)| *a ^= *b);
}
#[cfg(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64", target_arch = "powerpc64"))]
#[inline(always)]
fn xor_with<const L: usize>(x: &mut [u8; L], y: &[u8; L]) {
if L >= 16 {
for i in 0..(L / 16) {
unsafe { *x.as_mut_ptr().cast::<u128>().add(i) ^= *y.as_ptr().cast::<u128>().add(i) };
fn xor_with<T>(x: &mut T, y: &T)
where
T: FromBytes + AsBytes + Sized,
{
let size = std::mem::size_of::<T>();
if size >= 16 {
for i in 0..(size / 16) {
unsafe { *x.as_bytes_mut().as_mut_ptr().cast::<u128>().add(i) ^= *y.as_bytes().as_ptr().cast::<u128>().add(i) };
}
for i in (L - (L % 16))..L {
unsafe { *x.as_mut_ptr().add(i) ^= *y.as_ptr().add(i) };
for i in (size - (size % 16))..size {
unsafe { *x.as_bytes_mut().as_mut_ptr().add(i) ^= *y.as_bytes().as_ptr().add(i) };
}
} else {
for i in 0..(L / 8) {
unsafe { *x.as_mut_ptr().cast::<u64>().add(i) ^= *y.as_ptr().cast::<u64>().add(i) };
for i in 0..(size / 8) {
unsafe { *x.as_bytes_mut().as_mut_ptr().cast::<u64>().add(i) ^= *y.as_bytes().as_ptr().cast::<u64>().add(i) };
}
for i in (L - (L % 8))..L {
unsafe { *x.as_mut_ptr().add(i) ^= *y.as_ptr().add(i) };
for i in (size - (size % 8))..size {
unsafe { *x.as_bytes_mut().as_mut_ptr().add(i) ^= *y.as_bytes().as_ptr().add(i) };
}
}
}
@ -61,14 +71,19 @@ fn murmurhash32_mix32(mut x: u32) -> u32 {
///
/// The best value for HASHES seems to be 3 for an optimal fill of 75%.
#[repr(C)]
pub struct IBLT<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> {
pub struct IBLT<T, const BUCKETS: usize, const HASHES: usize>
where
T: FromBytes + AsBytes + Sized + Clone,
{
check_hash: [u32; BUCKETS],
count: [i8; BUCKETS],
key: [[u8; ITEM_BYTES]; BUCKETS],
key: [T; BUCKETS],
}
impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> Clone for IBLT<BUCKETS, ITEM_BYTES, HASHES> {
#[inline(always)]
impl<T, const BUCKETS: usize, const HASHES: usize> Clone for IBLT<T, BUCKETS, HASHES>
where
T: FromBytes + AsBytes + Sized + Clone,
{
fn clone(&self) -> Self {
unsafe {
let mut tmp: Self = std::mem::MaybeUninit::uninit().assume_init();
@ -78,9 +93,12 @@ impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> Clone f
}
}
impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> IBLT<BUCKETS, ITEM_BYTES, HASHES> {
impl<T, const BUCKETS: usize, const HASHES: usize> IBLT<T, BUCKETS, HASHES>
where
T: FromBytes + AsBytes + Sized + Clone,
{
/// Number of bytes each bucket consumes (not congituously, but doesn't matter).
const BUCKET_SIZE_BYTES: usize = ITEM_BYTES + 4 + 1;
const BUCKET_SIZE_BYTES: usize = std::mem::size_of::<T>() + 4 + 1;
/// Number of buckets in this IBLT.
#[allow(unused)]
@ -97,8 +115,6 @@ impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> IBLT<BU
unsafe { std::mem::zeroed() }
}
/// Get this IBLT as a byte slice (free cast operation).
/// The returned slice is always SIZE_BYTES in length.
#[inline(always)]
pub fn as_bytes(&self) -> &[u8] {
unsafe { &*std::ptr::slice_from_raw_parts((self as *const Self).cast::<u8>(), Self::SIZE_BYTES) }
@ -136,37 +152,37 @@ impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> IBLT<BU
unsafe { std::ptr::write_bytes((self as *mut Self).cast::<u8>(), 0, std::mem::size_of::<Self>()) };
}
pub(crate) fn ins_rem(&mut self, key: &[u8; ITEM_BYTES], delta: i8) {
let check_hash = crc32fast::hash(key);
pub(crate) fn ins_rem(&mut self, key: T, delta: i8) {
let check_hash = crc32fast::hash(key.as_bytes());
let mut iteration_index = u32::from_le(check_hash).wrapping_add(1);
for _ in 0..(HASHES as u64) {
iteration_index = murmurhash32_mix32(iteration_index);
let i = (iteration_index as usize) % BUCKETS;
self.check_hash[i] ^= check_hash;
self.count[i] = self.count[i].wrapping_add(delta);
xor_with(&mut self.key[i], key);
xor_with(&mut self.key[i], &key);
}
}
/// Insert a set item into this set.
/// This will panic if the slice is smaller than ITEM_BYTES.
#[inline(always)]
pub fn insert(&mut self, key: &[u8; ITEM_BYTES]) {
self.ins_rem(unsafe { &*key.as_ptr().cast() }, 1);
pub fn insert(&mut self, key: T) {
self.ins_rem(key, 1);
}
/// Insert a set item into this set.
/// This will panic if the slice is smaller than ITEM_BYTES.
#[inline(always)]
pub fn remove(&mut self, key: &[u8; ITEM_BYTES]) {
self.ins_rem(unsafe { &*key.as_ptr().cast() }, -1);
pub fn remove(&mut self, key: T) {
self.ins_rem(key, -1);
}
/// Subtract another IBLT from this one to get a set difference.
pub fn subtract(&mut self, other: &Self) {
self.check_hash.iter_mut().zip(other.check_hash.iter()).for_each(|(a, b)| *a ^= *b);
self.count.iter_mut().zip(other.count.iter()).for_each(|(a, b)| *a = a.wrapping_sub(*b));
self.key.iter_mut().zip(other.key.iter()).for_each(|(a, b)| xor_with(a, b));
self.key.iter_mut().zip(other.key.iter()).for_each(|(a, b)| xor_with(a, &b));
}
/// List as many entries in this IBLT as can be extracted.
@ -183,12 +199,12 @@ impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> IBLT<BU
/// Due to the small check hash sizes used in this IBLT there is a very small chance this will list
/// bogus items that were never added. This is not an issue with this protocol as it would just result
/// in an unsatisfied record request.
pub fn list<F: FnMut([u8; ITEM_BYTES], bool)>(mut self, mut f: F) -> bool {
pub fn list<F: FnMut(T, bool)>(&mut self, mut f: F) -> bool {
let mut queue: Vec<u32> = Vec::with_capacity(BUCKETS);
for i in 0..BUCKETS {
let count = self.count[i];
if (count == 1 || count == -1) && crc32fast::hash(&self.key[i]) == self.check_hash[i] {
if (count == 1 || count == -1) && crc32fast::hash(&self.key[i].as_bytes()) == self.check_hash[i] {
queue.push(i as u32);
}
}
@ -204,7 +220,7 @@ impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> IBLT<BU
let check_hash = self.check_hash[i];
let count = self.count[i];
let key = &self.key[i];
if (count == 1 || count == -1) && check_hash == crc32fast::hash(key) {
if (count == 1 || count == -1) && check_hash == crc32fast::hash(key.as_bytes()) {
let key = key.clone();
let mut iteration_index = u32::from_le(check_hash).wrapping_add(1);
@ -217,7 +233,7 @@ impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> IBLT<BU
self.check_hash[i2] = check_hash2;
self.count[i2] = count2;
xor_with(key2, &key);
if (count2 == 1 || count2 == -1) && check_hash2 == crc32fast::hash(key2) {
if (count2 == 1 || count2 == -1) && check_hash2 == crc32fast::hash(key2.as_bytes()) {
if queue.len() > BUCKETS {
// sanity check, should be impossible
break 'list_main;
@ -226,7 +242,7 @@ impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> IBLT<BU
}
}
f(key, count == 1);
f(key.clone(), count == 1);
}
}
@ -234,14 +250,17 @@ impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> IBLT<BU
}
}
impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> PartialEq for IBLT<BUCKETS, ITEM_BYTES, HASHES> {
impl<T, const BUCKETS: usize, const HASHES: usize> PartialEq for IBLT<T, BUCKETS, HASHES>
where
T: AsBytes + FromBytes + Clone,
{
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
self.as_bytes().eq(other.as_bytes())
}
}
impl<const BUCKETS: usize, const ITEM_BYTES: usize, const HASHES: usize> Eq for IBLT<BUCKETS, ITEM_BYTES, HASHES> {}
impl<T, const BUCKETS: usize, const HASHES: usize> Eq for IBLT<T, BUCKETS, HASHES> where T: AsBytes + FromBytes + Clone {}
#[cfg(test)]
mod tests {
@ -296,28 +315,28 @@ mod tests {
#[test]
fn struct_packing() {
// Typical case
let mut tmp = IBLT::<64, 16, 3>::new();
let mut tmp = IBLT::<[u8; 64], 16, 3>::new();
tmp.check_hash.fill(0x01010101);
tmp.count.fill(1);
tmp.key.iter_mut().for_each(|x| x.fill(1));
assert!(tmp.as_bytes().iter().all(|x| *x == 1));
// Pathological alignment case #1
let mut tmp = IBLT::<17, 13, 3>::new();
let mut tmp = IBLT::<[u8; 17], 13, 3>::new();
tmp.check_hash.fill(0x01010101);
tmp.count.fill(1);
tmp.key.iter_mut().for_each(|x| x.fill(1));
assert!(tmp.as_bytes().iter().all(|x| *x == 1));
// Pathological alignment case #2
let mut tmp = IBLT::<17, 8, 3>::new();
let mut tmp = IBLT::<[u8; 17], 8, 3>::new();
tmp.check_hash.fill(0x01010101);
tmp.count.fill(1);
tmp.key.iter_mut().for_each(|x| x.fill(1));
assert!(tmp.as_bytes().iter().all(|x| *x == 1));
// Pathological alignment case #3
let mut tmp = IBLT::<16, 7, 3>::new();
let mut tmp = IBLT::<[u8; 16], 7, 3>::new();
tmp.check_hash.fill(0x01010101);
tmp.count.fill(1);
tmp.key.iter_mut().for_each(|x| x.fill(1));
@ -326,18 +345,19 @@ mod tests {
#[test]
fn fill_list_performance() {
const LENGTH: usize = 16;
const CAPACITY: usize = 4096;
let mut rn: u128 = 0xd3b07384d113edec49eaa6238ad5ff00;
let mut expected: HashSet<u128> = HashSet::with_capacity(4096);
let mut count = 64;
let mut count = LENGTH;
while count <= CAPACITY {
let mut test = IBLT::<CAPACITY, 16, HASHES>::new();
let mut test = IBLT::<[u8; LENGTH], CAPACITY, HASHES>::new();
expected.clear();
for _ in 0..count {
rn = rn.wrapping_add(splitmix64(rn as u64) as u128);
expected.insert(rn);
test.insert(&rn.to_le_bytes());
test.insert(rn.to_le_bytes());
}
let mut list_count = 0;
@ -348,7 +368,7 @@ mod tests {
});
println!("inserted: {}\tlisted: {}\tcapacity: {}\tscore: {:.4}\tfill: {:.4}", count, list_count, CAPACITY, (list_count as f64) / (count as f64), (count as f64) / (CAPACITY as f64));
count += 64;
count += LENGTH;
}
}
@ -357,6 +377,7 @@ mod tests {
const CAPACITY: usize = 4096; // previously 16384;
const REMOTE_SIZE: usize = 1024 * 1024 * 2;
const STEP: usize = 1024;
const LENGTH: usize = 16;
let mut rn: u128 = 0xd3b07384d113edec49eaa6238ad5ff00;
let mut missing_count = 1024;
let mut missing: HashSet<u128> = HashSet::with_capacity(CAPACITY * 2);
@ -364,19 +385,19 @@ mod tests {
while missing_count <= CAPACITY {
missing.clear();
all.clear();
let mut local = IBLT::<CAPACITY, 16, HASHES>::new();
let mut remote = IBLT::<CAPACITY, 16, HASHES>::new();
let mut local = IBLT::<[u8; LENGTH], CAPACITY, HASHES>::new();
let mut remote = IBLT::<[u8; LENGTH], CAPACITY, HASHES>::new();
let mut k = 0;
while k < REMOTE_SIZE {
rn = rn.wrapping_add(splitmix64(rn as u64) as u128);
if all.insert(rn) {
if k >= missing_count {
local.insert(&rn.to_le_bytes());
local.insert(rn.to_le_bytes());
} else {
missing.insert(rn);
}
remote.insert(&rn.to_le_bytes());
remote.insert(rn.to_le_bytes());
k += 1;
}
}
@ -400,4 +421,56 @@ mod tests {
missing_count += STEP;
}
}
#[derive(Eq, PartialEq, Clone, AsBytes, FromBytes, Debug)]
#[repr(C)]
struct TestType {
thing: [u8; 256],
other_thing: [u8; 32],
}
impl TestType {
pub fn zeroed() -> Self {
unsafe { std::mem::zeroed() }
}
pub fn new() -> Self {
let mut newtype = Self::zeroed();
newtype.thing.fill_with(|| rand::random());
newtype.other_thing.fill_with(|| rand::random());
newtype
}
}
#[test]
fn test_polymorphism() {
const CAPACITY: usize = 512;
let mut full = IBLT::<TestType, CAPACITY, HASHES>::new();
let mut zero = IBLT::<TestType, CAPACITY, HASHES>::new();
for _ in 0..CAPACITY {
zero.insert(TestType::zeroed());
full.insert(TestType::new());
}
full.subtract(&zero);
zero.list(|item, new| {
if !new {
assert_eq!(item, TestType::zeroed());
}
});
zero.reset();
for _ in 0..CAPACITY {
zero.insert(TestType::zeroed());
}
zero.subtract(&full);
full.list(|item, new| {
if !new {
assert_ne!(item, TestType::zeroed());
}
});
}
}