Use 6x blocks for ARMv8 AES rather than 4x

We gain 0.1 to 0.3 cpb, depending on the mode
pull/507/head
Jeffrey Walton 2017-09-14 20:32:06 -04:00
parent 51752cb91a
commit 25efb7a140
No known key found for this signature in database
GPG Key ID: B36AB348921B1838
1 changed files with 111 additions and 87 deletions

View File

@ -168,167 +168,173 @@ const word32 s_one[] = {0, 0, 0, 1}; // uint32x4_t
inline void ARMV8_Enc_Block(uint8x16_t &block, const word32 *subkeys, unsigned int rounds) inline void ARMV8_Enc_Block(uint8x16_t &block, const word32 *subkeys, unsigned int rounds)
{ {
CRYPTOPP_ASSERT(subkeys); CRYPTOPP_ASSERT(subkeys);
CRYPTOPP_ASSERT(rounds >= 9);
const byte *keys = reinterpret_cast<const byte*>(subkeys); const byte *keys = reinterpret_cast<const byte*>(subkeys);
// Unroll the loop, profit 0.3 to 0.5 cpb. // AES single round encryption
block = vaeseq_u8(block, vld1q_u8(keys+0)); block = vaeseq_u8(block, vld1q_u8(keys+0*16));
block = vaesmcq_u8(block); // AES mix columns
block = vaeseq_u8(block, vld1q_u8(keys+16));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+32));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+48));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+64));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+80));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+96));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+112));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+128));
block = vaesmcq_u8(block); block = vaesmcq_u8(block);
unsigned int i=9; for (unsigned int i=1; i<rounds-1; i+=2)
for ( ; i<rounds-1; ++i)
{ {
// AES single round encryption // AES single round encryption
block = vaeseq_u8(block, vld1q_u8(keys+i*16)); block = vaeseq_u8(block, vld1q_u8(keys+i*16));
// AES mix columns // AES mix columns
block = vaesmcq_u8(block); block = vaesmcq_u8(block);
// AES single round encryption
block = vaeseq_u8(block, vld1q_u8(keys+(i+1)*16));
// AES mix columns
block = vaesmcq_u8(block);
} }
// AES single round encryption // AES single round encryption
block = vaeseq_u8(block, vld1q_u8(keys+i*16)); block = vaeseq_u8(block, vld1q_u8(keys+(rounds-1)*16));
// Final Add (bitwise Xor) // Final Add (bitwise Xor)
block = veorq_u8(block, vld1q_u8(keys+(i+1)*16)); block = veorq_u8(block, vld1q_u8(keys+rounds*16));
} }
inline void ARMV8_Enc_4_Blocks(uint8x16_t &block0, uint8x16_t &block1, uint8x16_t &block2, inline void ARMV8_Enc_6_Blocks(uint8x16_t &block0, uint8x16_t &block1, uint8x16_t &block2,
uint8x16_t &block3, const word32 *subkeys, unsigned int rounds) uint8x16_t &block3, uint8x16_t &block4, uint8x16_t &block5,
const word32 *subkeys, unsigned int rounds)
{ {
CRYPTOPP_ASSERT(subkeys); CRYPTOPP_ASSERT(subkeys);
const byte *keys = reinterpret_cast<const byte*>(subkeys); const byte *keys = reinterpret_cast<const byte*>(subkeys);
uint8x16_t key;
unsigned int i=0; for (unsigned int i=0; i<rounds-1; ++i)
for ( ; i<rounds-1; ++i)
{ {
uint8x16_t key = vld1q_u8(keys+i*16);
// AES single round encryption // AES single round encryption
block0 = vaeseq_u8(block0, vld1q_u8(keys+i*16)); block0 = vaeseq_u8(block0, key);
// AES mix columns // AES mix columns
block0 = vaesmcq_u8(block0); block0 = vaesmcq_u8(block0);
// AES single round encryption // AES single round encryption
block1 = vaeseq_u8(block1, vld1q_u8(keys+i*16)); block1 = vaeseq_u8(block1, key);
// AES mix columns // AES mix columns
block1 = vaesmcq_u8(block1); block1 = vaesmcq_u8(block1);
// AES single round encryption // AES single round encryption
block2 = vaeseq_u8(block2, vld1q_u8(keys+i*16)); block2 = vaeseq_u8(block2, key);
// AES mix columns // AES mix columns
block2 = vaesmcq_u8(block2); block2 = vaesmcq_u8(block2);
// AES single round encryption // AES single round encryption
block3 = vaeseq_u8(block3, vld1q_u8(keys+i*16)); block3 = vaeseq_u8(block3, key);
// AES mix columns // AES mix columns
block3 = vaesmcq_u8(block3); block3 = vaesmcq_u8(block3);
// AES single round encryption
block4 = vaeseq_u8(block4, key);
// AES mix columns
block4 = vaesmcq_u8(block4);
// AES single round encryption
block5 = vaeseq_u8(block5, key);
// AES mix columns
block5 = vaesmcq_u8(block5);
} }
// AES single round encryption // AES single round encryption
block0 = vaeseq_u8(block0, vld1q_u8(keys+i*16)); key = vld1q_u8(keys+(rounds-1)*16);
block1 = vaeseq_u8(block1, vld1q_u8(keys+i*16)); block0 = vaeseq_u8(block0, key);
block2 = vaeseq_u8(block2, vld1q_u8(keys+i*16)); block1 = vaeseq_u8(block1, key);
block3 = vaeseq_u8(block3, vld1q_u8(keys+i*16)); block2 = vaeseq_u8(block2, key);
block3 = vaeseq_u8(block3, key);
block4 = vaeseq_u8(block4, key);
block5 = vaeseq_u8(block5, key);
// Final Add (bitwise Xor) // Final Add (bitwise Xor)
block0 = veorq_u8(block0, vld1q_u8(keys+(i+1)*16)); key = vld1q_u8(keys+rounds*16);
block1 = veorq_u8(block1, vld1q_u8(keys+(i+1)*16)); block0 = veorq_u8(block0, key);
block2 = veorq_u8(block2, vld1q_u8(keys+(i+1)*16)); block1 = veorq_u8(block1, key);
block3 = veorq_u8(block3, vld1q_u8(keys+(i+1)*16)); block2 = veorq_u8(block2, key);
block3 = veorq_u8(block3, key);
block4 = veorq_u8(block4, key);
block5 = veorq_u8(block5, key);
} }
inline void ARMV8_Dec_Block(uint8x16_t &block, const word32 *subkeys, unsigned int rounds) inline void ARMV8_Dec_Block(uint8x16_t &block, const word32 *subkeys, unsigned int rounds)
{ {
CRYPTOPP_ASSERT(subkeys); CRYPTOPP_ASSERT(subkeys);
CRYPTOPP_ASSERT(rounds >= 9);
const byte *keys = reinterpret_cast<const byte*>(subkeys); const byte *keys = reinterpret_cast<const byte*>(subkeys);
// Unroll the loop, profit 0.3 to 0.5 cpb. // AES single round decryption
block = vaesdq_u8(block, vld1q_u8(keys+0)); block = vaesdq_u8(block, vld1q_u8(keys+0*16));
block = vaesimcq_u8(block); // AES inverse mix columns
block = vaesdq_u8(block, vld1q_u8(keys+16));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+32));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+48));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+64));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+80));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+96));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+112));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+128));
block = vaesimcq_u8(block); block = vaesimcq_u8(block);
unsigned int i=9; for (unsigned int i=1; i<rounds-1; i+=2)
for ( ; i<rounds-1; ++i)
{ {
// AES single round decryption // AES single round decryption
block = vaesdq_u8(block, vld1q_u8(keys+i*16)); block = vaesdq_u8(block, vld1q_u8(keys+i*16));
// AES inverse mix columns // AES inverse mix columns
block = vaesimcq_u8(block); block = vaesimcq_u8(block);
// AES single round decryption
block = vaesdq_u8(block, vld1q_u8(keys+(i+1)*16));
// AES inverse mix columns
block = vaesimcq_u8(block);
} }
// AES single round decryption // AES single round decryption
block = vaesdq_u8(block, vld1q_u8(keys+i*16)); block = vaesdq_u8(block, vld1q_u8(keys+(rounds-1)*16));
// Final Add (bitwise Xor) // Final Add (bitwise Xor)
block = veorq_u8(block, vld1q_u8(keys+(i+1)*16)); block = veorq_u8(block, vld1q_u8(keys+rounds*16));
} }
inline void ARMV8_Dec_4_Blocks(uint8x16_t &block0, uint8x16_t &block1, uint8x16_t &block2, inline void ARMV8_Dec_6_Blocks(uint8x16_t &block0, uint8x16_t &block1, uint8x16_t &block2,
uint8x16_t &block3, const word32 *subkeys, unsigned int rounds) uint8x16_t &block3, uint8x16_t &block4, uint8x16_t &block5,
const word32 *subkeys, unsigned int rounds)
{ {
CRYPTOPP_ASSERT(subkeys); CRYPTOPP_ASSERT(subkeys);
const byte *keys = reinterpret_cast<const byte*>(subkeys); const byte *keys = reinterpret_cast<const byte*>(subkeys);
unsigned int i=0; uint8x16_t key;
for ( ; i<rounds-1; ++i) for (unsigned int i=0; i<rounds-1; ++i)
{ {
key = vld1q_u8(keys+i*16);
// AES single round decryption // AES single round decryption
block0 = vaesdq_u8(block0, vld1q_u8(keys+i*16)); block0 = vaesdq_u8(block0, key);
// AES inverse mix columns // AES inverse mix columns
block0 = vaesimcq_u8(block0); block0 = vaesimcq_u8(block0);
// AES single round decryption // AES single round decryption
block1 = vaesdq_u8(block1, vld1q_u8(keys+i*16)); block1 = vaesdq_u8(block1, key);
// AES inverse mix columns // AES inverse mix columns
block1 = vaesimcq_u8(block1); block1 = vaesimcq_u8(block1);
// AES single round decryption // AES single round decryption
block2 = vaesdq_u8(block2, vld1q_u8(keys+i*16)); block2 = vaesdq_u8(block2, key);
// AES inverse mix columns // AES inverse mix columns
block2 = vaesimcq_u8(block2); block2 = vaesimcq_u8(block2);
// AES single round decryption // AES single round decryption
block3 = vaesdq_u8(block3, vld1q_u8(keys+i*16)); block3 = vaesdq_u8(block3, key);
// AES inverse mix columns // AES inverse mix columns
block3 = vaesimcq_u8(block3); block3 = vaesimcq_u8(block3);
// AES single round decryption
block4 = vaesdq_u8(block4, key);
// AES inverse mix columns
block4 = vaesimcq_u8(block4);
// AES single round decryption
block5 = vaesdq_u8(block5, key);
// AES inverse mix columns
block5 = vaesimcq_u8(block5);
} }
// AES single round decryption // AES single round decryption
block0 = vaesdq_u8(block0, vld1q_u8(keys+i*16)); key = vld1q_u8(keys+(rounds-1)*16);
block1 = vaesdq_u8(block1, vld1q_u8(keys+i*16)); block0 = vaesdq_u8(block0, key);
block2 = vaesdq_u8(block2, vld1q_u8(keys+i*16)); block1 = vaesdq_u8(block1, key);
block3 = vaesdq_u8(block3, vld1q_u8(keys+i*16)); block2 = vaesdq_u8(block2, key);
block3 = vaesdq_u8(block3, key);
block4 = vaesdq_u8(block4, key);
block5 = vaesdq_u8(block5, key);
// Final Add (bitwise Xor) // Final Add (bitwise Xor)
block0 = veorq_u8(block0, vld1q_u8(keys+(i+1)*16)); key = vld1q_u8(keys+rounds*16);
block1 = veorq_u8(block1, vld1q_u8(keys+(i+1)*16)); block0 = veorq_u8(block0, key);
block2 = veorq_u8(block2, vld1q_u8(keys+(i+1)*16)); block1 = veorq_u8(block1, key);
block3 = veorq_u8(block3, vld1q_u8(keys+(i+1)*16)); block2 = veorq_u8(block2, key);
block3 = veorq_u8(block3, key);
block4 = veorq_u8(block4, key);
block5 = veorq_u8(block5, key);
} }
template <typename F1, typename F4> template <typename F1, typename F6>
size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *subKeys, size_t rounds, size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F6 func6, const word32 *subKeys, size_t rounds,
const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags)
{ {
CRYPTOPP_ASSERT(subKeys); CRYPTOPP_ASSERT(subKeys);
@ -353,9 +359,9 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
if (flags & BlockTransformation::BT_AllowParallel) if (flags & BlockTransformation::BT_AllowParallel)
{ {
while (length >= 4*blockSize) while (length >= 6*blockSize)
{ {
uint8x16_t block0, block1, block2, block3, temp; uint8x16_t block0, block1, block2, block3, block4, block5, temp;
block0 = vld1q_u8(inBlocks); block0 = vld1q_u8(inBlocks);
if (flags & BlockTransformation::BT_InBlockIsCounter) if (flags & BlockTransformation::BT_InBlockIsCounter)
@ -364,7 +370,9 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
block1 = vaddq_u8(block0, vreinterpretq_u8_u32(be)); block1 = vaddq_u8(block0, vreinterpretq_u8_u32(be));
block2 = vaddq_u8(block1, vreinterpretq_u8_u32(be)); block2 = vaddq_u8(block1, vreinterpretq_u8_u32(be));
block3 = vaddq_u8(block2, vreinterpretq_u8_u32(be)); block3 = vaddq_u8(block2, vreinterpretq_u8_u32(be));
temp = vaddq_u8(block3, vreinterpretq_u8_u32(be)); block4 = vaddq_u8(block3, vreinterpretq_u8_u32(be));
block5 = vaddq_u8(block4, vreinterpretq_u8_u32(be));
temp = vaddq_u8(block5, vreinterpretq_u8_u32(be));
vst1q_u8(const_cast<byte*>(inBlocks), temp); vst1q_u8(const_cast<byte*>(inBlocks), temp);
} }
else else
@ -376,6 +384,10 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
inBlocks += inIncrement; inBlocks += inIncrement;
block3 = vld1q_u8(inBlocks); block3 = vld1q_u8(inBlocks);
inBlocks += inIncrement; inBlocks += inIncrement;
block4 = vld1q_u8(inBlocks);
inBlocks += inIncrement;
block5 = vld1q_u8(inBlocks);
inBlocks += inIncrement;
} }
if (flags & BlockTransformation::BT_XorInput) if (flags & BlockTransformation::BT_XorInput)
@ -388,9 +400,13 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
xorBlocks += xorIncrement; xorBlocks += xorIncrement;
block3 = veorq_u8(block3, vld1q_u8(xorBlocks)); block3 = veorq_u8(block3, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement; xorBlocks += xorIncrement;
block4 = veorq_u8(block4, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
block5 = veorq_u8(block5, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
} }
func4(block0, block1, block2, block3, subKeys, rounds); func6(block0, block1, block2, block3, block4, block5, subKeys, rounds);
if (xorBlocks && !(flags & BlockTransformation::BT_XorInput)) if (xorBlocks && !(flags & BlockTransformation::BT_XorInput))
{ {
@ -402,6 +418,10 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
xorBlocks += xorIncrement; xorBlocks += xorIncrement;
block3 = veorq_u8(block3, vld1q_u8(xorBlocks)); block3 = veorq_u8(block3, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement; xorBlocks += xorIncrement;
block4 = veorq_u8(block4, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
block5 = veorq_u8(block5, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
} }
vst1q_u8(outBlocks, block0); vst1q_u8(outBlocks, block0);
@ -412,8 +432,12 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
outBlocks += outIncrement; outBlocks += outIncrement;
vst1q_u8(outBlocks, block3); vst1q_u8(outBlocks, block3);
outBlocks += outIncrement; outBlocks += outIncrement;
vst1q_u8(outBlocks, block4);
outBlocks += outIncrement;
vst1q_u8(outBlocks, block5);
outBlocks += outIncrement;
length -= 4*blockSize; length -= 6*blockSize;
} }
} }
@ -446,14 +470,14 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
size_t Rijndael_Enc_AdvancedProcessBlocks_ARMV8(const word32 *subKeys, size_t rounds, size_t Rijndael_Enc_AdvancedProcessBlocks_ARMV8(const word32 *subKeys, size_t rounds,
const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags)
{ {
return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Enc_Block, ARMV8_Enc_4_Blocks, return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Enc_Block, ARMV8_Enc_6_Blocks,
subKeys, rounds, inBlocks, xorBlocks, outBlocks, length, flags); subKeys, rounds, inBlocks, xorBlocks, outBlocks, length, flags);
} }
size_t Rijndael_Dec_AdvancedProcessBlocks_ARMV8(const word32 *subKeys, size_t rounds, size_t Rijndael_Dec_AdvancedProcessBlocks_ARMV8(const word32 *subKeys, size_t rounds,
const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags)
{ {
return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Dec_Block, ARMV8_Dec_4_Blocks, return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Dec_Block, ARMV8_Dec_6_Blocks,
subKeys, rounds, inBlocks, xorBlocks, outBlocks, length, flags); subKeys, rounds, inBlocks, xorBlocks, outBlocks, length, flags);
} }