diff --git a/core/AES.cpp b/core/AES.cpp index 7589be88e..e0f4bb0e9 100644 --- a/core/AES.cpp +++ b/core/AES.cpp @@ -22,17 +22,19 @@ namespace { #ifdef ZT_AES_NEON -ZT_INLINE uint8x16_t s_clmul_armneon_crypto(uint8x16_t a8, const uint8x16_t y, const uint8_t b[16]) noexcept +ZT_INLINE uint8x16_t s_clmul_armneon_crypto(uint8x16_t h, uint8x16_t y, const uint8_t b[16]) noexcept { - const uint8x16_t p = vreinterpretq_u8_u64(vdupq_n_u64(0x0000000000000087)); - const uint8x16_t z = vdupq_n_u8(0); - uint8x16_t b8 = vrbitq_u8(veorq_u8(vld1q_u8(b), y)); uint8x16_t r0, r1, t0, t1; - __asm__ __volatile__("pmull %0.1q, %1.1d, %2.1d \n\t" : "=w" (r0) : "w" (a8), "w" (b8)); - __asm__ __volatile__("pmull2 %0.1q, %1.2d, %2.2d \n\t" :"=w" (r1) : "w" (a8), "w" (b8)); - t0 = vextq_u8(b8, b8, 8); - __asm__ __volatile__("pmull %0.1q, %1.1d, %2.1d \n\t" : "=w" (t1) : "w" (a8), "w" (t0)); - __asm__ __volatile__("pmull2 %0.1q, %1.2d, %2.2d \n\t" :"=w" (t0) : "w" (a8), "w" (t0)); + r0 = vld1q_u8(b); + const uint8x16_t z = veorq_u8(h, h); + y = veorq_u8(r0, y); + y = vrbitq_u8(y); + const uint8x16_t p = vreinterpretq_u8_u64(vdupq_n_u64(0x0000000000000087)); + t0 = vextq_u8(y, y, 8); + __asm__ __volatile__("pmull %0.1q, %1.1d, %2.1d \n\t" : "=w" (r0) : "w" (h), "w" (y)); + __asm__ __volatile__("pmull2 %0.1q, %1.2d, %2.2d \n\t" :"=w" (r1) : "w" (h), "w" (y)); + __asm__ __volatile__("pmull %0.1q, %1.1d, %2.1d \n\t" : "=w" (t1) : "w" (h), "w" (t0)); + __asm__ __volatile__("pmull2 %0.1q, %1.2d, %2.2d \n\t" :"=w" (t0) : "w" (h), "w" (t0)); t0 = veorq_u8(t0, t1); t1 = vextq_u8(z, t0, 8); r0 = veorq_u8(r0, t1); @@ -871,9 +873,9 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept --len; out[totalLen++] = *(in++); if (!(totalLen & 15U)) { - uint8x16_t pt = vld1q_u8(out + (totalLen - 16)); + uint8_t *const otmp = out + (totalLen - 16); uint8x16_t d0 = vrev32q_u8(dd); - dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one); + uint8x16_t pt = vld1q_u8(otmp); d0 = vaesmcq_u8(vaeseq_u8(d0, k0)); d0 = vaesmcq_u8(vaeseq_u8(d0, k1)); d0 = vaesmcq_u8(vaeseq_u8(d0, k2)); @@ -888,7 +890,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept d0 = vaesmcq_u8(vaeseq_u8(d0, k11)); d0 = vaesmcq_u8(vaeseq_u8(d0, k12)); d0 = veorq_u8(vaeseq_u8(d0, k13), k14); - vst1q_u8(out + (totalLen - 16), veorq_u8(pt, d0)); + vst1q_u8(otmp, veorq_u8(pt, d0)); + dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one); break; } } @@ -898,23 +901,18 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept _len = totalLen + len; if (likely(len >= 64)) { - const uint32x4_t four = {0,0,0,4}; + const uint32x4_t four = vshlq_n_u32(one, 2); uint8x16_t dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one); uint8x16_t dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, one); uint8x16_t dd3 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, one); for (;;) { len -= 64; - uint8x16_t pt0 = vld1q_u8(in); - uint8x16_t pt1 = vld1q_u8(in + 16); - uint8x16_t pt2 = vld1q_u8(in + 32); - uint8x16_t pt3 = vld1q_u8(in + 48); - in += 64; - uint8x16_t d0 = vrev32q_u8(dd); uint8x16_t d1 = vrev32q_u8(dd1); uint8x16_t d2 = vrev32q_u8(dd2); uint8x16_t d3 = vrev32q_u8(dd3); - + uint8x16_t pt0 = vld1q_u8(in); + in += 16; d0 = vaesmcq_u8(vaeseq_u8(d0, k0)); d1 = vaesmcq_u8(vaeseq_u8(d1, k0)); d2 = vaesmcq_u8(vaeseq_u8(d2, k0)); @@ -927,6 +925,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept d1 = vaesmcq_u8(vaeseq_u8(d1, k2)); d2 = vaesmcq_u8(vaeseq_u8(d2, k2)); d3 = vaesmcq_u8(vaeseq_u8(d3, k2)); + uint8x16_t pt1 = vld1q_u8(in); + in += 16; d0 = vaesmcq_u8(vaeseq_u8(d0, k3)); d1 = vaesmcq_u8(vaeseq_u8(d1, k3)); d2 = vaesmcq_u8(vaeseq_u8(d2, k3)); @@ -939,6 +939,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept d1 = vaesmcq_u8(vaeseq_u8(d1, k5)); d2 = vaesmcq_u8(vaeseq_u8(d2, k5)); d3 = vaesmcq_u8(vaeseq_u8(d3, k5)); + uint8x16_t pt2 = vld1q_u8(in); + in += 16; d0 = vaesmcq_u8(vaeseq_u8(d0, k6)); d1 = vaesmcq_u8(vaeseq_u8(d1, k6)); d2 = vaesmcq_u8(vaeseq_u8(d2, k6)); @@ -951,6 +953,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept d1 = vaesmcq_u8(vaeseq_u8(d1, k8)); d2 = vaesmcq_u8(vaeseq_u8(d2, k8)); d3 = vaesmcq_u8(vaeseq_u8(d3, k8)); + uint8x16_t pt3 = vld1q_u8(in); + in += 16; d0 = vaesmcq_u8(vaeseq_u8(d0, k9)); d1 = vaesmcq_u8(vaeseq_u8(d1, k9)); d2 = vaesmcq_u8(vaeseq_u8(d2, k9)); @@ -984,7 +988,7 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept out += 64; dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, four); - if (len < 64) + if (unlikely(len < 64)) break; dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, four); dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, four); @@ -994,9 +998,9 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept while (len >= 16) { len -= 16; + uint8x16_t d0 = vrev32q_u8(dd); uint8x16_t pt = vld1q_u8(in); in += 16; - uint8x16_t d0 = vrev32q_u8(dd); dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one); d0 = vaesmcq_u8(vaeseq_u8(d0, k0)); d0 = vaesmcq_u8(vaeseq_u8(d0, k1));