Remove xorInput parameter from ChaCha SIMD functions

We can use the input pointer directly after checking KeystreamOperation
pull/730/head
Jeffrey Walton 2018-10-26 10:10:52 -04:00
parent 61a696f710
commit c0b273dac8
No known key found for this signature in database
GPG Key ID: B36AB348921B1838
2 changed files with 18 additions and 18 deletions

View File

@ -193,7 +193,7 @@ NAMESPACE_BEGIN(CryptoPP)
#if (CRYPTOPP_ARM_NEON_AVAILABLE) #if (CRYPTOPP_ARM_NEON_AVAILABLE)
void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput) void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds)
{ {
const uint32x4_t state0 = vld1q_u32(state + 0*4); const uint32x4_t state0 = vld1q_u32(state + 0*4);
const uint32x4_t state1 = vld1q_u32(state + 1*4); const uint32x4_t state1 = vld1q_u32(state + 1*4);
@ -408,7 +408,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *
r3_3 = vreinterpretq_u32_u64(vaddq_u64( r3_3 = vreinterpretq_u32_u64(vaddq_u64(
vreinterpretq_u64_u32(r3_3), CTRS[2])); vreinterpretq_u64_u32(r3_3), CTRS[2]));
if (xorInput) if (input)
{ {
r0_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 0*16)), r0_0); r0_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 0*16)), r0_0);
r0_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 1*16)), r0_1); r0_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 1*16)), r0_1);
@ -421,7 +421,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *
vst1q_u8(output + 2*16, vreinterpretq_u8_u32(r0_2)); vst1q_u8(output + 2*16, vreinterpretq_u8_u32(r0_2));
vst1q_u8(output + 3*16, vreinterpretq_u8_u32(r0_3)); vst1q_u8(output + 3*16, vreinterpretq_u8_u32(r0_3));
if (xorInput) if (input)
{ {
r1_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 4*16)), r1_0); r1_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 4*16)), r1_0);
r1_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 5*16)), r1_1); r1_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 5*16)), r1_1);
@ -434,7 +434,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *
vst1q_u8(output + 6*16, vreinterpretq_u8_u32(r1_2)); vst1q_u8(output + 6*16, vreinterpretq_u8_u32(r1_2));
vst1q_u8(output + 7*16, vreinterpretq_u8_u32(r1_3)); vst1q_u8(output + 7*16, vreinterpretq_u8_u32(r1_3));
if (xorInput) if (input)
{ {
r2_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 8*16)), r2_0); r2_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 8*16)), r2_0);
r2_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 9*16)), r2_1); r2_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 9*16)), r2_1);
@ -447,7 +447,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *
vst1q_u8(output + 10*16, vreinterpretq_u8_u32(r2_2)); vst1q_u8(output + 10*16, vreinterpretq_u8_u32(r2_2));
vst1q_u8(output + 11*16, vreinterpretq_u8_u32(r2_3)); vst1q_u8(output + 11*16, vreinterpretq_u8_u32(r2_3));
if (xorInput) if (input)
{ {
r3_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 12*16)), r3_0); r3_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 12*16)), r3_0);
r3_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 13*16)), r3_1); r3_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 13*16)), r3_1);
@ -467,7 +467,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *
#if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE)
void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput) void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds)
{ {
const __m128i* state_mm = reinterpret_cast<const __m128i*>(state); const __m128i* state_mm = reinterpret_cast<const __m128i*>(state);
const __m128i* input_mm = reinterpret_cast<const __m128i*>(input); const __m128i* input_mm = reinterpret_cast<const __m128i*>(input);
@ -676,7 +676,7 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *
r3_3 = _mm_add_epi32(r3_3, state3); r3_3 = _mm_add_epi32(r3_3, state3);
r3_3 = _mm_add_epi64(r3_3, _mm_set_epi32(0, 0, 0, 3)); r3_3 = _mm_add_epi64(r3_3, _mm_set_epi32(0, 0, 0, 3));
if (xorInput) if (input_mm)
{ {
r0_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 0), r0_0); r0_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 0), r0_0);
r0_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 1), r0_1); r0_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 1), r0_1);
@ -689,7 +689,7 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *
_mm_storeu_si128(output_mm + 2, r0_2); _mm_storeu_si128(output_mm + 2, r0_2);
_mm_storeu_si128(output_mm + 3, r0_3); _mm_storeu_si128(output_mm + 3, r0_3);
if (xorInput) if (input_mm)
{ {
r1_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 4), r1_0); r1_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 4), r1_0);
r1_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 5), r1_1); r1_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 5), r1_1);
@ -702,7 +702,7 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *
_mm_storeu_si128(output_mm + 6, r1_2); _mm_storeu_si128(output_mm + 6, r1_2);
_mm_storeu_si128(output_mm + 7, r1_3); _mm_storeu_si128(output_mm + 7, r1_3);
if (xorInput) if (input_mm)
{ {
r2_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 8), r2_0); r2_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 8), r2_0);
r2_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 9), r2_1); r2_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 9), r2_1);
@ -715,7 +715,7 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *
_mm_storeu_si128(output_mm + 10, r2_2); _mm_storeu_si128(output_mm + 10, r2_2);
_mm_storeu_si128(output_mm + 11, r2_3); _mm_storeu_si128(output_mm + 11, r2_3);
if (xorInput) if (input_mm)
{ {
r3_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 12), r3_0); r3_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 12), r3_0);
r3_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 13), r3_1); r3_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 13), r3_1);

View File

@ -12,11 +12,11 @@
NAMESPACE_BEGIN(CryptoPP) NAMESPACE_BEGIN(CryptoPP)
#if (CRYPTOPP_ARM_NEON_AVAILABLE) #if (CRYPTOPP_ARM_NEON_AVAILABLE)
extern void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput); extern void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds);
#endif #endif
#if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE)
extern void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput); extern void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds);
#endif #endif
#define CHACHA_QUARTER_ROUND(a,b,c,d) \ #define CHACHA_QUARTER_ROUND(a,b,c,d) \
@ -124,8 +124,8 @@ void ChaCha_Policy::OperateKeystream(KeystreamOperation operation,
{ {
while (iterationCount >= 4) while (iterationCount >= 4)
{ {
bool xorInput = (operation & INPUT_NULL) != INPUT_NULL; const bool xorInput = (operation & INPUT_NULL) != INPUT_NULL;
ChaCha_OperateKeystream_SSE2(m_state, input, output, m_rounds, xorInput); ChaCha_OperateKeystream_SSE2(m_state, xorInput ? input : NULLPTR, output, m_rounds);
m_state[12] += 4; m_state[12] += 4;
if (m_state[12] < 4) if (m_state[12] < 4)
@ -143,8 +143,8 @@ void ChaCha_Policy::OperateKeystream(KeystreamOperation operation,
{ {
while (iterationCount >= 4) while (iterationCount >= 4)
{ {
bool xorInput = (operation & INPUT_NULL) != INPUT_NULL; const bool xorInput = (operation & INPUT_NULL) != INPUT_NULL;
ChaCha_OperateKeystream_NEON(m_state, input, output, m_rounds, xorInput); ChaCha_OperateKeystream_NEON(m_state, xorInput ? input : NULLPTR, output, m_rounds);
m_state[12] += 4; m_state[12] += 4;
if (m_state[12] < 4) if (m_state[12] < 4)
@ -163,8 +163,8 @@ void ChaCha_Policy::OperateKeystream(KeystreamOperation operation,
x0 = m_state[0]; x1 = m_state[1]; x2 = m_state[2]; x3 = m_state[3]; x0 = m_state[0]; x1 = m_state[1]; x2 = m_state[2]; x3 = m_state[3];
x4 = m_state[4]; x5 = m_state[5]; x6 = m_state[6]; x7 = m_state[7]; x4 = m_state[4]; x5 = m_state[5]; x6 = m_state[6]; x7 = m_state[7];
x8 = m_state[8]; x9 = m_state[9]; x10 = m_state[10]; x11 = m_state[11]; x8 = m_state[8]; x9 = m_state[9]; x10 = m_state[10]; x11 = m_state[11];
x12 = m_state[12]; x13 = m_state[13]; x14 = m_state[14]; x15 = m_state[15]; x12 = m_state[12]; x13 = m_state[13]; x14 = m_state[14]; x15 = m_state[15];
for (int i = static_cast<int>(m_rounds); i > 0; i -= 2) for (int i = static_cast<int>(m_rounds); i > 0; i -= 2)
{ {