diff --git a/chacha-simd.cpp b/chacha-simd.cpp index e4d24178..6ec01834 100644 --- a/chacha-simd.cpp +++ b/chacha-simd.cpp @@ -193,7 +193,7 @@ NAMESPACE_BEGIN(CryptoPP) #if (CRYPTOPP_ARM_NEON_AVAILABLE) -void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput) +void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds) { const uint32x4_t state0 = vld1q_u32(state + 0*4); const uint32x4_t state1 = vld1q_u32(state + 1*4); @@ -408,7 +408,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte * r3_3 = vreinterpretq_u32_u64(vaddq_u64( vreinterpretq_u64_u32(r3_3), CTRS[2])); - if (xorInput) + if (input) { r0_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 0*16)), r0_0); r0_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 1*16)), r0_1); @@ -421,7 +421,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte * vst1q_u8(output + 2*16, vreinterpretq_u8_u32(r0_2)); vst1q_u8(output + 3*16, vreinterpretq_u8_u32(r0_3)); - if (xorInput) + if (input) { r1_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 4*16)), r1_0); r1_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 5*16)), r1_1); @@ -434,7 +434,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte * vst1q_u8(output + 6*16, vreinterpretq_u8_u32(r1_2)); vst1q_u8(output + 7*16, vreinterpretq_u8_u32(r1_3)); - if (xorInput) + if (input) { r2_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 8*16)), r2_0); r2_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 9*16)), r2_1); @@ -447,7 +447,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte * vst1q_u8(output + 10*16, vreinterpretq_u8_u32(r2_2)); vst1q_u8(output + 11*16, vreinterpretq_u8_u32(r2_3)); - if (xorInput) + if (input) { r3_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 12*16)), r3_0); r3_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 13*16)), r3_1); @@ -467,7 +467,7 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte * #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) -void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput) +void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds) { const __m128i* state_mm = reinterpret_cast(state); const __m128i* input_mm = reinterpret_cast(input); @@ -676,7 +676,7 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte * r3_3 = _mm_add_epi32(r3_3, state3); r3_3 = _mm_add_epi64(r3_3, _mm_set_epi32(0, 0, 0, 3)); - if (xorInput) + if (input_mm) { r0_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 0), r0_0); r0_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 1), r0_1); @@ -689,7 +689,7 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte * _mm_storeu_si128(output_mm + 2, r0_2); _mm_storeu_si128(output_mm + 3, r0_3); - if (xorInput) + if (input_mm) { r1_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 4), r1_0); r1_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 5), r1_1); @@ -702,7 +702,7 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte * _mm_storeu_si128(output_mm + 6, r1_2); _mm_storeu_si128(output_mm + 7, r1_3); - if (xorInput) + if (input_mm) { r2_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 8), r2_0); r2_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 9), r2_1); @@ -715,7 +715,7 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte * _mm_storeu_si128(output_mm + 10, r2_2); _mm_storeu_si128(output_mm + 11, r2_3); - if (xorInput) + if (input_mm) { r3_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 12), r3_0); r3_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 13), r3_1); diff --git a/chacha.cpp b/chacha.cpp index 723f37f2..63b24845 100644 --- a/chacha.cpp +++ b/chacha.cpp @@ -12,11 +12,11 @@ NAMESPACE_BEGIN(CryptoPP) #if (CRYPTOPP_ARM_NEON_AVAILABLE) -extern void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput); +extern void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds); #endif #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) -extern void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput); +extern void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds); #endif #define CHACHA_QUARTER_ROUND(a,b,c,d) \ @@ -124,8 +124,8 @@ void ChaCha_Policy::OperateKeystream(KeystreamOperation operation, { while (iterationCount >= 4) { - bool xorInput = (operation & INPUT_NULL) != INPUT_NULL; - ChaCha_OperateKeystream_SSE2(m_state, input, output, m_rounds, xorInput); + const bool xorInput = (operation & INPUT_NULL) != INPUT_NULL; + ChaCha_OperateKeystream_SSE2(m_state, xorInput ? input : NULLPTR, output, m_rounds); m_state[12] += 4; if (m_state[12] < 4) @@ -143,8 +143,8 @@ void ChaCha_Policy::OperateKeystream(KeystreamOperation operation, { while (iterationCount >= 4) { - bool xorInput = (operation & INPUT_NULL) != INPUT_NULL; - ChaCha_OperateKeystream_NEON(m_state, input, output, m_rounds, xorInput); + const bool xorInput = (operation & INPUT_NULL) != INPUT_NULL; + ChaCha_OperateKeystream_NEON(m_state, xorInput ? input : NULLPTR, output, m_rounds); m_state[12] += 4; if (m_state[12] < 4) @@ -163,8 +163,8 @@ void ChaCha_Policy::OperateKeystream(KeystreamOperation operation, x0 = m_state[0]; x1 = m_state[1]; x2 = m_state[2]; x3 = m_state[3]; x4 = m_state[4]; x5 = m_state[5]; x6 = m_state[6]; x7 = m_state[7]; - x8 = m_state[8]; x9 = m_state[9]; x10 = m_state[10]; x11 = m_state[11]; - x12 = m_state[12]; x13 = m_state[13]; x14 = m_state[14]; x15 = m_state[15]; + x8 = m_state[8]; x9 = m_state[9]; x10 = m_state[10]; x11 = m_state[11]; + x12 = m_state[12]; x13 = m_state[13]; x14 = m_state[14]; x15 = m_state[15]; for (int i = static_cast(m_rounds); i > 0; i -= 2) {