Small AES optimizations on ARM64.

This commit is contained in:
Adam Ierymenko 2020-08-03 23:14:02 +00:00
parent 4273d89373
commit d0cc3ac333

View file

@ -22,17 +22,19 @@ namespace {
#ifdef ZT_AES_NEON
ZT_INLINE uint8x16_t s_clmul_armneon_crypto(uint8x16_t a8, const uint8x16_t y, const uint8_t b[16]) noexcept
ZT_INLINE uint8x16_t s_clmul_armneon_crypto(uint8x16_t h, uint8x16_t y, const uint8_t b[16]) noexcept
{
const uint8x16_t p = vreinterpretq_u8_u64(vdupq_n_u64(0x0000000000000087));
const uint8x16_t z = vdupq_n_u8(0);
uint8x16_t b8 = vrbitq_u8(veorq_u8(vld1q_u8(b), y));
uint8x16_t r0, r1, t0, t1;
__asm__ __volatile__("pmull %0.1q, %1.1d, %2.1d \n\t" : "=w" (r0) : "w" (a8), "w" (b8));
__asm__ __volatile__("pmull2 %0.1q, %1.2d, %2.2d \n\t" :"=w" (r1) : "w" (a8), "w" (b8));
t0 = vextq_u8(b8, b8, 8);
__asm__ __volatile__("pmull %0.1q, %1.1d, %2.1d \n\t" : "=w" (t1) : "w" (a8), "w" (t0));
__asm__ __volatile__("pmull2 %0.1q, %1.2d, %2.2d \n\t" :"=w" (t0) : "w" (a8), "w" (t0));
r0 = vld1q_u8(b);
const uint8x16_t z = veorq_u8(h, h);
y = veorq_u8(r0, y);
y = vrbitq_u8(y);
const uint8x16_t p = vreinterpretq_u8_u64(vdupq_n_u64(0x0000000000000087));
t0 = vextq_u8(y, y, 8);
__asm__ __volatile__("pmull %0.1q, %1.1d, %2.1d \n\t" : "=w" (r0) : "w" (h), "w" (y));
__asm__ __volatile__("pmull2 %0.1q, %1.2d, %2.2d \n\t" :"=w" (r1) : "w" (h), "w" (y));
__asm__ __volatile__("pmull %0.1q, %1.1d, %2.1d \n\t" : "=w" (t1) : "w" (h), "w" (t0));
__asm__ __volatile__("pmull2 %0.1q, %1.2d, %2.2d \n\t" :"=w" (t0) : "w" (h), "w" (t0));
t0 = veorq_u8(t0, t1);
t1 = vextq_u8(z, t0, 8);
r0 = veorq_u8(r0, t1);
@ -871,9 +873,9 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
--len;
out[totalLen++] = *(in++);
if (!(totalLen & 15U)) {
uint8x16_t pt = vld1q_u8(out + (totalLen - 16));
uint8_t *const otmp = out + (totalLen - 16);
uint8x16_t d0 = vrev32q_u8(dd);
dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
uint8x16_t pt = vld1q_u8(otmp);
d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
d0 = vaesmcq_u8(vaeseq_u8(d0, k1));
d0 = vaesmcq_u8(vaeseq_u8(d0, k2));
@ -888,7 +890,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
d0 = vaesmcq_u8(vaeseq_u8(d0, k11));
d0 = vaesmcq_u8(vaeseq_u8(d0, k12));
d0 = veorq_u8(vaeseq_u8(d0, k13), k14);
vst1q_u8(out + (totalLen - 16), veorq_u8(pt, d0));
vst1q_u8(otmp, veorq_u8(pt, d0));
dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
break;
}
}
@ -898,23 +901,18 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
_len = totalLen + len;
if (likely(len >= 64)) {
const uint32x4_t four = {0,0,0,4};
const uint32x4_t four = vshlq_n_u32(one, 2);
uint8x16_t dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
uint8x16_t dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, one);
uint8x16_t dd3 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, one);
for (;;) {
len -= 64;
uint8x16_t pt0 = vld1q_u8(in);
uint8x16_t pt1 = vld1q_u8(in + 16);
uint8x16_t pt2 = vld1q_u8(in + 32);
uint8x16_t pt3 = vld1q_u8(in + 48);
in += 64;
uint8x16_t d0 = vrev32q_u8(dd);
uint8x16_t d1 = vrev32q_u8(dd1);
uint8x16_t d2 = vrev32q_u8(dd2);
uint8x16_t d3 = vrev32q_u8(dd3);
uint8x16_t pt0 = vld1q_u8(in);
in += 16;
d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
d1 = vaesmcq_u8(vaeseq_u8(d1, k0));
d2 = vaesmcq_u8(vaeseq_u8(d2, k0));
@ -927,6 +925,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
d1 = vaesmcq_u8(vaeseq_u8(d1, k2));
d2 = vaesmcq_u8(vaeseq_u8(d2, k2));
d3 = vaesmcq_u8(vaeseq_u8(d3, k2));
uint8x16_t pt1 = vld1q_u8(in);
in += 16;
d0 = vaesmcq_u8(vaeseq_u8(d0, k3));
d1 = vaesmcq_u8(vaeseq_u8(d1, k3));
d2 = vaesmcq_u8(vaeseq_u8(d2, k3));
@ -939,6 +939,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
d1 = vaesmcq_u8(vaeseq_u8(d1, k5));
d2 = vaesmcq_u8(vaeseq_u8(d2, k5));
d3 = vaesmcq_u8(vaeseq_u8(d3, k5));
uint8x16_t pt2 = vld1q_u8(in);
in += 16;
d0 = vaesmcq_u8(vaeseq_u8(d0, k6));
d1 = vaesmcq_u8(vaeseq_u8(d1, k6));
d2 = vaesmcq_u8(vaeseq_u8(d2, k6));
@ -951,6 +953,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
d1 = vaesmcq_u8(vaeseq_u8(d1, k8));
d2 = vaesmcq_u8(vaeseq_u8(d2, k8));
d3 = vaesmcq_u8(vaeseq_u8(d3, k8));
uint8x16_t pt3 = vld1q_u8(in);
in += 16;
d0 = vaesmcq_u8(vaeseq_u8(d0, k9));
d1 = vaesmcq_u8(vaeseq_u8(d1, k9));
d2 = vaesmcq_u8(vaeseq_u8(d2, k9));
@ -984,7 +988,7 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
out += 64;
dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, four);
if (len < 64)
if (unlikely(len < 64))
break;
dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, four);
dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, four);
@ -994,9 +998,9 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
while (len >= 16) {
len -= 16;
uint8x16_t d0 = vrev32q_u8(dd);
uint8x16_t pt = vld1q_u8(in);
in += 16;
uint8x16_t d0 = vrev32q_u8(dd);
dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
d0 = vaesmcq_u8(vaeseq_u8(d0, k1));