diff --git a/chacha-simd.cpp b/chacha-simd.cpp index 7b58b8a3..ecc1c1d8 100644 --- a/chacha-simd.cpp +++ b/chacha-simd.cpp @@ -46,6 +46,72 @@ extern const char CHACHA_SIMD_FNAME[] = __FILE__; ANONYMOUS_NAMESPACE_BEGIN +#if (CRYPTOPP_ARM_NEON_AVAILABLE) + +template +inline uint32x4_t RotateLeft(const uint32x4_t& val) +{ + const uint32x4_t a(vshlq_n_u32(val, R)); + const uint32x4_t b(vshrq_n_u32(val, 32 - R)); + return vorrq_u32(a, b); +} + +template +inline uint32x4_t RotateRight(const uint32x4_t& val) +{ + const uint32x4_t a(vshlq_n_u32(val, 32 - R)); + const uint32x4_t b(vshrq_n_u32(val, R)); + return vorrq_u32(a, b); +} + +#if defined(__aarch32__) || defined(__aarch64__) +template <> +inline uint32x4_t RotateLeft<8>(const uint32x4_t& val) +{ + const uint8_t maskb[16] = { 3,0,1,2, 7,4,5,6, 11,8,9,10, 15,12,13,14 }; + const uint8x16_t mask = vld1q_u8(maskb); + + return vreinterpretq_u32_u8( + vqtbl1q_u8(vreinterpretq_u8_u32(val), mask)); +} + +template <> +inline uint32x4_t RotateLeft<16>(const uint32x4_t& val) +{ + return vreinterpretq_u32_u16( + vrev32q_u16(vreinterpretq_u16_u32(val))); +} + +template <> +inline uint32x4_t RotateRight<16>(const uint32x4_t& val) +{ + return vreinterpretq_u32_u16( + vrev32q_u16(vreinterpretq_u16_u32(val))); +} + +template <> +inline uint32x4_t RotateRight<8>(const uint32x4_t& val) +{ + const uint8_t maskb[16] = { 1,2,3,0, 5,6,7,4, 9,10,11,8, 13,14,15,12 }; + const uint8x16_t mask = vld1q_u8(maskb); + + return vreinterpretq_u32_u8( + vqtbl1q_u8(vreinterpretq_u8_u32(val), mask)); +} +#endif // Aarch32 or Aarch64 + +// ChaCha's use of shuffle is really a 4, 8, or 12 byte rotation: +// * [3,2,1,0] => [0,3,2,1] is Shuffle<1>(x) +// * [3,2,1,0] => [1,0,3,2] is Shuffle<2>(x) +// * [3,2,1,0] => [2,1,0,3] is Shuffle<3>(x) +template +inline uint32x4_t Shuffle(const uint32x4_t& val) +{ + return vextq_u32(val, val, S); +} + +#endif // CRYPTOPP_ARM_NEON_AVAILABLE + #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) template @@ -88,6 +154,278 @@ ANONYMOUS_NAMESPACE_END NAMESPACE_BEGIN(CryptoPP) +#if (CRYPTOPP_ARM_NEON_AVAILABLE) + +void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput) +{ + const uint32x4_t state0 = vld1q_u32(state + 0*4); + const uint32x4_t state1 = vld1q_u32(state + 1*4); + const uint32x4_t state2 = vld1q_u32(state + 2*4); + const uint32x4_t state3 = vld1q_u32(state + 3*4); + + const uint64x2_t CTRS[3] = { + {1, 0}, {2, 0}, {3, 0} + }; + + uint32x4_t r0_0 = state0; + uint32x4_t r0_1 = state1; + uint32x4_t r0_2 = state2; + uint32x4_t r0_3 = state3; + + uint32x4_t r1_0 = state0; + uint32x4_t r1_1 = state1; + uint32x4_t r1_2 = state2; + uint32x4_t r1_3 = vreinterpretq_u32_u64(vaddq_u64( + vreinterpretq_u64_u32(r0_3), CTRS[0])); + + uint32x4_t r2_0 = state0; + uint32x4_t r2_1 = state1; + uint32x4_t r2_2 = state2; + uint32x4_t r2_3 = vreinterpretq_u32_u64(vaddq_u64( + vreinterpretq_u64_u32(r0_3), CTRS[1])); + + uint32x4_t r3_0 = state0; + uint32x4_t r3_1 = state1; + uint32x4_t r3_2 = state2; + uint32x4_t r3_3 = vreinterpretq_u32_u64(vaddq_u64( + vreinterpretq_u64_u32(r0_3), CTRS[2])); + + for (int i = static_cast(rounds); i > 0; i -= 2) + { + r0_0 = vaddq_u32(r0_0, r0_1); + r1_0 = vaddq_u32(r1_0, r1_1); + r2_0 = vaddq_u32(r2_0, r2_1); + r3_0 = vaddq_u32(r3_0, r3_1); + + r0_3 = veorq_u32(r0_3, r0_0); + r1_3 = veorq_u32(r1_3, r1_0); + r2_3 = veorq_u32(r2_3, r2_0); + r3_3 = veorq_u32(r3_3, r3_0); + + r0_3 = RotateLeft<16>(r0_3); + r1_3 = RotateLeft<16>(r1_3); + r2_3 = RotateLeft<16>(r2_3); + r3_3 = RotateLeft<16>(r3_3); + + r0_2 = vaddq_u32(r0_2, r0_3); + r1_2 = vaddq_u32(r1_2, r1_3); + r2_2 = vaddq_u32(r2_2, r2_3); + r3_2 = vaddq_u32(r3_2, r3_3); + + r0_1 = veorq_u32(r0_1, r0_2); + r1_1 = veorq_u32(r1_1, r1_2); + r2_1 = veorq_u32(r2_1, r2_2); + r3_1 = veorq_u32(r3_1, r3_2); + + r0_1 = RotateLeft<12>(r0_1); + r1_1 = RotateLeft<12>(r1_1); + r2_1 = RotateLeft<12>(r2_1); + r3_1 = RotateLeft<12>(r3_1); + + r0_0 = vaddq_u32(r0_0, r0_1); + r1_0 = vaddq_u32(r1_0, r1_1); + r2_0 = vaddq_u32(r2_0, r2_1); + r3_0 = vaddq_u32(r3_0, r3_1); + + r0_3 = veorq_u32(r0_3, r0_0); + r1_3 = veorq_u32(r1_3, r1_0); + r2_3 = veorq_u32(r2_3, r2_0); + r3_3 = veorq_u32(r3_3, r3_0); + + r0_3 = RotateLeft<8>(r0_3); + r1_3 = RotateLeft<8>(r1_3); + r2_3 = RotateLeft<8>(r2_3); + r3_3 = RotateLeft<8>(r3_3); + + r0_2 = vaddq_u32(r0_2, r0_3); + r1_2 = vaddq_u32(r1_2, r1_3); + r2_2 = vaddq_u32(r2_2, r2_3); + r3_2 = vaddq_u32(r3_2, r3_3); + + r0_1 = veorq_u32(r0_1, r0_2); + r1_1 = veorq_u32(r1_1, r1_2); + r2_1 = veorq_u32(r2_1, r2_2); + r3_1 = veorq_u32(r3_1, r3_2); + + r0_1 = RotateLeft<7>(r0_1); + r1_1 = RotateLeft<7>(r1_1); + r2_1 = RotateLeft<7>(r2_1); + r3_1 = RotateLeft<7>(r3_1); + + r0_1 = Shuffle<1>(r0_1); + r0_2 = Shuffle<2>(r0_2); + r0_3 = Shuffle<3>(r0_3); + + r1_1 = Shuffle<1>(r1_1); + r1_2 = Shuffle<2>(r1_2); + r1_3 = Shuffle<3>(r1_3); + + r2_1 = Shuffle<1>(r2_1); + r2_2 = Shuffle<2>(r2_2); + r2_3 = Shuffle<3>(r2_3); + + r3_1 = Shuffle<1>(r3_1); + r3_2 = Shuffle<2>(r3_2); + r3_3 = Shuffle<3>(r3_3); + + r0_0 = vaddq_u32(r0_0, r0_1); + r1_0 = vaddq_u32(r1_0, r1_1); + r2_0 = vaddq_u32(r2_0, r2_1); + r3_0 = vaddq_u32(r3_0, r3_1); + + r0_3 = veorq_u32(r0_3, r0_0); + r1_3 = veorq_u32(r1_3, r1_0); + r2_3 = veorq_u32(r2_3, r2_0); + r3_3 = veorq_u32(r3_3, r3_0); + + r0_3 = RotateLeft<16>(r0_3); + r1_3 = RotateLeft<16>(r1_3); + r2_3 = RotateLeft<16>(r2_3); + r3_3 = RotateLeft<16>(r3_3); + + r0_2 = vaddq_u32(r0_2, r0_3); + r1_2 = vaddq_u32(r1_2, r1_3); + r2_2 = vaddq_u32(r2_2, r2_3); + r3_2 = vaddq_u32(r3_2, r3_3); + + r0_1 = veorq_u32(r0_1, r0_2); + r1_1 = veorq_u32(r1_1, r1_2); + r2_1 = veorq_u32(r2_1, r2_2); + r3_1 = veorq_u32(r3_1, r3_2); + + r0_1 = RotateLeft<12>(r0_1); + r1_1 = RotateLeft<12>(r1_1); + r2_1 = RotateLeft<12>(r2_1); + r3_1 = RotateLeft<12>(r3_1); + + r0_0 = vaddq_u32(r0_0, r0_1); + r1_0 = vaddq_u32(r1_0, r1_1); + r2_0 = vaddq_u32(r2_0, r2_1); + r3_0 = vaddq_u32(r3_0, r3_1); + + r0_3 = veorq_u32(r0_3, r0_0); + r1_3 = veorq_u32(r1_3, r1_0); + r2_3 = veorq_u32(r2_3, r2_0); + r3_3 = veorq_u32(r3_3, r3_0); + + r0_3 = RotateLeft<8>(r0_3); + r1_3 = RotateLeft<8>(r1_3); + r2_3 = RotateLeft<8>(r2_3); + r3_3 = RotateLeft<8>(r3_3); + + r0_2 = vaddq_u32(r0_2, r0_3); + r1_2 = vaddq_u32(r1_2, r1_3); + r2_2 = vaddq_u32(r2_2, r2_3); + r3_2 = vaddq_u32(r3_2, r3_3); + + r0_1 = veorq_u32(r0_1, r0_2); + r1_1 = veorq_u32(r1_1, r1_2); + r2_1 = veorq_u32(r2_1, r2_2); + r3_1 = veorq_u32(r3_1, r3_2); + + r0_1 = RotateLeft<7>(r0_1); + r1_1 = RotateLeft<7>(r1_1); + r2_1 = RotateLeft<7>(r2_1); + r3_1 = RotateLeft<7>(r3_1); + + r0_1 = Shuffle<3>(r0_1); + r0_2 = Shuffle<2>(r0_2); + r0_3 = Shuffle<1>(r0_3); + + r1_1 = Shuffle<3>(r1_1); + r1_2 = Shuffle<2>(r1_2); + r1_3 = Shuffle<1>(r1_3); + + r2_1 = Shuffle<3>(r2_1); + r2_2 = Shuffle<2>(r2_2); + r2_3 = Shuffle<1>(r2_3); + + r3_1 = Shuffle<3>(r3_1); + r3_2 = Shuffle<2>(r3_2); + r3_3 = Shuffle<1>(r3_3); + } + + r0_0 = vaddq_u32(r0_0, state0); + r0_1 = vaddq_u32(r0_1, state1); + r0_2 = vaddq_u32(r0_2, state2); + r0_3 = vaddq_u32(r0_3, state3); + + r1_0 = vaddq_u32(r1_0, state0); + r1_1 = vaddq_u32(r1_1, state1); + r1_2 = vaddq_u32(r1_2, state2); + r1_3 = vaddq_u32(r1_3, state3); + r1_3 = vreinterpretq_u32_u64(vaddq_u64( + vreinterpretq_u64_u32(r1_3), CTRS[0])); + + r2_0 = vaddq_u32(r2_0, state0); + r2_1 = vaddq_u32(r2_1, state1); + r2_2 = vaddq_u32(r2_2, state2); + r2_3 = vaddq_u32(r2_3, state3); + r2_3 = vreinterpretq_u32_u64(vaddq_u64( + vreinterpretq_u64_u32(r2_3), CTRS[1])); + + r3_0 = vaddq_u32(r3_0, state0); + r3_1 = vaddq_u32(r3_1, state1); + r3_2 = vaddq_u32(r3_2, state2); + r3_3 = vaddq_u32(r3_3, state3); + r3_3 = vreinterpretq_u32_u64(vaddq_u64( + vreinterpretq_u64_u32(r3_3), CTRS[2])); + + if (xorInput) + { + r0_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 0*16)), r0_0); + r0_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 1*16)), r0_1); + r0_2 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 2*16)), r0_2); + r0_3 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 3*16)), r0_3); + } + + vst1q_u8(output + 0*16, vreinterpretq_u8_u32(r0_0)); + vst1q_u8(output + 1*16, vreinterpretq_u8_u32(r0_1)); + vst1q_u8(output + 2*16, vreinterpretq_u8_u32(r0_2)); + vst1q_u8(output + 3*16, vreinterpretq_u8_u32(r0_3)); + + if (xorInput) + { + r1_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 4*16)), r1_0); + r1_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 5*16)), r1_1); + r1_2 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 6*16)), r1_2); + r1_3 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 7*16)), r1_3); + } + + vst1q_u8(output + 4*16, vreinterpretq_u8_u32(r1_0)); + vst1q_u8(output + 5*16, vreinterpretq_u8_u32(r1_1)); + vst1q_u8(output + 6*16, vreinterpretq_u8_u32(r1_2)); + vst1q_u8(output + 7*16, vreinterpretq_u8_u32(r1_3)); + + if (xorInput) + { + r2_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 8*16)), r2_0); + r2_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 9*16)), r2_1); + r2_2 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 10*16)), r2_2); + r2_3 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 11*16)), r2_3); + } + + vst1q_u8(output + 8*16, vreinterpretq_u8_u32(r2_0)); + vst1q_u8(output + 9*16, vreinterpretq_u8_u32(r2_1)); + vst1q_u8(output + 10*16, vreinterpretq_u8_u32(r2_2)); + vst1q_u8(output + 11*16, vreinterpretq_u8_u32(r2_3)); + + if (xorInput) + { + r3_0 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 12*16)), r3_0); + r3_1 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 13*16)), r3_1); + r3_2 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 14*16)), r3_2); + r3_3 = veorq_u32(vreinterpretq_u32_u8(vld1q_u8(input + 15*16)), r3_3); + } + + vst1q_u8(output + 12*16, vreinterpretq_u8_u32(r3_0)); + vst1q_u8(output + 13*16, vreinterpretq_u8_u32(r3_1)); + vst1q_u8(output + 14*16, vreinterpretq_u8_u32(r3_2)); + vst1q_u8(output + 15*16, vreinterpretq_u8_u32(r3_3)); +} + +#endif // CRYPTOPP_ARM_NEON_AVAILABLE + #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput) diff --git a/chacha.cpp b/chacha.cpp index 23fd3be8..aa2e9268 100644 --- a/chacha.cpp +++ b/chacha.cpp @@ -11,6 +11,10 @@ NAMESPACE_BEGIN(CryptoPP) +#if (CRYPTOPP_ARM_NEON_AVAILABLE) +extern void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput); +#endif + #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) extern void ChaCha_OperateKeystream_SSE2(const word32 *state, const byte* input, byte *output, unsigned int rounds, bool xorInput); #endif @@ -33,6 +37,10 @@ std::string ChaCha_Policy::AlgorithmProvider() const #if (CRYPTOPP_SSE2_INTRIN_AVAILABLE || CRYPTOPP_SSE2_ASM_AVAILABLE) if (HasSSE2()) return "SSE2"; +#endif +#if (CRYPTOPP_ARM_NEON_AVAILABLE) + if (HasNEON()) + return "NEON"; #endif return "C++"; } @@ -91,6 +99,11 @@ unsigned int ChaCha_Policy::GetOptimalBlockSize() const if (HasSSE2()) return 4*BYTES_PER_ITERATION; else +#endif +#if (CRYPTOPP_ARM_NEON_AVAILABLE) + if (HasNEON()) + return 4*BYTES_PER_ITERATION; + else #endif return BYTES_PER_ITERATION; } @@ -113,7 +126,26 @@ void ChaCha_Policy::OperateKeystream(KeystreamOperation operation, if (m_state[12] < 4) m_state[13]++; - input += !!xorInput*4*BYTES_PER_ITERATION; + input += (!!xorInput)*4*BYTES_PER_ITERATION; + output += 4*BYTES_PER_ITERATION; + iterationCount -= 4; + } + } +#endif + +#if (CRYPTOPP_ARM_NEON_AVAILABLE) + if (HasNEON()) + { + while (iterationCount >= 4) + { + bool xorInput = (operation & INPUT_NULL) != INPUT_NULL; + ChaCha_OperateKeystream_NEON(m_state, input, output, m_rounds, xorInput); + + m_state[12] += 4; + if (m_state[12] < 4) + m_state[13]++; + + input += (!!xorInput)*4*BYTES_PER_ITERATION; output += 4*BYTES_PER_ITERATION; iterationCount -= 4; }