From 61b72d42b883aacc0db04e71507a3610fdf465d4 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Mon, 24 Feb 2020 13:30:35 -0800 Subject: [PATCH] More AES tweaks --- node/AES.cpp | 117 +++++++++++++++++++++++++-------------------------- 1 file changed, 57 insertions(+), 60 deletions(-) diff --git a/node/AES.cpp b/node/AES.cpp index 3df1474a4..e27fe904b 100644 --- a/node/AES.cpp +++ b/node/AES.cpp @@ -477,19 +477,26 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept uint64_t c0 = _ctr[0]; uint64_t c1 = Utils::ntoh(_ctr[1]); - // There are 16 XMM registers. We can reserve six of them for the - // first six parts of the expanded AES key. The rest are used for - // other key material, counter, or data depending on the chunk size. - const __m128i k0 = _aes._k.ni.k[0]; - const __m128i k1 = _aes._k.ni.k[1]; - const __m128i k2 = _aes._k.ni.k[2]; - const __m128i k3 = _aes._k.ni.k[3]; - const __m128i k4 = _aes._k.ni.k[4]; - const __m128i k5 = _aes._k.ni.k[5]; + // This uses some spare XMM registers to hold some of the key. + const __m128i *const k = _aes._k.ni.k; + const __m128i k0 = k[0]; + const __m128i k1 = k[1]; + const __m128i k2 = k[2]; + const __m128i k3 = k[3]; + const __m128i k4 = k[4]; + const __m128i k5 = k[5]; // Complete any unfinished blocks from previous calls to crypt(). unsigned int totalLen = _len; if ((totalLen & 15U)) { + const __m128i k7 = k[7]; + const __m128i k8 = k[8]; + const __m128i k9 = k[9]; + const __m128i k10 = k[10]; + const __m128i k11 = k[11]; + const __m128i k12 = k[12]; + const __m128i k13 = k[13]; + const __m128i k14 = k[14]; for (;;) { if (!len) { _ctr[0] = c0; @@ -503,30 +510,21 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); d0 = _mm_xor_si128(d0,k0); d0 = _mm_aesenc_si128(d0,k1); - __m128i ka = _aes._k.ni.k[6]; d0 = _mm_aesenc_si128(d0,k2); - __m128i kb = _aes._k.ni.k[7]; d0 = _mm_aesenc_si128(d0,k3); - __m128i kc = _aes._k.ni.k[8]; d0 = _mm_aesenc_si128(d0,k4); - __m128i kd = _aes._k.ni.k[9]; d0 = _mm_aesenc_si128(d0,k5); - __m128i ke = _aes._k.ni.k[10]; - d0 = _mm_aesenc_si128(d0,ka); - __m128i kf = _aes._k.ni.k[11]; - d0 = _mm_aesenc_si128(d0,kb); - __m128i kg = _aes._k.ni.k[12]; - d0 = _mm_aesenc_si128(d0,kc); - __m128i kh = _aes._k.ni.k[13]; - d0 = _mm_aesenc_si128(d0,kd); - ka = _aes._k.ni.k[14]; - d0 = _mm_aesenc_si128(d0,ke); + d0 = _mm_aesenc_si128(d0,k[6]); + d0 = _mm_aesenc_si128(d0,k7); + d0 = _mm_aesenc_si128(d0,k8); + d0 = _mm_aesenc_si128(d0,k9); + d0 = _mm_aesenc_si128(d0,k10); __m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16)); - d0 = _mm_aesenc_si128(d0,kf); + d0 = _mm_aesenc_si128(d0,k11); const __m128i p0 = _mm_loadu_si128(outblk); - d0 = _mm_aesenc_si128(d0,kg); - d0 = _mm_aesenc_si128(d0,kh); - d0 = _mm_aesenclast_si128(d0,ka); + d0 = _mm_aesenc_si128(d0,k12); + d0 = _mm_aesenc_si128(d0,k13); + d0 = _mm_aesenclast_si128(d0,k14); _mm_storeu_si128(outblk,_mm_xor_si128(p0,d0)); if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); break; @@ -564,47 +562,47 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept d1 = _mm_aesenc_si128(d1,k1); d2 = _mm_aesenc_si128(d2,k1); d3 = _mm_aesenc_si128(d3,k1); - __m128i ka = _aes._k.ni.k[6]; + __m128i ka = k[6]; d0 = _mm_aesenc_si128(d0,k2); d1 = _mm_aesenc_si128(d1,k2); d2 = _mm_aesenc_si128(d2,k2); d3 = _mm_aesenc_si128(d3,k2); - __m128i kb = _aes._k.ni.k[7]; + __m128i kb = k[7]; d0 = _mm_aesenc_si128(d0,k3); d1 = _mm_aesenc_si128(d1,k3); d2 = _mm_aesenc_si128(d2,k3); d3 = _mm_aesenc_si128(d3,k3); - __m128i kc = _aes._k.ni.k[8]; + __m128i kc = k[8]; d0 = _mm_aesenc_si128(d0,k4); d1 = _mm_aesenc_si128(d1,k4); d2 = _mm_aesenc_si128(d2,k4); d3 = _mm_aesenc_si128(d3,k4); - __m128i kd = _aes._k.ni.k[9]; + __m128i kd = k[9]; d0 = _mm_aesenc_si128(d0,k5); d1 = _mm_aesenc_si128(d1,k5); d2 = _mm_aesenc_si128(d2,k5); d3 = _mm_aesenc_si128(d3,k5); - __m128i ke = _aes._k.ni.k[10]; + __m128i ke = k[10]; d0 = _mm_aesenc_si128(d0,ka); d1 = _mm_aesenc_si128(d1,ka); d2 = _mm_aesenc_si128(d2,ka); d3 = _mm_aesenc_si128(d3,ka); - __m128i kf = _aes._k.ni.k[11]; + __m128i kf = k[11]; d0 = _mm_aesenc_si128(d0,kb); d1 = _mm_aesenc_si128(d1,kb); d2 = _mm_aesenc_si128(d2,kb); d3 = _mm_aesenc_si128(d3,kb); - ka = _aes._k.ni.k[12]; + ka = k[12]; d0 = _mm_aesenc_si128(d0,kc); d1 = _mm_aesenc_si128(d1,kc); d2 = _mm_aesenc_si128(d2,kc); d3 = _mm_aesenc_si128(d3,kc); - kb = _aes._k.ni.k[13]; + kb = k[13]; d0 = _mm_aesenc_si128(d0,kd); d1 = _mm_aesenc_si128(d1,kd); d2 = _mm_aesenc_si128(d2,kd); d3 = _mm_aesenc_si128(d3,kd); - kc = _aes._k.ni.k[14]; + kc = k[14]; d0 = _mm_aesenc_si128(d0,ke); d1 = _mm_aesenc_si128(d1,ke); d2 = _mm_aesenc_si128(d2,ke); @@ -644,41 +642,40 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept } { - __m128i ka = _aes._k.ni.k[6]; - __m128i kb = _aes._k.ni.k[7]; - const __m128i kc = _aes._k.ni.k[8]; - const __m128i kd = _aes._k.ni.k[9]; - const __m128i ke = _aes._k.ni.k[10]; - const __m128i kf = _aes._k.ni.k[11]; - const __m128i kg = _aes._k.ni.k[12]; - const __m128i kh = _aes._k.ni.k[13]; + const __m128i k7 = k[7]; + const __m128i k8 = k[8]; + const __m128i k9 = k[9]; + const __m128i k10 = k[10]; + const __m128i k11 = k[11]; + const __m128i k12 = k[12]; + const __m128i k13 = k[13]; + const __m128i k14 = k[14]; while (len >= 16) { - __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); + __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1++),(long long)c0); + if (unlikely(c1 == 0)) { + c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); + d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); + } d0 = _mm_xor_si128(d0,k0); d0 = _mm_aesenc_si128(d0,k1); d0 = _mm_aesenc_si128(d0,k2); d0 = _mm_aesenc_si128(d0,k3); d0 = _mm_aesenc_si128(d0,k4); d0 = _mm_aesenc_si128(d0,k5); - d0 = _mm_aesenc_si128(d0,ka); - d0 = _mm_aesenc_si128(d0,kb); - d0 = _mm_aesenc_si128(d0,kc); - d0 = _mm_aesenc_si128(d0,kd); - ka = _aes._k.ni.k[14]; - d0 = _mm_aesenc_si128(d0,ke); - d0 = _mm_aesenc_si128(d0,kf); - d0 = _mm_aesenc_si128(d0,kg); - d0 = _mm_aesenc_si128(d0,kh); - kb = _mm_loadu_si128(reinterpret_cast(in)); - d0 = _mm_aesenclast_si128(d0,ka); - kb = _mm_xor_si128(d0,kb); - _mm_storeu_si128(reinterpret_cast<__m128i *>(out),kb); + d0 = _mm_aesenc_si128(d0,k[6]); + d0 = _mm_aesenc_si128(d0,k7); + d0 = _mm_aesenc_si128(d0,k8); + d0 = _mm_aesenc_si128(d0,k9); + d0 = _mm_aesenc_si128(d0,k10); + d0 = _mm_aesenc_si128(d0,k11); + d0 = _mm_aesenc_si128(d0,k12); + d0 = _mm_aesenc_si128(d0,k13); + d0 = _mm_aesenclast_si128(d0,k14); + _mm_storeu_si128(reinterpret_cast<__m128i *>(out),_mm_xor_si128(d0,_mm_loadu_si128(reinterpret_cast(in)))); in += 16; len -= 16; out += 16; - - if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); } }