From 1f02250dd81c822850961b7498265afc33743721 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Mon, 24 Feb 2020 11:56:37 -0800 Subject: [PATCH] Ridiculously fast AES-CTR --- node/AES.cpp | 232 ++++++++++++++++++++++++++++----------------------- 1 file changed, 127 insertions(+), 105 deletions(-) diff --git a/node/AES.cpp b/node/AES.cpp index 318e4bf52..4b2cbb149 100644 --- a/node/AES.cpp +++ b/node/AES.cpp @@ -477,6 +477,15 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept uint64_t c0 = _ctr[0]; uint64_t c1 = Utils::ntoh(_ctr[1]); + // There are 16 XMM registers. We can reserve six of them for the + // first six parts of the expanded AES key. + const __m128i k0 = _aes._k.ni.k[0]; + const __m128i k1 = _aes._k.ni.k[1]; + const __m128i k2 = _aes._k.ni.k[2]; + const __m128i k3 = _aes._k.ni.k[3]; + const __m128i k4 = _aes._k.ni.k[4]; + const __m128i k5 = _aes._k.ni.k[5]; + // Complete any unfinished blocks from previous calls to crypt(). unsigned int totalLen = _len; if ((totalLen & 15U)) { @@ -491,23 +500,33 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept out[totalLen++] = *(in++); if (!(totalLen & 15U)) { __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - d0 = _mm_xor_si128(d0,_aes._k.ni.k[0]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[1]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[2]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[3]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[4]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[5]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[6]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[7]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[8]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[9]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[10]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[11]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[12]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[13]); - d0 = _mm_aesenclast_si128(d0,_aes._k.ni.k[14]); + d0 = _mm_xor_si128(d0,k0); + d0 = _mm_aesenc_si128(d0,k1); + __m128i ka = _aes._k.ni.k[6]; + d0 = _mm_aesenc_si128(d0,k2); + __m128i kb = _aes._k.ni.k[7]; + d0 = _mm_aesenc_si128(d0,k3); + __m128i kc = _aes._k.ni.k[8]; + d0 = _mm_aesenc_si128(d0,k4); + __m128i kd = _aes._k.ni.k[9]; + d0 = _mm_aesenc_si128(d0,k5); + __m128i ke = _aes._k.ni.k[10]; + d0 = _mm_aesenc_si128(d0,ka); + __m128i kf = _aes._k.ni.k[11]; + d0 = _mm_aesenc_si128(d0,kb); + __m128i kg = _aes._k.ni.k[12]; + d0 = _mm_aesenc_si128(d0,kc); + __m128i kh = _aes._k.ni.k[13]; + d0 = _mm_aesenc_si128(d0,kd); + ka = _aes._k.ni.k[14]; + d0 = _mm_aesenc_si128(d0,ke); __m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16)); - _mm_storeu_si128(outblk,_mm_xor_si128(_mm_loadu_si128(outblk),d0)); + d0 = _mm_aesenc_si128(d0,kf); + const __m128i p0 = _mm_loadu_si128(outblk); + d0 = _mm_aesenc_si128(d0,kg); + d0 = _mm_aesenc_si128(d0,kh); + d0 = _mm_aesenclast_si128(d0,ka); + _mm_storeu_si128(outblk,_mm_xor_si128(p0,d0)); if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); break; } @@ -536,10 +555,6 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); } - __m128i k0 = _aes._k.ni.k[0]; - __m128i k1 = _aes._k.ni.k[1]; - __m128i k2 = _aes._k.ni.k[2]; - __m128i k3 = _aes._k.ni.k[3]; d0 = _mm_xor_si128(d0,k0); d1 = _mm_xor_si128(d1,k0); d2 = _mm_xor_si128(d2,k0); @@ -548,82 +563,79 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept d1 = _mm_aesenc_si128(d1,k1); d2 = _mm_aesenc_si128(d2,k1); d3 = _mm_aesenc_si128(d3,k1); - k0 = _aes._k.ni.k[4]; - k1 = _aes._k.ni.k[5]; + __m128i ka = _aes._k.ni.k[6]; d0 = _mm_aesenc_si128(d0,k2); d1 = _mm_aesenc_si128(d1,k2); d2 = _mm_aesenc_si128(d2,k2); d3 = _mm_aesenc_si128(d3,k2); + __m128i kb = _aes._k.ni.k[7]; d0 = _mm_aesenc_si128(d0,k3); d1 = _mm_aesenc_si128(d1,k3); d2 = _mm_aesenc_si128(d2,k3); d3 = _mm_aesenc_si128(d3,k3); - k2 = _aes._k.ni.k[6]; - k3 = _aes._k.ni.k[7]; - d0 = _mm_aesenc_si128(d0,k0); - d1 = _mm_aesenc_si128(d1,k0); - d2 = _mm_aesenc_si128(d2,k0); - d3 = _mm_aesenc_si128(d3,k0); - d0 = _mm_aesenc_si128(d0,k1); - d1 = _mm_aesenc_si128(d1,k1); - d2 = _mm_aesenc_si128(d2,k1); - d3 = _mm_aesenc_si128(d3,k1); - k0 = _aes._k.ni.k[8]; - k1 = _aes._k.ni.k[9]; - d0 = _mm_aesenc_si128(d0,k2); - d1 = _mm_aesenc_si128(d1,k2); - d2 = _mm_aesenc_si128(d2,k2); - d3 = _mm_aesenc_si128(d3,k2); - d0 = _mm_aesenc_si128(d0,k3); - d1 = _mm_aesenc_si128(d1,k3); - d2 = _mm_aesenc_si128(d2,k3); - d3 = _mm_aesenc_si128(d3,k3); - k2 = _aes._k.ni.k[10]; - k3 = _aes._k.ni.k[11]; - d0 = _mm_aesenc_si128(d0,k0); - d1 = _mm_aesenc_si128(d1,k0); - d2 = _mm_aesenc_si128(d2,k0); - d3 = _mm_aesenc_si128(d3,k0); - d0 = _mm_aesenc_si128(d0,k1); - d1 = _mm_aesenc_si128(d1,k1); - d2 = _mm_aesenc_si128(d2,k1); - d3 = _mm_aesenc_si128(d3,k1); - k0 = _aes._k.ni.k[12]; - k1 = _aes._k.ni.k[13]; - d0 = _mm_aesenc_si128(d0,k2); - d1 = _mm_aesenc_si128(d1,k2); - d2 = _mm_aesenc_si128(d2,k2); - d3 = _mm_aesenc_si128(d3,k2); - __m128i p0 = _mm_loadu_si128(reinterpret_cast(in)); - __m128i p1 = _mm_loadu_si128(reinterpret_cast(in + 16)); - d0 = _mm_aesenc_si128(d0,k3); - d1 = _mm_aesenc_si128(d1,k3); - d2 = _mm_aesenc_si128(d2,k3); - d3 = _mm_aesenc_si128(d3,k3); - k2 = _aes._k.ni.k[14]; - d0 = _mm_aesenc_si128(d0,k0); - d1 = _mm_aesenc_si128(d1,k0); - d2 = _mm_aesenc_si128(d2,k0); - d3 = _mm_aesenc_si128(d3,k0); - __m128i p2 = _mm_loadu_si128(reinterpret_cast(in + 32)); - __m128i p3 = _mm_loadu_si128(reinterpret_cast(in + 48)); - d0 = _mm_aesenc_si128(d0,k1); - d1 = _mm_aesenc_si128(d1,k1); - d2 = _mm_aesenc_si128(d2,k1); - d3 = _mm_aesenc_si128(d3,k1); - d0 = _mm_aesenclast_si128(d0,k2); - d1 = _mm_aesenclast_si128(d1,k2); - d2 = _mm_aesenclast_si128(d2,k2); - d3 = _mm_aesenclast_si128(d3,k2); - - p0 = _mm_xor_si128(d0,p0); - p1 = _mm_xor_si128(d1,p1); - p2 = _mm_xor_si128(d2,p2); - p3 = _mm_xor_si128(d3,p3); - _mm_storeu_si128(reinterpret_cast<__m128i *>(out),p0); - _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 16),p1); - _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 32),p2); - _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 48),p3); + __m128i kc = _aes._k.ni.k[8]; + d0 = _mm_aesenc_si128(d0,k4); + d1 = _mm_aesenc_si128(d1,k4); + d2 = _mm_aesenc_si128(d2,k4); + d3 = _mm_aesenc_si128(d3,k4); + __m128i kd = _aes._k.ni.k[9]; + d0 = _mm_aesenc_si128(d0,k5); + d1 = _mm_aesenc_si128(d1,k5); + d2 = _mm_aesenc_si128(d2,k5); + d3 = _mm_aesenc_si128(d3,k5); + __m128i ke = _aes._k.ni.k[10]; + d0 = _mm_aesenc_si128(d0,ka); + d1 = _mm_aesenc_si128(d1,ka); + d2 = _mm_aesenc_si128(d2,ka); + d3 = _mm_aesenc_si128(d3,ka); + __m128i kf = _aes._k.ni.k[11]; + d0 = _mm_aesenc_si128(d0,kb); + d1 = _mm_aesenc_si128(d1,kb); + d2 = _mm_aesenc_si128(d2,kb); + d3 = _mm_aesenc_si128(d3,kb); + ka = _aes._k.ni.k[12]; + d0 = _mm_aesenc_si128(d0,kc); + d1 = _mm_aesenc_si128(d1,kc); + d2 = _mm_aesenc_si128(d2,kc); + d3 = _mm_aesenc_si128(d3,kc); + kb = _aes._k.ni.k[13]; + d0 = _mm_aesenc_si128(d0,kd); + d1 = _mm_aesenc_si128(d1,kd); + d2 = _mm_aesenc_si128(d2,kd); + d3 = _mm_aesenc_si128(d3,kd); + kc = _aes._k.ni.k[14]; + d0 = _mm_aesenc_si128(d0,ke); + d1 = _mm_aesenc_si128(d1,ke); + d2 = _mm_aesenc_si128(d2,ke); + d3 = _mm_aesenc_si128(d3,ke); + kd = _mm_loadu_si128(reinterpret_cast(in)); + d0 = _mm_aesenc_si128(d0,kf); + d1 = _mm_aesenc_si128(d1,kf); + d2 = _mm_aesenc_si128(d2,kf); + d3 = _mm_aesenc_si128(d3,kf); + ke = _mm_loadu_si128(reinterpret_cast(in + 16)); + d0 = _mm_aesenc_si128(d0,ka); + d1 = _mm_aesenc_si128(d1,ka); + d2 = _mm_aesenc_si128(d2,ka); + d3 = _mm_aesenc_si128(d3,ka); + kf = _mm_loadu_si128(reinterpret_cast(in + 32)); + d0 = _mm_aesenc_si128(d0,kb); + d1 = _mm_aesenc_si128(d1,kb); + d2 = _mm_aesenc_si128(d2,kb); + d3 = _mm_aesenc_si128(d3,kb); + ka = _mm_loadu_si128(reinterpret_cast(in + 48)); + d0 = _mm_aesenclast_si128(d0,kc); + d1 = _mm_aesenclast_si128(d1,kc); + d2 = _mm_aesenclast_si128(d2,kc); + d3 = _mm_aesenclast_si128(d3,kc); + kd = _mm_xor_si128(d0,kd); + ke = _mm_xor_si128(d1,ke); + kf = _mm_xor_si128(d2,kf); + ka = _mm_xor_si128(d3,ka); + _mm_storeu_si128(reinterpret_cast<__m128i *>(out),kd); + _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 16),ke); + _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 32),kf); + _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 48),ka); in += 64; len -= 64; @@ -632,23 +644,31 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept while (len >= 16) { __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - d0 = _mm_xor_si128(d0,_aes._k.ni.k[0]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[1]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[2]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[3]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[4]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[5]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[6]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[7]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[8]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[9]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[10]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[11]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[12]); - d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[13]); - d0 = _mm_aesenclast_si128(d0,_aes._k.ni.k[14]); - + d0 = _mm_xor_si128(d0,k0); + d0 = _mm_aesenc_si128(d0,k1); + __m128i ka = _aes._k.ni.k[6]; + d0 = _mm_aesenc_si128(d0,k2); + __m128i kb = _aes._k.ni.k[7]; + d0 = _mm_aesenc_si128(d0,k3); + __m128i kc = _aes._k.ni.k[8]; + d0 = _mm_aesenc_si128(d0,k4); + __m128i kd = _aes._k.ni.k[9]; + d0 = _mm_aesenc_si128(d0,k5); + __m128i ke = _aes._k.ni.k[10]; + d0 = _mm_aesenc_si128(d0,ka); + __m128i kf = _aes._k.ni.k[11]; + d0 = _mm_aesenc_si128(d0,kb); + __m128i kg = _aes._k.ni.k[12]; + d0 = _mm_aesenc_si128(d0,kc); __m128i p0 = _mm_loadu_si128(reinterpret_cast(in)); + d0 = _mm_aesenc_si128(d0,kd); + __m128i kh = _aes._k.ni.k[13]; + d0 = _mm_aesenc_si128(d0,ke); + ka = _aes._k.ni.k[14]; + d0 = _mm_aesenc_si128(d0,kf); + d0 = _mm_aesenc_si128(d0,kg); + d0 = _mm_aesenc_si128(d0,kh); + d0 = _mm_aesenclast_si128(d0,ka); p0 = _mm_xor_si128(d0,p0); _mm_storeu_si128(reinterpret_cast<__m128i *>(out),p0); @@ -678,8 +698,10 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept unsigned int totalLen = _len; if ((totalLen & 15U)) { for (;;) { - if (!len) + if (!len) { + _len = (totalLen + len); return; + } --len; out[totalLen++] = *(in++); if (!(totalLen & 15U)) {