diff --git a/node/AES.cpp b/node/AES.cpp index e0e01c092..83c1d1676 100644 --- a/node/AES.cpp +++ b/node/AES.cpp @@ -18,22 +18,176 @@ namespace ZeroTier { // GMAC --------------------------------------------------------------------------------------------------------------- +namespace { + +#if (defined(__GNUC__) || defined(__clang)) && (defined(__amd64) || defined(__amd64__) || defined(__x86_64) || defined(__x86_64__) || defined(__AMD64) || defined(__AMD64__) || defined(_M_X64) || defined(__aarch64__)) + +#if defined(__SIZEOF_INT128__) +typedef unsigned __int128 uint128_t; +#else +typedef unsigned uint128_t __attribute__((mode(TI))); +#endif + +ZT_ALWAYS_INLINE void s_bmul64(const uint64_t x,const uint64_t y,uint64_t &r_high,uint64_t &r_low) noexcept +{ + static uint128_t m1 = (uint128_t)0x2108421084210842ULL << 64U | 0x1084210842108421ULL; + static uint128_t m2 = (uint128_t)0x4210842108421084ULL << 64U | 0x2108421084210842ULL; + static uint128_t m3 = (uint128_t)0x8421084210842108ULL << 64U | 0x4210842108421084ULL; + static uint128_t m4 = (uint128_t)0x0842108421084210ULL << 64U | 0x8421084210842108ULL; + static uint128_t m5 = (uint128_t)0x1084210842108421ULL << 64U | 0x0842108421084210ULL; + uint128_t x1 = x & m1; + uint128_t y1 = y & m1; + uint128_t x2 = x & m2; + uint128_t y2 = y & m2; + uint128_t x3 = x & m3; + uint128_t y3 = y & m3; + uint128_t x4 = x & m4; + uint128_t y4 = y & m4; + uint128_t x5 = x & m5; + uint128_t y5 = y & m5; + uint128_t z = (x1 * y1) ^ (x2 * y5) ^ (x3 * y4) ^ (x4 * y3) ^ (x5 * y2); + uint128_t r = z & m1; + z = (x1 * y2) ^ (x2 * y1) ^ (x3 * y5) ^ (x4 * y4) ^ (x5 * y3); + r |= z & m2; + z = (x1 * y3) ^ (x2 * y2) ^ (x3 * y1) ^ (x4 * y5) ^ (x5 * y4); + r |= z & m3; + z = (x1 * y4) ^ (x2 * y3) ^ (x3 * y2) ^ (x4 * y1) ^ (x5 * y5); + r |= z & m4; + z = (x1 * y5) ^ (x2 * y4) ^ (x3 * y3) ^ (x4 * y2) ^ (x5 * y1); + r |= z & m5; + r_high = (uint64_t)(r >> 64); + r_low = (uint64_t)r; +} + +ZT_ALWAYS_INLINE void s_gfmul(const uint64_t h_high,const uint64_t h_low,uint64_t &y0, uint64_t &y1) noexcept +{ + uint64_t z2_low,z2_high,z0_low,z0_high,z1a_low,z1a_high; + uint64_t y_high = Utils::ntoh(y0); + uint64_t y_low = Utils::ntoh(y1); + s_bmul64(y_high,h_high,z2_high,z2_low); + s_bmul64(y_low,h_low,z0_high,z0_low); + s_bmul64(y_high ^ y_low,h_high ^ h_low,z1a_high,z1a_low); + z1a_high ^= z2_high ^ z0_high; + z1a_low ^= z2_low ^ z0_low; + uint128_t z_high = ((uint128_t)z2_high << 64U) | (z2_low ^ z1a_high); + uint128_t z_low = (((uint128_t)z0_high << 64U) | z0_low) ^ (((uint128_t)z1a_low) << 64U); + z_high = (z_high << 1U) | (z_low >> 127U); + z_low <<= 1U; + z_low ^= (z_low << 127U) ^ (z_low << 126U) ^ (z_low << 121U); + z_high ^= z_low ^ (z_low >> 1U) ^ (z_low >> 2U) ^ (z_low >> 7U); + y1 = Utils::hton((uint64_t)z_high); + y0 = Utils::hton((uint64_t)(z_high >> 64U)); +} + +#else + +ZT_ALWAYS_INLINE void s_bmul32(uint32_t x,uint32_t y,uint32_t &r_high,uint32_t &r_low) noexcept +{ + const uint32_t m1 = (uint32_t)0x11111111; + const uint32_t m2 = (uint32_t)0x22222222; + const uint32_t m4 = (uint32_t)0x44444444; + const uint32_t m8 = (uint32_t)0x88888888; + uint32_t x0 = x & m1; + uint32_t x1 = x & m2; + uint32_t x2 = x & m4; + uint32_t x3 = x & m8; + uint32_t y0 = y & m1; + uint32_t y1 = y & m2; + uint32_t y2 = y & m4; + uint32_t y3 = y & m8; + uint64_t z0 = ((uint64_t)x0 * y0) ^ ((uint64_t)x1 * y3) ^ ((uint64_t)x2 * y2) ^ ((uint64_t)x3 * y1); + uint64_t z1 = ((uint64_t)x0 * y1) ^ ((uint64_t)x1 * y0) ^ ((uint64_t)x2 * y3) ^ ((uint64_t)x3 * y2); + uint64_t z2 = ((uint64_t)x0 * y2) ^ ((uint64_t)x1 * y1) ^ ((uint64_t)x2 * y0) ^ ((uint64_t)x3 * y3); + uint64_t z3 = ((uint64_t)x0 * y3) ^ ((uint64_t)x1 * y2) ^ ((uint64_t)x2 * y1) ^ ((uint64_t)x3 * y0); + z0 &= ((uint64_t)m1 << 32) | m1; + z1 &= ((uint64_t)m2 << 32) | m2; + z2 &= ((uint64_t)m4 << 32) | m4; + z3 &= ((uint64_t)m8 << 32) | m8; + uint64_t z = z0 | z1 | z2 | z3; + r_high = (uint32_t)(z >> 32); + r_low = (uint32_t)z; +} + +ZT_ALWAYS_INLINE void s_gfmul(const uint64_t h_high,const uint64_t h_low,uint64_t &y0,uint64_t &y1) noexcept +{ + uint32_t h_high_h = (uint32_t)(h_high >> 32); + uint32_t h_high_l = (uint32_t)h_high; + uint32_t h_low_h = (uint32_t)(h_low >> 32); + uint32_t h_low_l = (uint32_t)h_low; + uint32_t h_highXlow_h = h_high_h ^ h_low_h; + uint32_t h_highXlow_l = h_high_l ^ h_low_l; + uint64_t y_low = Utils::ntoh(y0); + uint64_t y_high = Utils::ntoh(y1); + uint32_t ci_low_h = (uint32_t)(y_high >> 32); + uint32_t ci_low_l = (uint32_t)y_high; + uint32_t ci_high_h = (uint32_t)(y_low >> 32); + uint32_t ci_high_l = (uint32_t)y_low; + uint32_t ci_highXlow_h = ci_high_h ^ ci_low_h; + uint32_t ci_highXlow_l = ci_high_l ^ ci_low_l; + uint32_t a_a_h,a_a_l,a_b_h,a_b_l,a_c_h,a_c_l; + s_bmul32(ci_high_h,h_high_h,a_a_h,a_a_l); + s_bmul32(ci_high_l,h_high_l,a_b_h,a_b_l); + s_bmul32(ci_high_h ^ ci_high_l,h_high_h ^ h_high_l,a_c_h,a_c_l); + a_c_h ^= a_a_h ^ a_b_h; + a_c_l ^= a_a_l ^ a_b_l; + a_a_l ^= a_c_h; + a_b_h ^= a_c_l; + uint32_t b_a_h,b_a_l,b_b_h,b_b_l,b_c_h,b_c_l; + s_bmul32(ci_low_h,h_low_h,b_a_h,b_a_l); + s_bmul32(ci_low_l,h_low_l,b_b_h,b_b_l); + s_bmul32(ci_low_h ^ ci_low_l,h_low_h ^ h_low_l,b_c_h,b_c_l); + b_c_h ^= b_a_h ^ b_b_h; + b_c_l ^= b_a_l ^ b_b_l; + b_a_l ^= b_c_h; + b_b_h ^= b_c_l; + uint32_t c_a_h,c_a_l,c_b_h,c_b_l,c_c_h,c_c_l; + s_bmul32(ci_highXlow_h,h_highXlow_h,c_a_h,c_a_l); + s_bmul32(ci_highXlow_l,h_highXlow_l,c_b_h,c_b_l); + s_bmul32(ci_highXlow_h ^ ci_highXlow_l, h_highXlow_h ^ h_highXlow_l,c_c_h,c_c_l); + c_c_h ^= c_a_h ^ c_b_h; + c_c_l ^= c_a_l ^ c_b_l; + c_a_l ^= c_c_h; + c_b_h ^= c_c_l; + c_a_h ^= b_a_h ^ a_a_h; + c_a_l ^= b_a_l ^ a_a_l; + c_b_h ^= b_b_h ^ a_b_h; + c_b_l ^= b_b_l ^ a_b_l; + uint64_t z_high_h = ((uint64_t)a_a_h << 32) | a_a_l; + uint64_t z_high_l = (((uint64_t)a_b_h << 32) | a_b_l) ^ (((uint64_t)c_a_h << 32) | c_a_l); + uint64_t z_low_h = (((uint64_t)b_a_h << 32) | b_a_l) ^ (((uint64_t)c_b_h << 32) | c_b_l); + uint64_t z_low_l = ((uint64_t)b_b_h << 32) | b_b_l; + z_high_h = z_high_h << 1 | z_high_l >> 63; + z_high_l = z_high_l << 1 | z_low_h >> 63; + z_low_h = z_low_h << 1 | z_low_l >> 63; + z_low_l <<= 1; + z_low_h ^= (z_low_l << 63) ^ (z_low_l << 62) ^ (z_low_l << 57); + z_high_h ^= z_low_h ^ (z_low_h >> 1) ^ (z_low_h >> 2) ^ (z_low_h >> 7); + z_high_l ^= z_low_l ^ (z_low_l >> 1) ^ (z_low_l >> 2) ^ (z_low_l >> 7) ^ (z_low_h << 63) ^ (z_low_h << 62) ^ (z_low_h << 57); + y0 = Utils::hton(z_high_h); + y1 = Utils::hton(z_high_l); +} + +#endif + +} // anonymous namespace + void AES::GMAC::update(const void *const data,unsigned int len) noexcept { + const uint8_t *in = reinterpret_cast(data); _len += len; #ifdef ZT_AES_AESNI if (likely(Utils::CPUID.aes)) { - const uint8_t *d = reinterpret_cast(data); __m128i y = _mm_loadu_si128(reinterpret_cast(_y)); const __m128i shuf = s_shuf; + // Handle anything left over from a previous run that wasn't a multiple of 16 bytes. if (_rp) { for(;;) { if (!len) return; --len; - _r[_rp++] = *(d++); + _r[_rp++] = *(in++); if (_rp == 16) { y = _mult_block_aesni(shuf,_aes._k.ni.h,_mm_xor_si128(y,_mm_loadu_si128(reinterpret_cast<__m128i *>(_r)))); break; @@ -42,12 +196,12 @@ void AES::GMAC::update(const void *const data,unsigned int len) noexcept } while (len >= 64) { - __m128i d1 = _mm_loadu_si128(reinterpret_cast(d)); - __m128i d2 = _mm_loadu_si128(reinterpret_cast(d + 16)); - __m128i d3 = _mm_loadu_si128(reinterpret_cast(d + 32)); - __m128i d4 = _mm_loadu_si128(reinterpret_cast(d + 48)); + __m128i d1 = _mm_loadu_si128(reinterpret_cast(in)); + __m128i d2 = _mm_loadu_si128(reinterpret_cast(in + 16)); + __m128i d3 = _mm_loadu_si128(reinterpret_cast(in + 32)); + __m128i d4 = _mm_loadu_si128(reinterpret_cast(in + 48)); - d += 64; + in += 64; len -= 64; // This does 4X parallel mult_block via instruction level parallelism. @@ -128,18 +282,56 @@ void AES::GMAC::update(const void *const data,unsigned int len) noexcept } while (len >= 16) { - y = _mult_block_aesni(shuf,_aes._k.ni.h,_mm_xor_si128(y,_mm_loadu_si128(reinterpret_cast(d)))); - d += 16; + y = _mult_block_aesni(shuf,_aes._k.ni.h,_mm_xor_si128(y,_mm_loadu_si128(reinterpret_cast(in)))); + in += 16; len -= 16; } + _mm_storeu_si128(reinterpret_cast<__m128i *>(_y),y); + + // Any overflow is cached for a later run or finish(). for(unsigned int i=0;i(_r); + y1 ^= Utils::loadAsIsEndian(_r + 8); + s_gfmul(h0,h1,y0,y1); + break; + } + } + } + + while (len >= 16) { + y0 ^= Utils::loadAsIsEndian(in); + y1 ^= Utils::loadAsIsEndian(in + 8); + s_gfmul(h0,h1,y0,y1); + in += 16; + len -= 16; + } + + for(unsigned int i=0;i(_y)); + // Handle any remaining bytes, padding the last block with zeroes. if (_rp) { while (_rp < 16) _r[_rp++] = 0; @@ -210,10 +403,39 @@ void AES::GMAC::finish(uint8_t tag[16]) noexcept t4 = _mm_xor_si128(t4,t2); t4 = _mm_xor_si128(t4,t3); t4 = _mm_xor_si128(t4,t5); + _mm_storeu_si128(reinterpret_cast<__m128i *>(tag),_mm_xor_si128(_mm_shuffle_epi8(t4,s_shuf),encIV)); + return; } #endif + + const uint64_t h0 = _aes._k.sw.h[0]; + const uint64_t h1 = _aes._k.sw.h[1]; + uint64_t y0 = _y[0]; + uint64_t y1 = _y[1]; + + if (_rp) { + while (_rp < 16) + _r[_rp++] = 0; + y0 ^= Utils::loadAsIsEndian(_r); + y1 ^= Utils::loadAsIsEndian(_r + 8); + s_gfmul(h0,h1,y0,y1); + } + + y0 ^= Utils::hton((uint64_t)_len << 3U); + s_gfmul(h0,h1,y0,y1); + + uint64_t iv2[2]; + for(unsigned int i=0;i<12;++i) ((uint8_t *)iv2)[i] = _iv[i]; + ((uint8_t *)iv2)[12] = 0; + ((uint8_t *)iv2)[13] = 0; + ((uint8_t *)iv2)[14] = 0; + ((uint8_t *)iv2)[15] = 1; + _aes._encryptSW((const uint8_t *)iv2,(uint8_t *)iv2); + + Utils::storeAsIsEndian(tag,iv2[0] ^ y0); + Utils::storeAsIsEndian(tag + 8,iv2[1] ^ y1); } // AES-CTR ------------------------------------------------------------------------------------------------------------ @@ -221,7 +443,7 @@ void AES::GMAC::finish(uint8_t tag[16]) noexcept void AES::CTR::crypt(const void *const input,unsigned int len) noexcept { const uint8_t *in = reinterpret_cast(input); - uint8_t *out = _out + _len; + uint8_t *out = _out; #ifdef ZT_AES_AESNI if (likely(Utils::CPUID.aes)) { @@ -230,19 +452,20 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept _mm_prefetch(in + 128,_MM_HINT_T0); uint64_t c0 = _ctr[0]; - uint64_t c1 = _ctr[1]; + uint64_t c1 = Utils::ntoh(_ctr[1]); // Complete any unfinished blocks from previous calls to crypt(). - if ((_len & 15U) != 0) { + unsigned int totalLen = _len; + if ((totalLen & 15U)) { for (;;) { if (!len) { _ctr[0] = c0; - _ctr[1] = c1; + _ctr[1] = Utils::hton(c1); return; } --len; - _out[_len++] = *(in++); - if ((_len & 15U) == 0) { + out[totalLen++] = *(in++); + if (!(totalLen & 15U)) { __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); d0 = _mm_xor_si128(d0,_aes._k.ni.k[0]); d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[1]); @@ -259,16 +482,16 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[12]); d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[13]); d0 = _mm_aesenclast_si128(d0,_aes._k.ni.k[14]); - __m128i *const outblk = reinterpret_cast<__m128i *>(_out - 16); + __m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16)); _mm_storeu_si128(outblk,_mm_xor_si128(_mm_loadu_si128(outblk),d0)); - c0 += (uint64_t)((++c1) == 0ULL); - out += 16; + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); break; } } } - _len += len; + out += totalLen; + _len = (totalLen + len); // This is the largest chunk size that will fit in SSE registers with four // registers left over for round key data and temporaries. @@ -278,7 +501,7 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept _mm_prefetch(in + 320,_MM_HINT_T0); __m128i d0,d1,d2,d3,d4,d5,d6,d7,d8,d9,d10,d11; - if (likely((c1 + 12ULL) > c1)) { + if (likely(c1 < 0xfffffffffffffff4ULL)) { d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); d1 = _mm_set_epi64x((long long)Utils::hton(c1 + 1ULL),(long long)c0); d2 = _mm_set_epi64x((long long)Utils::hton(c1 + 2ULL),(long long)c0); @@ -294,73 +517,99 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept c1 += 12; } else { d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d1 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d2 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d3 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d4 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d5 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d6 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d7 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d8 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d9 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d10 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d11 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); } { - __m128i k = _aes._k.ni.k[0]; - d0 = _mm_xor_si128(d0,k); - d1 = _mm_xor_si128(d1,k); - d2 = _mm_xor_si128(d2,k); - d3 = _mm_xor_si128(d3,k); - d4 = _mm_xor_si128(d4,k); - d5 = _mm_xor_si128(d5,k); - d6 = _mm_xor_si128(d6,k); - d7 = _mm_xor_si128(d7,k); - d8 = _mm_xor_si128(d8,k); - d9 = _mm_xor_si128(d9,k); - d10 = _mm_xor_si128(d10,k); - d11 = _mm_xor_si128(d11,k); - for (int r = 1; r < 14; ++r) { - k = _aes._k.ni.k[r]; - d0 = _mm_aesenc_si128(d0,k); - d1 = _mm_aesenc_si128(d1,k); - d2 = _mm_aesenc_si128(d2,k); - d3 = _mm_aesenc_si128(d3,k); - d4 = _mm_aesenc_si128(d4,k); - d5 = _mm_aesenc_si128(d5,k); - d6 = _mm_aesenc_si128(d6,k); - d7 = _mm_aesenc_si128(d7,k); - d8 = _mm_aesenc_si128(d8,k); - d9 = _mm_aesenc_si128(d9,k); - d10 = _mm_aesenc_si128(d10,k); - d11 = _mm_aesenc_si128(d11,k); + __m128i k0 = _aes._k.ni.k[0]; + __m128i k1 = _aes._k.ni.k[1]; + d0 = _mm_xor_si128(d0,k0); + d1 = _mm_xor_si128(d1,k0); + d2 = _mm_xor_si128(d2,k0); + d3 = _mm_xor_si128(d3,k0); + d4 = _mm_xor_si128(d4,k0); + d5 = _mm_xor_si128(d5,k0); + d6 = _mm_xor_si128(d6,k0); + d7 = _mm_xor_si128(d7,k0); + d8 = _mm_xor_si128(d8,k0); + d9 = _mm_xor_si128(d9,k0); + d10 = _mm_xor_si128(d10,k0); + d11 = _mm_xor_si128(d11,k0); + d0 = _mm_aesenc_si128(d0,k1); + d1 = _mm_aesenc_si128(d1,k1); + d2 = _mm_aesenc_si128(d2,k1); + d3 = _mm_aesenc_si128(d3,k1); + d4 = _mm_aesenc_si128(d4,k1); + d5 = _mm_aesenc_si128(d5,k1); + d6 = _mm_aesenc_si128(d6,k1); + d7 = _mm_aesenc_si128(d7,k1); + d8 = _mm_aesenc_si128(d8,k1); + d9 = _mm_aesenc_si128(d9,k1); + d10 = _mm_aesenc_si128(d10,k1); + d11 = _mm_aesenc_si128(d11,k1); + for (int r=2;r<14;r+=2) { + k0 = _aes._k.ni.k[r]; + k1 = _aes._k.ni.k[r+1]; + d0 = _mm_aesenc_si128(d0,k0); + d1 = _mm_aesenc_si128(d1,k0); + d2 = _mm_aesenc_si128(d2,k0); + d3 = _mm_aesenc_si128(d3,k0); + d4 = _mm_aesenc_si128(d4,k0); + d5 = _mm_aesenc_si128(d5,k0); + d6 = _mm_aesenc_si128(d6,k0); + d7 = _mm_aesenc_si128(d7,k0); + d8 = _mm_aesenc_si128(d8,k0); + d9 = _mm_aesenc_si128(d9,k0); + d10 = _mm_aesenc_si128(d10,k0); + d11 = _mm_aesenc_si128(d11,k0); + d0 = _mm_aesenc_si128(d0,k1); + d1 = _mm_aesenc_si128(d1,k1); + d2 = _mm_aesenc_si128(d2,k1); + d3 = _mm_aesenc_si128(d3,k1); + d4 = _mm_aesenc_si128(d4,k1); + d5 = _mm_aesenc_si128(d5,k1); + d6 = _mm_aesenc_si128(d6,k1); + d7 = _mm_aesenc_si128(d7,k1); + d8 = _mm_aesenc_si128(d8,k1); + d9 = _mm_aesenc_si128(d9,k1); + d10 = _mm_aesenc_si128(d10,k1); + d11 = _mm_aesenc_si128(d11,k1); } - k = _aes._k.ni.k[14]; - d0 = _mm_aesenclast_si128(d0,k); - d1 = _mm_aesenclast_si128(d1,k); - d2 = _mm_aesenclast_si128(d2,k); - d3 = _mm_aesenclast_si128(d3,k); - d4 = _mm_aesenclast_si128(d4,k); - d5 = _mm_aesenclast_si128(d5,k); - d6 = _mm_aesenclast_si128(d6,k); - d7 = _mm_aesenclast_si128(d7,k); - d8 = _mm_aesenclast_si128(d8,k); - d9 = _mm_aesenclast_si128(d9,k); - d10 = _mm_aesenclast_si128(d10,k); - d11 = _mm_aesenclast_si128(d11,k); + k0 = _aes._k.ni.k[14]; + d0 = _mm_aesenclast_si128(d0,k0); + d1 = _mm_aesenclast_si128(d1,k0); + d2 = _mm_aesenclast_si128(d2,k0); + d3 = _mm_aesenclast_si128(d3,k0); + d4 = _mm_aesenclast_si128(d4,k0); + d5 = _mm_aesenclast_si128(d5,k0); + d6 = _mm_aesenclast_si128(d6,k0); + d7 = _mm_aesenclast_si128(d7,k0); + d8 = _mm_aesenclast_si128(d8,k0); + d9 = _mm_aesenclast_si128(d9,k0); + d10 = _mm_aesenclast_si128(d10,k0); + d11 = _mm_aesenclast_si128(d11,k0); } { @@ -376,19 +625,18 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 16),p1); _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 32),p2); _mm_storeu_si128(reinterpret_cast<__m128i *>(out + 48),p3); - p0 = _mm_loadu_si128(reinterpret_cast(in + 64)); p1 = _mm_loadu_si128(reinterpret_cast(in + 80)); p2 = _mm_loadu_si128(reinterpret_cast(in + 96)); p3 = _mm_loadu_si128(reinterpret_cast(in + 112)); - d0 = _mm_loadu_si128(reinterpret_cast(in + 128)); - d1 = _mm_loadu_si128(reinterpret_cast(in + 144)); - d2 = _mm_loadu_si128(reinterpret_cast(in + 160)); - d3 = _mm_loadu_si128(reinterpret_cast(in + 176)); p0 = _mm_xor_si128(d4,p0); p1 = _mm_xor_si128(d5,p1); p2 = _mm_xor_si128(d6,p2); p3 = _mm_xor_si128(d7,p3); + d0 = _mm_loadu_si128(reinterpret_cast(in + 128)); + d1 = _mm_loadu_si128(reinterpret_cast(in + 144)); + d2 = _mm_loadu_si128(reinterpret_cast(in + 160)); + d3 = _mm_loadu_si128(reinterpret_cast(in + 176)); d0 = _mm_xor_si128(d8,d0); d1 = _mm_xor_si128(d9,d1); d2 = _mm_xor_si128(d10,d2); @@ -410,7 +658,7 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept while (_len >= 64) { __m128i d0,d1,d2,d3; - if (likely((c1 + 4ULL) > c1)) { + if (likely(c1 < 0xfffffffffffffffcULL)) { d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); d1 = _mm_set_epi64x((long long)Utils::hton(c1 + 1ULL),(long long)c0); d2 = _mm_set_epi64x((long long)Utils::hton(c1 + 2ULL),(long long)c0); @@ -418,33 +666,43 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept c1 += 4; } else { d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d1 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d2 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); d3 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); } { - __m128i k = _aes._k.ni.k[0]; - d0 = _mm_xor_si128(d0,k); - d1 = _mm_xor_si128(d1,k); - d2 = _mm_xor_si128(d2,k); - d3 = _mm_xor_si128(d3,k); - for (int r = 1; r < 14; ++r) { - k = _aes._k.ni.k[r]; - d0 = _mm_aesenc_si128(d0,k); - d1 = _mm_aesenc_si128(d1,k); - d2 = _mm_aesenc_si128(d2,k); - d3 = _mm_aesenc_si128(d3,k); + __m128i k0 = _aes._k.ni.k[0]; + __m128i k1 = _aes._k.ni.k[1]; + d0 = _mm_xor_si128(d0,k0); + d1 = _mm_xor_si128(d1,k0); + d2 = _mm_xor_si128(d2,k0); + d3 = _mm_xor_si128(d3,k0); + d0 = _mm_xor_si128(d0,k1); + d1 = _mm_xor_si128(d1,k1); + d2 = _mm_xor_si128(d2,k1); + d3 = _mm_xor_si128(d3,k1); + for (int r=2;r<14;r+=2) { + k0 = _aes._k.ni.k[r]; + k1 = _aes._k.ni.k[r+1]; + d0 = _mm_aesenc_si128(d0,k0); + d1 = _mm_aesenc_si128(d1,k0); + d2 = _mm_aesenc_si128(d2,k0); + d3 = _mm_aesenc_si128(d3,k0); + d0 = _mm_aesenc_si128(d0,k1); + d1 = _mm_aesenc_si128(d1,k1); + d2 = _mm_aesenc_si128(d2,k1); + d3 = _mm_aesenc_si128(d3,k1); } - k = _aes._k.ni.k[14]; - d0 = _mm_aesenclast_si128(d0,k); - d1 = _mm_aesenclast_si128(d1,k); - d2 = _mm_aesenclast_si128(d2,k); - d3 = _mm_aesenclast_si128(d3,k); + k0 = _aes._k.ni.k[14]; + d0 = _mm_aesenclast_si128(d0,k0); + d1 = _mm_aesenclast_si128(d1,k0); + d2 = _mm_aesenclast_si128(d2,k0); + d3 = _mm_aesenclast_si128(d3,k0); } __m128i p0 = _mm_loadu_si128(reinterpret_cast(in)); @@ -467,8 +725,6 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept while (len >= 16) { __m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0); - c0 += (uint64_t)((++c1) == 0ULL); - d0 = _mm_xor_si128(d0,_aes._k.ni.k[0]); d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[1]); d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[2]); @@ -492,6 +748,8 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept in += 16; len -= 16; out += 16; + + if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL); } // Any remaining input is placed in _out. This will be picked up and crypted @@ -503,26 +761,68 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept } _ctr[0] = c0; - _ctr[1] = c1; + _ctr[1] = Utils::hton(c1); return; } #endif + + uint8_t keyStream[16]; + + unsigned int totalLen = _len; + if ((totalLen & 15U)) { + for (;;) { + if (!len) + return; + --len; + out[totalLen++] = *(in++); + if (!(totalLen & 15U)) { + _aes._encryptSW(reinterpret_cast(_ctr),keyStream); + uint8_t *outblk = out + (totalLen - 16); + for(int i=0;i<16;++i) + outblk[i] ^= keyStream[i]; + if (unlikely((_ctr[1] = Utils::hton(Utils::ntoh(_ctr[1]) + 1ULL)) == 0)) _ctr[0] = Utils::hton(Utils::ntoh(_ctr[0]) + 1ULL); + break; + } + } + } + + out += totalLen; + _len = (totalLen + len); + + while (len >= 16) { + _aes._encryptSW(reinterpret_cast(_ctr),keyStream); + for(int i=0;i<16;++i) + out[i] = in[i] ^ keyStream[i]; + out += 16; + len -= 16; + in += 16; + if (unlikely((_ctr[1] = Utils::hton(Utils::ntoh(_ctr[1]) + 1ULL)) == 0)) _ctr[0] = Utils::hton(Utils::ntoh(_ctr[0]) + 1ULL); + } + + // Any remaining input is placed in _out. This will be picked up and crypted + // on subsequent calls to crypt() or finish() as it'll mean _len will not be + // an even multiple of 16. + while (len) { + --len; + *(out++) = *(in++); + } } void AES::CTR::finish() noexcept { + const unsigned int rem = _len & 15U; + #ifdef ZT_AES_AESNI if (likely(Utils::CPUID.aes)) { // Encrypt any remaining bytes as indicated by _len not being an even multiple of 16. - const unsigned int rem = _len & 15U; if (rem) { uint8_t tmp[16]; - for (unsigned int i = 0,j = _len - rem; i < rem; ++i) + for (unsigned int i = 0,j = _len - rem;i < rem;++i) tmp[i] = _out[j]; - for (unsigned int i = rem; i < 16; ++i) + for (unsigned int i = rem;i < 16;++i) tmp[i] = 0; - __m128i d0 = _mm_set_epi64x((long long)Utils::hton(_ctr[1]),(long long)_ctr[0]); + __m128i d0 = _mm_loadu_si128(reinterpret_cast(_ctr)); d0 = _mm_xor_si128(d0,_aes._k.ni.k[0]); d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[1]); d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[2]); @@ -540,12 +840,23 @@ void AES::CTR::finish() noexcept d0 = _mm_aesenclast_si128(d0,_aes._k.ni.k[14]); _mm_storeu_si128(reinterpret_cast<__m128i *>(tmp),_mm_xor_si128(_mm_loadu_si128(reinterpret_cast<__m128i *>(tmp)),d0)); - for (unsigned int i = 0,j = _len - rem; i < rem; ++i) + for (unsigned int i = 0,j = _len - rem;i < rem;++i) _out[j] = tmp[i]; } return; } #endif + + if (rem) { + uint8_t tmp[16],keyStream[16]; + for (unsigned int i = 0,j = _len - rem;i < rem;++i) + tmp[i] = _out[j]; + for (unsigned int i = rem;i < 16;++i) + tmp[i] = 0; + _aes._encryptSW(reinterpret_cast(_ctr),keyStream); + for (unsigned int i = 0,j = _len - rem;i < rem;++i) + _out[j] = tmp[i] ^ keyStream[i]; + } } // Software AES and AES key expansion --------------------------------------------------------------------------------- @@ -770,203 +1081,6 @@ void AES::_decryptSW(const uint8_t in[16],uint8_t out[16]) const noexcept writeuint32_t(out + 12,(Td4[(t3 >> 24)] << 24) ^ (Td4[(t2 >> 16) & 0xff] << 16) ^ (Td4[(t1 >> 8) & 0xff] << 8) ^ (Td4[(t0) & 0xff]) ^ rk[3]); } -#if (defined(__GNUC__) || defined(__clang)) && (defined(__amd64) || defined(__amd64__) || defined(__x86_64) || defined(__x86_64__) || defined(__AMD64) || defined(__AMD64__) || defined(_M_X64) || defined(__aarch64__)) - -#if defined(__SIZEOF_INT128__) -typedef unsigned __int128 uint128_t; -#else -typedef unsigned uint128_t __attribute__((mode(TI))); -#endif - -static ZT_ALWAYS_INLINE void s_bmul64(const uint64_t x,const uint64_t y,uint64_t &r_high,uint64_t &r_low) noexcept -{ - static uint128_t m1 = (uint128_t)0x2108421084210842ULL << 64U | 0x1084210842108421ULL; - static uint128_t m2 = (uint128_t)0x4210842108421084ULL << 64U | 0x2108421084210842ULL; - static uint128_t m3 = (uint128_t)0x8421084210842108ULL << 64U | 0x4210842108421084ULL; - static uint128_t m4 = (uint128_t)0x0842108421084210ULL << 64U | 0x8421084210842108ULL; - static uint128_t m5 = (uint128_t)0x1084210842108421ULL << 64U | 0x0842108421084210ULL; - uint128_t x1 = x & m1; - uint128_t y1 = y & m1; - uint128_t x2 = x & m2; - uint128_t y2 = y & m2; - uint128_t x3 = x & m3; - uint128_t y3 = y & m3; - uint128_t x4 = x & m4; - uint128_t y4 = y & m4; - uint128_t x5 = x & m5; - uint128_t y5 = y & m5; - uint128_t z = (x1 * y1) ^ (x2 * y5) ^ (x3 * y4) ^ (x4 * y3) ^ (x5 * y2); - uint128_t r = z & m1; - z = (x1 * y2) ^ (x2 * y1) ^ (x3 * y5) ^ (x4 * y4) ^ (x5 * y3); - r |= z & m2; - z = (x1 * y3) ^ (x2 * y2) ^ (x3 * y1) ^ (x4 * y5) ^ (x5 * y4); - r |= z & m3; - z = (x1 * y4) ^ (x2 * y3) ^ (x3 * y2) ^ (x4 * y1) ^ (x5 * y5); - r |= z & m4; - z = (x1 * y5) ^ (x2 * y4) ^ (x3 * y3) ^ (x4 * y2) ^ (x5 * y1); - r |= z & m5; - r_high = (uint64_t)(r >> 64); - r_low = (uint64_t)r; -} - -static ZT_ALWAYS_INLINE void s_gfmul(const uint64_t h_high,const uint64_t h_low,uint64_t &y0, uint64_t &y1) noexcept -{ - uint64_t z2_low,z2_high,z0_low,z0_high,z1a_low,z1a_high; - uint64_t y_high = Utils::ntoh(y0); - uint64_t y_low = Utils::ntoh(y1); - s_bmul64(y_high,h_high,z2_high,z2_low); - s_bmul64(y_low,h_low,z0_high,z0_low); - s_bmul64(y_high ^ y_low,h_high ^ h_low,z1a_high,z1a_low); - z1a_high ^= z2_high ^ z0_high; - z1a_low ^= z2_low ^ z0_low; - uint128_t z_high = ((uint128_t)z2_high << 64U) | (z2_low ^ z1a_high); - uint128_t z_low = (((uint128_t)z0_high << 64U) | z0_low) ^ (((uint128_t)z1a_low) << 64U); - z_high = (z_high << 1U) | (z_low >> 127U); - z_low <<= 1U; - z_low ^= (z_low << 127U) ^ (z_low << 126U) ^ (z_low << 121U); - z_high ^= z_low ^ (z_low >> 1U) ^ (z_low >> 2U) ^ (z_low >> 7U); - y1 = Utils::hton((uint64_t)z_high); - y0 = Utils::hton((uint64_t)(z_high >> 64U)); -} - -#else - -static ZT_ALWAYS_INLINE void s_bmul32(uint32_t x,uint32_t y,uint32_t &r_high,uint32_t &r_low) noexcept -{ - const uint32_t m1 = (uint32_t)0x11111111; - const uint32_t m2 = (uint32_t)0x22222222; - const uint32_t m4 = (uint32_t)0x44444444; - const uint32_t m8 = (uint32_t)0x88888888; - uint32_t x0 = x & m1; - uint32_t x1 = x & m2; - uint32_t x2 = x & m4; - uint32_t x3 = x & m8; - uint32_t y0 = y & m1; - uint32_t y1 = y & m2; - uint32_t y2 = y & m4; - uint32_t y3 = y & m8; - uint64_t z0 = ((uint64_t)x0 * y0) ^ ((uint64_t)x1 * y3) ^ ((uint64_t)x2 * y2) ^ ((uint64_t)x3 * y1); - uint64_t z1 = ((uint64_t)x0 * y1) ^ ((uint64_t)x1 * y0) ^ ((uint64_t)x2 * y3) ^ ((uint64_t)x3 * y2); - uint64_t z2 = ((uint64_t)x0 * y2) ^ ((uint64_t)x1 * y1) ^ ((uint64_t)x2 * y0) ^ ((uint64_t)x3 * y3); - uint64_t z3 = ((uint64_t)x0 * y3) ^ ((uint64_t)x1 * y2) ^ ((uint64_t)x2 * y1) ^ ((uint64_t)x3 * y0); - z0 &= ((uint64_t)m1 << 32) | m1; - z1 &= ((uint64_t)m2 << 32) | m2; - z2 &= ((uint64_t)m4 << 32) | m4; - z3 &= ((uint64_t)m8 << 32) | m8; - uint64_t z = z0 | z1 | z2 | z3; - r_high = (uint32_t)(z >> 32); - r_low = (uint32_t)z; -} - -static ZT_ALWAYS_INLINE void s_gfmul(const uint64_t h_high,const uint64_t h_low,uint64_t &y0,uint64_t &y1) noexcept -{ - uint32_t h_high_h = (uint32_t)(h_high >> 32); - uint32_t h_high_l = (uint32_t)h_high; - uint32_t h_low_h = (uint32_t)(h_low >> 32); - uint32_t h_low_l = (uint32_t)h_low; - uint32_t h_highXlow_h = h_high_h ^ h_low_h; - uint32_t h_highXlow_l = h_high_l ^ h_low_l; - uint64_t y_low = Utils::ntoh(y0); - uint64_t y_high = Utils::ntoh(y1); - uint32_t ci_low_h = (uint32_t)(y_high >> 32); - uint32_t ci_low_l = (uint32_t)y_high; - uint32_t ci_high_h = (uint32_t)(y_low >> 32); - uint32_t ci_high_l = (uint32_t)y_low; - uint32_t ci_highXlow_h = ci_high_h ^ ci_low_h; - uint32_t ci_highXlow_l = ci_high_l ^ ci_low_l; - uint32_t a_a_h,a_a_l,a_b_h,a_b_l,a_c_h,a_c_l; - s_bmul32(ci_high_h,h_high_h,a_a_h,a_a_l); - s_bmul32(ci_high_l,h_high_l,a_b_h,a_b_l); - s_bmul32(ci_high_h ^ ci_high_l,h_high_h ^ h_high_l,a_c_h,a_c_l); - a_c_h ^= a_a_h ^ a_b_h; - a_c_l ^= a_a_l ^ a_b_l; - a_a_l ^= a_c_h; - a_b_h ^= a_c_l; - uint32_t b_a_h,b_a_l,b_b_h,b_b_l,b_c_h,b_c_l; - s_bmul32(ci_low_h,h_low_h,b_a_h,b_a_l); - s_bmul32(ci_low_l,h_low_l,b_b_h,b_b_l); - s_bmul32(ci_low_h ^ ci_low_l,h_low_h ^ h_low_l,b_c_h,b_c_l); - b_c_h ^= b_a_h ^ b_b_h; - b_c_l ^= b_a_l ^ b_b_l; - b_a_l ^= b_c_h; - b_b_h ^= b_c_l; - uint32_t c_a_h,c_a_l,c_b_h,c_b_l,c_c_h,c_c_l; - s_bmul32(ci_highXlow_h,h_highXlow_h,c_a_h,c_a_l); - s_bmul32(ci_highXlow_l,h_highXlow_l,c_b_h,c_b_l); - s_bmul32(ci_highXlow_h ^ ci_highXlow_l, h_highXlow_h ^ h_highXlow_l,c_c_h,c_c_l); - c_c_h ^= c_a_h ^ c_b_h; - c_c_l ^= c_a_l ^ c_b_l; - c_a_l ^= c_c_h; - c_b_h ^= c_c_l; - c_a_h ^= b_a_h ^ a_a_h; - c_a_l ^= b_a_l ^ a_a_l; - c_b_h ^= b_b_h ^ a_b_h; - c_b_l ^= b_b_l ^ a_b_l; - uint64_t z_high_h = ((uint64_t)a_a_h << 32) | a_a_l; - uint64_t z_high_l = (((uint64_t)a_b_h << 32) | a_b_l) ^ (((uint64_t)c_a_h << 32) | c_a_l); - uint64_t z_low_h = (((uint64_t)b_a_h << 32) | b_a_l) ^ (((uint64_t)c_b_h << 32) | c_b_l); - uint64_t z_low_l = ((uint64_t)b_b_h << 32) | b_b_l; - z_high_h = z_high_h << 1 | z_high_l >> 63; - z_high_l = z_high_l << 1 | z_low_h >> 63; - z_low_h = z_low_h << 1 | z_low_l >> 63; - z_low_l <<= 1; - z_low_h ^= (z_low_l << 63) ^ (z_low_l << 62) ^ (z_low_l << 57); - z_high_h ^= z_low_h ^ (z_low_h >> 1) ^ (z_low_h >> 2) ^ (z_low_h >> 7); - z_high_l ^= z_low_l ^ (z_low_l >> 1) ^ (z_low_l >> 2) ^ (z_low_l >> 7) ^ (z_low_h << 63) ^ (z_low_h << 62) ^ (z_low_h << 57); - y0 = Utils::hton(z_high_h); - y1 = Utils::hton(z_high_l); -} - -#endif - -void AES::_gmacSW(const uint8_t iv[12],const uint8_t *in,unsigned int len,uint8_t out[16]) const noexcept -{ - const uint64_t h0 = _k.sw.h[0]; - const uint64_t h1 = _k.sw.h[1]; - const uint64_t lpad = Utils::hton((uint64_t)len << 3U); - uint64_t y0 = 0,y1 = 0; - - while (len >= 16) { -#ifdef ZT_NO_UNALIGNED_ACCESS - for(unsigned int i=0;i<8;++i) ((uint8_t *)&y0)[i] ^= *(in++); - for(unsigned int i=0;i<8;++i) ((uint8_t *)&y1)[i] ^= *(in++); -#else - y0 ^= *((const uint64_t *)in); - y1 ^= *((const uint64_t *)(in + 8)); - in += 16; -#endif - s_gfmul(h0,h1,y0,y1); - len -= 16; - } - - if (len) { - uint64_t last[2] = { 0,0 }; - for(unsigned int i=0;i #if (defined(__amd64) || defined(__amd64__) || defined(__x86_64) || defined(__x86_64__) || defined(__AMD64) || defined(__AMD64__) || defined(_M_X64)) -#include #include #include #include +#include #define ZT_AES_AESNI 1 #endif @@ -33,12 +33,37 @@ namespace ZeroTier { /** * AES-256 and pals including GMAC, CTR, etc. + * + * This includes hardware acceleration for certain processors. The software + * mode is fallback and is significantly slower. */ class AES { public: + /** + * @return True if this system has hardware AES acceleration + */ + static ZT_ALWAYS_INLINE bool accelerated() + { +#ifdef ZT_AES_AESNI + return Utils::CPUID.aes; +#else + return false; +#endif + } + + /** + * Create an un-initialized AES instance (must call init() before use) + */ ZT_ALWAYS_INLINE AES() noexcept {} + + /** + * Create an AES instance with the given key + * + * @param key 256-bit key + */ explicit ZT_ALWAYS_INLINE AES(const uint8_t key[32]) noexcept { this->init(key); } + ZT_ALWAYS_INLINE ~AES() { Utils::burn(&_k,sizeof(_k)); } /** @@ -102,14 +127,22 @@ public: */ ZT_ALWAYS_INLINE GMAC(const AES &aes) : _aes(aes) {} + /** + * Reset and initialize for a new GMAC calculation + * + * @param iv 96-bit initialization vector (pad with zeroes if actual IV is shorter) + */ ZT_ALWAYS_INLINE void init(const uint8_t iv[12]) noexcept { _rp = 0; _len = 0; + // We fill the least significant 32 bits in the _iv field with 1 since in GCM mode + // this would hold the counter, but we're not doing GCM. The counter is therefore + // always 1. #ifdef ZT_AES_AESNI // also implies an x64 processor *reinterpret_cast(_iv) = *reinterpret_cast(iv); *reinterpret_cast(_iv + 8) = *reinterpret_cast(iv + 8); - *reinterpret_cast(_iv + 12) = 0x01000000; // 00000001 in big-endian byte order + *reinterpret_cast(_iv + 12) = 0x01000000; // 0x00000001 in big-endian byte order #else for(int i=0;i<12;++i) _iv[i] = iv[i]; @@ -122,8 +155,21 @@ public: _y[1] = 0; } + /** + * Process data through GMAC + * + * @param data Bytes to process + * @param len Length of input + */ void update(const void *data,unsigned int len) noexcept; + /** + * Process any remaining cached bytes and generate tag + * + * Don't call finish() more than once or you'll get an invalid result. + * + * @param tag 128-bit GMAC tag (can be truncated) + */ void finish(uint8_t tag[16]) noexcept; private: @@ -149,16 +195,9 @@ public: * @param iv Unique initialization vector * @param output Buffer to which to store output (MUST be large enough for total bytes processed!) */ - ZT_ALWAYS_INLINE void init(const uint8_t iv[16],void *output) noexcept + ZT_ALWAYS_INLINE void init(const uint8_t iv[16],void *const output) noexcept { -#ifdef ZT_AES_AESNI // also implies an x64 processor - _ctr[0] = Utils::ntoh(*reinterpret_cast(iv)); - _ctr[1] = Utils::ntoh(*reinterpret_cast(iv + 8)); -#else memcpy(_ctr,iv,16); - _ctr[0] = Utils::ntoh(_ctr[0]); - _ctr[1] = Utils::ntoh(_ctr[1]); -#endif _out = reinterpret_cast(output); _len = 0; } @@ -173,6 +212,8 @@ public: /** * Finish any remaining bytes if total bytes processed wasn't a multiple of 16 + * + * Don't call more than once for a given stream or data may be corrupted. */ void finish() noexcept; @@ -199,7 +240,6 @@ private: void _initSW(const uint8_t key[32]) noexcept; void _encryptSW(const uint8_t in[16],uint8_t out[16]) const noexcept; void _decryptSW(const uint8_t in[16],uint8_t out[16]) const noexcept; - void _gmacSW(const uint8_t iv[12],const uint8_t *in,unsigned int len,uint8_t out[16]) const noexcept; union { #ifdef ZT_AES_AESNI @@ -216,7 +256,6 @@ private: } sw; } _k; - #ifdef ZT_AES_AESNI static const __m128i s_shuf; diff --git a/node/Utils.cpp b/node/Utils.cpp index a4627abd0..39f205fed 100644 --- a/node/Utils.cpp +++ b/node/Utils.cpp @@ -54,7 +54,7 @@ CPUIDRegisters::CPUIDRegisters() rdrand = ((ecx & (1U << 30U)) != 0); aes = ( ((ecx & (1U << 25U)) != 0) && ((ecx & (1U << 19U)) != 0) && ((ecx & (1U << 1U)) != 0) ); // AES, PCLMUL, SSE4.1 } -CPUIDRegisters CPUID; +const CPUIDRegisters CPUID; #endif const uint64_t ZERO256[4] = { 0,0,0,0 }; @@ -71,8 +71,24 @@ bool secureEq(const void *a,const void *b,unsigned int len) noexcept // Crazy hack to force memory to be securely zeroed in spite of the best efforts of optimizing compilers. static void _Utils_doBurn(volatile uint8_t *ptr,unsigned int len) { - volatile uint8_t *const end = ptr + len; - while (ptr != end) *(ptr++) = (uint8_t)0; +#ifndef ZT_NO_UNALIGNED_ACCESS + const uint64_t z = 0; + while (len >= 32) { + *reinterpret_cast(ptr) = z; + *reinterpret_cast(ptr + 8) = z; + *reinterpret_cast(ptr + 16) = z; + *reinterpret_cast(ptr + 24) = z; + ptr += 32; + len -= 32; + } + while (len >= 8) { + *reinterpret_cast(ptr) = z; + ptr += 8; + len -= 8; + } +#endif + for(unsigned int i=0;i -static ZT_ALWAYS_INLINE I loadAsIsEndian(const void *const p) noexcept -{ -#ifdef ZT_NO_UNALIGNED_ACCESS - I x = (I)0; - for(unsigned int k=0;k(&x)[k] = reinterpret_cast(p)[k]; - return x; -#else - return *reinterpret_cast(p); -#endif -} - /** * Save an integer in big-endian format * @@ -478,6 +458,44 @@ static ZT_ALWAYS_INLINE void storeBigEndian(void *const p,const I i) noexcept #endif } +/** + * Copy bits from memory into an integer type without modifying their order + * + * @tparam I Type to load + * @param p Byte stream, must be at least sizeof(I) in size + * @return Loaded raw integer + */ +template +static ZT_ALWAYS_INLINE I loadAsIsEndian(const void *const p) noexcept +{ +#ifdef ZT_NO_UNALIGNED_ACCESS + I x = (I)0; + for(unsigned int k=0;k(&x)[k] = reinterpret_cast(p)[k]; + return x; +#else + return *reinterpret_cast(p); +#endif +} + +/** + * Copy bits from memory into an integer type without modifying their order + * + * @tparam I Type to store + * @param p Byte array (must be at least sizeof(I)) + * @param i Integer to store + */ +template +static ZT_ALWAYS_INLINE void storeAsIsEndian(void *const p,const I i) noexcept +{ +#ifdef ZT_NO_UNALIGNED_ACCESS + for(unsigned int k=0;k(p)[k] = reinterpret_cast(&i)[k]; +#else + *reinterpret_cast(p) = i; +#endif +} + } // namespace Utils } // namespace ZeroTier