More AES tweaks

This commit is contained in:
Adam Ierymenko 2020-02-24 13:30:35 -08:00
parent 56bf504ec2
commit 61b72d42b8
No known key found for this signature in database
GPG key ID: C8877CF2D7A5D7F3

View file

@ -477,19 +477,26 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
uint64_t c0 = _ctr[0]; uint64_t c0 = _ctr[0];
uint64_t c1 = Utils::ntoh(_ctr[1]); uint64_t c1 = Utils::ntoh(_ctr[1]);
// There are 16 XMM registers. We can reserve six of them for the // This uses some spare XMM registers to hold some of the key.
// first six parts of the expanded AES key. The rest are used for const __m128i *const k = _aes._k.ni.k;
// other key material, counter, or data depending on the chunk size. const __m128i k0 = k[0];
const __m128i k0 = _aes._k.ni.k[0]; const __m128i k1 = k[1];
const __m128i k1 = _aes._k.ni.k[1]; const __m128i k2 = k[2];
const __m128i k2 = _aes._k.ni.k[2]; const __m128i k3 = k[3];
const __m128i k3 = _aes._k.ni.k[3]; const __m128i k4 = k[4];
const __m128i k4 = _aes._k.ni.k[4]; const __m128i k5 = k[5];
const __m128i k5 = _aes._k.ni.k[5];
// Complete any unfinished blocks from previous calls to crypt(). // Complete any unfinished blocks from previous calls to crypt().
unsigned int totalLen = _len; unsigned int totalLen = _len;
if ((totalLen & 15U)) { if ((totalLen & 15U)) {
const __m128i k7 = k[7];
const __m128i k8 = k[8];
const __m128i k9 = k[9];
const __m128i k10 = k[10];
const __m128i k11 = k[11];
const __m128i k12 = k[12];
const __m128i k13 = k[13];
const __m128i k14 = k[14];
for (;;) { for (;;) {
if (!len) { if (!len) {
_ctr[0] = c0; _ctr[0] = c0;
@ -503,30 +510,21 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
d0 = _mm_xor_si128(d0,k0); d0 = _mm_xor_si128(d0,k0);
d0 = _mm_aesenc_si128(d0,k1); d0 = _mm_aesenc_si128(d0,k1);
__m128i ka = _aes._k.ni.k[6];
d0 = _mm_aesenc_si128(d0,k2); d0 = _mm_aesenc_si128(d0,k2);
__m128i kb = _aes._k.ni.k[7];
d0 = _mm_aesenc_si128(d0,k3); d0 = _mm_aesenc_si128(d0,k3);
__m128i kc = _aes._k.ni.k[8];
d0 = _mm_aesenc_si128(d0,k4); d0 = _mm_aesenc_si128(d0,k4);
__m128i kd = _aes._k.ni.k[9];
d0 = _mm_aesenc_si128(d0,k5); d0 = _mm_aesenc_si128(d0,k5);
__m128i ke = _aes._k.ni.k[10]; d0 = _mm_aesenc_si128(d0,k[6]);
d0 = _mm_aesenc_si128(d0,ka); d0 = _mm_aesenc_si128(d0,k7);
__m128i kf = _aes._k.ni.k[11]; d0 = _mm_aesenc_si128(d0,k8);
d0 = _mm_aesenc_si128(d0,kb); d0 = _mm_aesenc_si128(d0,k9);
__m128i kg = _aes._k.ni.k[12]; d0 = _mm_aesenc_si128(d0,k10);
d0 = _mm_aesenc_si128(d0,kc);
__m128i kh = _aes._k.ni.k[13];
d0 = _mm_aesenc_si128(d0,kd);
ka = _aes._k.ni.k[14];
d0 = _mm_aesenc_si128(d0,ke);
__m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16)); __m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16));
d0 = _mm_aesenc_si128(d0,kf); d0 = _mm_aesenc_si128(d0,k11);
const __m128i p0 = _mm_loadu_si128(outblk); const __m128i p0 = _mm_loadu_si128(outblk);
d0 = _mm_aesenc_si128(d0,kg); d0 = _mm_aesenc_si128(d0,k12);
d0 = _mm_aesenc_si128(d0,kh); d0 = _mm_aesenc_si128(d0,k13);
d0 = _mm_aesenclast_si128(d0,ka); d0 = _mm_aesenclast_si128(d0,k14);
_mm_storeu_si128(outblk,_mm_xor_si128(p0,d0)); _mm_storeu_si128(outblk,_mm_xor_si128(p0,d0));
if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
break; break;
@ -564,47 +562,47 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
d1 = _mm_aesenc_si128(d1,k1); d1 = _mm_aesenc_si128(d1,k1);
d2 = _mm_aesenc_si128(d2,k1); d2 = _mm_aesenc_si128(d2,k1);
d3 = _mm_aesenc_si128(d3,k1); d3 = _mm_aesenc_si128(d3,k1);
__m128i ka = _aes._k.ni.k[6]; __m128i ka = k[6];
d0 = _mm_aesenc_si128(d0,k2); d0 = _mm_aesenc_si128(d0,k2);
d1 = _mm_aesenc_si128(d1,k2); d1 = _mm_aesenc_si128(d1,k2);
d2 = _mm_aesenc_si128(d2,k2); d2 = _mm_aesenc_si128(d2,k2);
d3 = _mm_aesenc_si128(d3,k2); d3 = _mm_aesenc_si128(d3,k2);
__m128i kb = _aes._k.ni.k[7]; __m128i kb = k[7];
d0 = _mm_aesenc_si128(d0,k3); d0 = _mm_aesenc_si128(d0,k3);
d1 = _mm_aesenc_si128(d1,k3); d1 = _mm_aesenc_si128(d1,k3);
d2 = _mm_aesenc_si128(d2,k3); d2 = _mm_aesenc_si128(d2,k3);
d3 = _mm_aesenc_si128(d3,k3); d3 = _mm_aesenc_si128(d3,k3);
__m128i kc = _aes._k.ni.k[8]; __m128i kc = k[8];
d0 = _mm_aesenc_si128(d0,k4); d0 = _mm_aesenc_si128(d0,k4);
d1 = _mm_aesenc_si128(d1,k4); d1 = _mm_aesenc_si128(d1,k4);
d2 = _mm_aesenc_si128(d2,k4); d2 = _mm_aesenc_si128(d2,k4);
d3 = _mm_aesenc_si128(d3,k4); d3 = _mm_aesenc_si128(d3,k4);
__m128i kd = _aes._k.ni.k[9]; __m128i kd = k[9];
d0 = _mm_aesenc_si128(d0,k5); d0 = _mm_aesenc_si128(d0,k5);
d1 = _mm_aesenc_si128(d1,k5); d1 = _mm_aesenc_si128(d1,k5);
d2 = _mm_aesenc_si128(d2,k5); d2 = _mm_aesenc_si128(d2,k5);
d3 = _mm_aesenc_si128(d3,k5); d3 = _mm_aesenc_si128(d3,k5);
__m128i ke = _aes._k.ni.k[10]; __m128i ke = k[10];
d0 = _mm_aesenc_si128(d0,ka); d0 = _mm_aesenc_si128(d0,ka);
d1 = _mm_aesenc_si128(d1,ka); d1 = _mm_aesenc_si128(d1,ka);
d2 = _mm_aesenc_si128(d2,ka); d2 = _mm_aesenc_si128(d2,ka);
d3 = _mm_aesenc_si128(d3,ka); d3 = _mm_aesenc_si128(d3,ka);
__m128i kf = _aes._k.ni.k[11]; __m128i kf = k[11];
d0 = _mm_aesenc_si128(d0,kb); d0 = _mm_aesenc_si128(d0,kb);
d1 = _mm_aesenc_si128(d1,kb); d1 = _mm_aesenc_si128(d1,kb);
d2 = _mm_aesenc_si128(d2,kb); d2 = _mm_aesenc_si128(d2,kb);
d3 = _mm_aesenc_si128(d3,kb); d3 = _mm_aesenc_si128(d3,kb);
ka = _aes._k.ni.k[12]; ka = k[12];
d0 = _mm_aesenc_si128(d0,kc); d0 = _mm_aesenc_si128(d0,kc);
d1 = _mm_aesenc_si128(d1,kc); d1 = _mm_aesenc_si128(d1,kc);
d2 = _mm_aesenc_si128(d2,kc); d2 = _mm_aesenc_si128(d2,kc);
d3 = _mm_aesenc_si128(d3,kc); d3 = _mm_aesenc_si128(d3,kc);
kb = _aes._k.ni.k[13]; kb = k[13];
d0 = _mm_aesenc_si128(d0,kd); d0 = _mm_aesenc_si128(d0,kd);
d1 = _mm_aesenc_si128(d1,kd); d1 = _mm_aesenc_si128(d1,kd);
d2 = _mm_aesenc_si128(d2,kd); d2 = _mm_aesenc_si128(d2,kd);
d3 = _mm_aesenc_si128(d3,kd); d3 = _mm_aesenc_si128(d3,kd);
kc = _aes._k.ni.k[14]; kc = k[14];
d0 = _mm_aesenc_si128(d0,ke); d0 = _mm_aesenc_si128(d0,ke);
d1 = _mm_aesenc_si128(d1,ke); d1 = _mm_aesenc_si128(d1,ke);
d2 = _mm_aesenc_si128(d2,ke); d2 = _mm_aesenc_si128(d2,ke);
@ -644,41 +642,40 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
} }
{ {
__m128i ka = _aes._k.ni.k[6]; const __m128i k7 = k[7];
__m128i kb = _aes._k.ni.k[7]; const __m128i k8 = k[8];
const __m128i kc = _aes._k.ni.k[8]; const __m128i k9 = k[9];
const __m128i kd = _aes._k.ni.k[9]; const __m128i k10 = k[10];
const __m128i ke = _aes._k.ni.k[10]; const __m128i k11 = k[11];
const __m128i kf = _aes._k.ni.k[11]; const __m128i k12 = k[12];
const __m128i kg = _aes._k.ni.k[12]; const __m128i k13 = k[13];
const __m128i kh = _aes._k.ni.k[13]; const __m128i k14 = k[14];
while (len >= 16) { while (len >= 16) {
__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1++),(long long)c0);
if (unlikely(c1 == 0)) {
c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
}
d0 = _mm_xor_si128(d0,k0); d0 = _mm_xor_si128(d0,k0);
d0 = _mm_aesenc_si128(d0,k1); d0 = _mm_aesenc_si128(d0,k1);
d0 = _mm_aesenc_si128(d0,k2); d0 = _mm_aesenc_si128(d0,k2);
d0 = _mm_aesenc_si128(d0,k3); d0 = _mm_aesenc_si128(d0,k3);
d0 = _mm_aesenc_si128(d0,k4); d0 = _mm_aesenc_si128(d0,k4);
d0 = _mm_aesenc_si128(d0,k5); d0 = _mm_aesenc_si128(d0,k5);
d0 = _mm_aesenc_si128(d0,ka); d0 = _mm_aesenc_si128(d0,k[6]);
d0 = _mm_aesenc_si128(d0,kb); d0 = _mm_aesenc_si128(d0,k7);
d0 = _mm_aesenc_si128(d0,kc); d0 = _mm_aesenc_si128(d0,k8);
d0 = _mm_aesenc_si128(d0,kd); d0 = _mm_aesenc_si128(d0,k9);
ka = _aes._k.ni.k[14]; d0 = _mm_aesenc_si128(d0,k10);
d0 = _mm_aesenc_si128(d0,ke); d0 = _mm_aesenc_si128(d0,k11);
d0 = _mm_aesenc_si128(d0,kf); d0 = _mm_aesenc_si128(d0,k12);
d0 = _mm_aesenc_si128(d0,kg); d0 = _mm_aesenc_si128(d0,k13);
d0 = _mm_aesenc_si128(d0,kh); d0 = _mm_aesenclast_si128(d0,k14);
kb = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in)); _mm_storeu_si128(reinterpret_cast<__m128i *>(out),_mm_xor_si128(d0,_mm_loadu_si128(reinterpret_cast<const __m128i *>(in))));
d0 = _mm_aesenclast_si128(d0,ka);
kb = _mm_xor_si128(d0,kb);
_mm_storeu_si128(reinterpret_cast<__m128i *>(out),kb);
in += 16; in += 16;
len -= 16; len -= 16;
out += 16; out += 16;
if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
} }
} }