From 7efaab2af1a0880976f82dad590f0e577fd31107 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Thu, 30 Jul 2020 04:17:01 +0000 Subject: [PATCH] Add 4X parallel ARM AES so VTEC will kick in, yo. Seems to help on Graviton, not much on small chips but thats okay. --- core/AES.cpp | 180 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 134 insertions(+), 46 deletions(-) diff --git a/core/AES.cpp b/core/AES.cpp index 12a20736c..745940b4f 100644 --- a/core/AES.cpp +++ b/core/AES.cpp @@ -841,7 +841,7 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept #ifdef ZT_AES_NEON if (Utils::ARMCAP.aes) { - uint8x16_t dd = vld1q_u8(reinterpret_cast(_ctr)); + uint8x16_t dd = vrev32q_u8(vld1q_u8(reinterpret_cast(_ctr))); const uint32x4_t one = {0,0,0,1}; uint8x16_t k0 = _aes._k.neon.ek[0]; @@ -864,36 +864,31 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept if ((totalLen & 15U)) { for (;;) { if (unlikely(!len)) { - vst1q_u8(reinterpret_cast(_ctr), dd); + vst1q_u8(reinterpret_cast(_ctr), vrev32q_u8(dd)); _len = totalLen; return; } --len; out[totalLen++] = *(in++); if (!(totalLen & 15U)) { - uint8x16_t tmp = dd; - dd = vrev32q_u8(dd); + uint8x16_t pt = vld1q_u8(out + (totalLen - 16)); + uint8x16_t d0 = vrev32q_u8(dd); dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one); - dd = vrev32q_u8(dd); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k0)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k1)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k2)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k3)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k4)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k5)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k6)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k7)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k8)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k9)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k10)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k11)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k12)); - tmp = veorq_u8(vaeseq_u8(tmp, k13), k14); - uint8x16_t pt = vld1q_u8(reinterpret_cast(out + (totalLen - 16))); - vst1q_u8(reinterpret_cast(out + (totalLen - 16)), veorq_u8(pt, tmp)); - //__m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16)); - //const __m128i p0 = _mm_loadu_si128(outblk); - //_mm_storeu_si128(outblk, _mm_xor_si128(p0, d0)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k0)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k1)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k2)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k3)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k4)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k5)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k6)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k7)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k8)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k9)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k10)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k11)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k12)); + d0 = veorq_u8(vaeseq_u8(d0, k13), k14); + vst1q_u8(out + (totalLen - 16), veorq_u8(pt, d0)); break; } } @@ -902,29 +897,122 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept out += totalLen; _len = totalLen + len; + if (len >= 64) { + const uint32x4_t four = {0,0,0,4}; + uint8x16_t dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one); + uint8x16_t dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, one); + uint8x16_t dd3 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, one); + for (;;) { + len -= 64; + uint8x16_t pt0 = vld1q_u8(in); + uint8x16_t pt1 = vld1q_u8(in + 16); + uint8x16_t pt2 = vld1q_u8(in + 32); + uint8x16_t pt3 = vld1q_u8(in + 48); + in += 64; + + uint8x16_t d0 = vrev32q_u8(dd); + uint8x16_t d1 = vrev32q_u8(dd1); + uint8x16_t d2 = vrev32q_u8(dd2); + uint8x16_t d3 = vrev32q_u8(dd3); + + d0 = vaesmcq_u8(vaeseq_u8(d0, k0)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k0)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k0)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k0)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k1)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k1)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k1)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k1)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k2)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k2)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k2)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k2)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k3)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k3)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k3)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k3)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k4)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k4)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k4)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k4)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k5)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k5)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k5)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k5)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k6)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k6)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k6)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k6)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k7)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k7)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k7)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k7)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k8)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k8)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k8)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k8)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k9)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k9)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k9)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k9)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k10)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k10)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k10)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k10)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k11)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k11)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k11)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k11)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k12)); + d1 = vaesmcq_u8(vaeseq_u8(d1, k12)); + d2 = vaesmcq_u8(vaeseq_u8(d2, k12)); + d3 = vaesmcq_u8(vaeseq_u8(d3, k12)); + d0 = veorq_u8(vaeseq_u8(d0, k13), k14); + d1 = veorq_u8(vaeseq_u8(d1, k13), k14); + d2 = veorq_u8(vaeseq_u8(d2, k13), k14); + d3 = veorq_u8(vaeseq_u8(d3, k13), k14); + + d0 = veorq_u8(pt0, d0); + d1 = veorq_u8(pt1, d1); + d2 = veorq_u8(pt2, d2); + d3 = veorq_u8(pt3, d3); + + vst1q_u8(out, d0); + vst1q_u8(out + 16, d1); + vst1q_u8(out + 32, d2); + vst1q_u8(out + 48, d3); + out += 64; + + dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, four); + if (len < 64) + break; + dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, four); + dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, four); + dd3 = (uint8x16_t)vaddq_u32((uint32x4_t)dd3, four); + } + } + while (len >= 16) { - uint8x16_t tmp = dd; - dd = vrev32q_u8(dd); - dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one); - dd = vrev32q_u8(dd); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k0)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k1)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k2)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k3)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k4)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k5)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k6)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k7)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k8)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k9)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k10)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k11)); - tmp = vaesmcq_u8(vaeseq_u8(tmp, k12)); - tmp = veorq_u8(vaeseq_u8(tmp, k13), k14); - uint8x16_t pt = vld1q_u8(reinterpret_cast(in)); - vst1q_u8(reinterpret_cast(out), veorq_u8(pt, tmp)); - in += 16; len -= 16; + uint8x16_t pt = vld1q_u8(in); + in += 16; + uint8x16_t d0 = vrev32q_u8(dd); + dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one); + d0 = vaesmcq_u8(vaeseq_u8(d0, k0)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k1)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k2)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k3)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k4)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k5)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k6)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k7)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k8)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k9)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k10)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k11)); + d0 = vaesmcq_u8(vaeseq_u8(d0, k12)); + d0 = veorq_u8(vaeseq_u8(d0, k13), k14); + vst1q_u8(out, veorq_u8(pt, d0)); out += 16; } @@ -934,7 +1022,7 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept for (unsigned int i = 0; i < len; ++i) out[i] = in[i]; - vst1q_u8(reinterpret_cast(_ctr), dd); + vst1q_u8(reinterpret_cast(_ctr), vrev32q_u8(dd)); return; } #endif // ZT_AES_NEON