From 18dcbdf514298d7097934d9a1d2e9032f14b54b7 Mon Sep 17 00:00:00 2001 From: Jeffrey Walton Date: Wed, 24 Oct 2018 11:00:35 -0400 Subject: [PATCH] Move input xor to ChaCha_OperateKeystream_SSE2 This picks up about 0.2 cpb in ChaCha::OperateKeystream. It may not sound like much but it puts SSE2 intrinsics version on par with the ASM version of Salsa20. Salsa20 leads ChaCha by 0.1 to 0.15 cpb, which equates to about 50 MB/s. --- GNUmakefile | 11 +++++-- chacha-simd.cpp | 77 +++++++++++++++++++++++++++++++++++-------------- chacha.cpp | 24 +++++++++------ 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/GNUmakefile b/GNUmakefile index 592b59dc..60cd8a23 100755 --- a/GNUmakefile +++ b/GNUmakefile @@ -238,9 +238,8 @@ endif # CXXFLAGS # SSE2 is a core feature of x86_64 ifeq ($(findstring -DCRYPTOPP_DISABLE_ASM,$(CXXFLAGS)),) - ifeq ($(IS_X86),1) - SSE_FLAG = -msse2 - endif + SSE_FLAG = -msse2 + CHACHA_FLAG = -msse2 endif ifeq ($(findstring -DCRYPTOPP_DISABLE_SSSE3,$(CXXFLAGS)),) HAVE_SSSE3 = $(shell $(CXX) $(CXXFLAGS) -DADHOC_MAIN -mssse3 -dM -E adhoc.cpp 2>&1 | $(GREP) -i -c __SSSE3__) @@ -379,6 +378,7 @@ ifeq ($(IS_NEON),1) CRC_FLAG = -march=armv7-a -mfloat-abi=$(FP_ABI) -mfpu=neon GCM_FLAG = -march=armv7-a -mfloat-abi=$(FP_ABI) -mfpu=neon BLAKE2_FLAG = -march=armv7-a -mfloat-abi=$(FP_ABI) -mfpu=neon + CHACHA_FLAG = -march=armv7-a -mfloat-abi=$(FP_ABI) -mfpu=neon CHAM_FLAG = -march=armv7-a -mfloat-abi=$(FP_ABI) -mfpu=neon LEA_FLAG = -march=armv7-a -mfloat-abi=$(FP_ABI) -mfpu=neon SHA_FLAG = -march=armv7-a -mfloat-abi=$(FP_ABI) -mfpu=neon @@ -396,6 +396,7 @@ ifeq ($(IS_ARMV8),1) ifeq ($(HAVE_NEON),1) ARIA_FLAG = -march=armv8-a BLAKE2_FLAG = -march=armv8-a + CHACHA_FLAG = -march=armv8-a CHAM_FLAG = -march=armv8-a LEA_FLAG = -march=armv8-a NEON_FLAG = -march=armv8-a @@ -1176,6 +1177,10 @@ aria-simd.o : aria-simd.cpp blake2-simd.o : blake2-simd.cpp $(CXX) $(strip $(CXXFLAGS) $(BLAKE2_FLAG) -c) $< +# SSE2 or NEON available +chacha-simd.o : chacha-simd.cpp + $(CXX) $(strip $(CXXFLAGS) $(CHACHA_FLAG) -c) $< + # SSSE3 available cham-simd.o : cham-simd.cpp $(CXX) $(strip $(CXXFLAGS) $(CHAM_FLAG) -c) $< diff --git a/chacha-simd.cpp b/chacha-simd.cpp index 58d2b786..c0ad8245 100644 --- a/chacha-simd.cpp +++ b/chacha-simd.cpp @@ -22,7 +22,7 @@ # include #endif -#if (CRYPTOPP_ARM_NEON_AVAILABLE) && 0 +#if (CRYPTOPP_ARM_NEON_AVAILABLE) # include #endif @@ -46,7 +46,7 @@ inline __m128i RotateLeft(const __m128i val) return _mm_or_si128(_mm_slli_epi32(val, R), _mm_srli_epi32(val, 32-R)); } -#endif // (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) +#endif // CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE ANONYMOUS_NAMESPACE_END @@ -54,12 +54,13 @@ NAMESPACE_BEGIN(CryptoPP) #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) -void ChaCha_OperateKeystream_SSE2(const word32 *state, byte *message, unsigned int rounds) +void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput) { const __m128i* state_mm = reinterpret_cast(state); - __m128i* message_mm = reinterpret_cast<__m128i*>(message); + const __m128i* input_mm = reinterpret_cast(input); + __m128i* output_mm = reinterpret_cast<__m128i*>(output); - const __m128i state0 = _mm_load_si128(state_mm); + const __m128i state0 = _mm_load_si128(state_mm + 0); const __m128i state1 = _mm_load_si128(state_mm + 1); const __m128i state2 = _mm_load_si128(state_mm + 2); const __m128i state3 = _mm_load_si128(state_mm + 3); @@ -262,27 +263,59 @@ void ChaCha_OperateKeystream_SSE2(const word32 *state, byte *message, unsigned i r3_3 = _mm_add_epi32(r3_3, state3); r3_3 = _mm_add_epi64(r3_3, _mm_set_epi32(0, 0, 0, 3)); - _mm_storeu_si128(message_mm + 0, r0_0); - _mm_storeu_si128(message_mm + 1, r0_1); - _mm_storeu_si128(message_mm + 2, r0_2); - _mm_storeu_si128(message_mm + 3, r0_3); + if (xorInput) + { + r0_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 0), r0_0); + r0_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 1), r0_1); + r0_2 = _mm_xor_si128(_mm_loadu_si128(input_mm + 2), r0_2); + r0_3 = _mm_xor_si128(_mm_loadu_si128(input_mm + 3), r0_3); + } - _mm_storeu_si128(message_mm + 4, r1_0); - _mm_storeu_si128(message_mm + 5, r1_1); - _mm_storeu_si128(message_mm + 6, r1_2); - _mm_storeu_si128(message_mm + 7, r1_3); + _mm_storeu_si128(output_mm + 0, r0_0); + _mm_storeu_si128(output_mm + 1, r0_1); + _mm_storeu_si128(output_mm + 2, r0_2); + _mm_storeu_si128(output_mm + 3, r0_3); - _mm_storeu_si128(message_mm + 8, r2_0); - _mm_storeu_si128(message_mm + 9, r2_1); - _mm_storeu_si128(message_mm + 10, r2_2); - _mm_storeu_si128(message_mm + 11, r2_3); + if (xorInput) + { + r1_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 4), r1_0); + r1_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 5), r1_1); + r1_2 = _mm_xor_si128(_mm_loadu_si128(input_mm + 6), r1_2); + r1_3 = _mm_xor_si128(_mm_loadu_si128(input_mm + 7), r1_3); + } - _mm_storeu_si128(message_mm + 12, r3_0); - _mm_storeu_si128(message_mm + 13, r3_1); - _mm_storeu_si128(message_mm + 14, r3_2); - _mm_storeu_si128(message_mm + 15, r3_3); + _mm_storeu_si128(output_mm + 4, r1_0); + _mm_storeu_si128(output_mm + 5, r1_1); + _mm_storeu_si128(output_mm + 6, r1_2); + _mm_storeu_si128(output_mm + 7, r1_3); + + if (xorInput) + { + r2_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 8), r2_0); + r2_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 9), r2_1); + r2_2 = _mm_xor_si128(_mm_loadu_si128(input_mm + 10), r2_2); + r2_3 = _mm_xor_si128(_mm_loadu_si128(input_mm + 11), r2_3); + } + + _mm_storeu_si128(output_mm + 8, r2_0); + _mm_storeu_si128(output_mm + 9, r2_1); + _mm_storeu_si128(output_mm + 10, r2_2); + _mm_storeu_si128(output_mm + 11, r2_3); + + if (xorInput) + { + r3_0 = _mm_xor_si128(_mm_loadu_si128(input_mm + 12), r3_0); + r3_1 = _mm_xor_si128(_mm_loadu_si128(input_mm + 13), r3_1); + r3_2 = _mm_xor_si128(_mm_loadu_si128(input_mm + 14), r3_2); + r3_3 = _mm_xor_si128(_mm_loadu_si128(input_mm + 15), r3_3); + } + + _mm_storeu_si128(output_mm + 12, r3_0); + _mm_storeu_si128(output_mm + 13, r3_1); + _mm_storeu_si128(output_mm + 14, r3_2); + _mm_storeu_si128(output_mm + 15, r3_3); } -#endif // (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) +#endif // CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE NAMESPACE_END diff --git a/chacha.cpp b/chacha.cpp index 7293c2a9..080ccbfd 100644 --- a/chacha.cpp +++ b/chacha.cpp @@ -12,11 +12,7 @@ NAMESPACE_BEGIN(CryptoPP) #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) -extern void ChaCha_OperateKeystream_SSE2(const word32 *state, byte *message, unsigned int rounds); -#endif - -#if (CRYPTOPP_ARM_NEON_AVAILABLE) -extern void ChaCha_OperateKeystream_NEON(const word32 *state, byte *message, unsigned int rounds); +extern void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput); #endif #define CHACHA_QUARTER_ROUND(a,b,c,d) \ @@ -37,6 +33,10 @@ std::string ChaCha_Policy::AlgorithmProvider() const #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) if (HasSSE2()) return "SSE2"; +#endif +#if (CRYPTOPP_ARM_NEON_AVAILABLE) + if (HasNEON()) + return "NEON"; #endif return "C++"; } @@ -95,10 +95,18 @@ unsigned int ChaCha_Policy::GetOptimalBlockSize() const if (HasSSE2()) return 4*BYTES_PER_ITERATION; else +#endif +#if (CRYPTOPP_ARM_NEON_AVAILABLE) + if (HasNEON()) + return 4*BYTES_PER_ITERATION; + else #endif return BYTES_PER_ITERATION; } +// OperateKeystream always produces a key stream. The key stream is written +// to output. Optionally a message may be supplied to xor with the key stream. +// The message is input, and output = output ^ input. void ChaCha_Policy::OperateKeystream(KeystreamOperation operation, byte *output, const byte *input, size_t iterationCount) { @@ -107,10 +115,8 @@ void ChaCha_Policy::OperateKeystream(KeystreamOperation operation, { while (iterationCount >= 4) { - ChaCha_OperateKeystream_SSE2(m_state, output, m_rounds); - - if ((operation & INPUT_NULL) != INPUT_NULL) - xorbuf(output, input, 4*BYTES_PER_ITERATION); + bool xorInput = (operation & INPUT_NULL) != INPUT_NULL; + ChaCha_OperateKeystream_SSE2(m_state, input, output, m_rounds, xorInput); m_state[12] += 4; if (m_state[12] < 4)