diff --git a/chacha.cpp b/chacha.cpp index 9e0e3d37..e16a5008 100644 --- a/chacha.cpp +++ b/chacha.cpp @@ -345,9 +345,13 @@ void ChaCha_Policy::CipherSetKey(const NameValuePairs ¶ms, const byte *key, CRYPTOPP_ASSERT(key); CRYPTOPP_ASSERT(length == 16 || length == 32); CRYPTOPP_UNUSED(key); CRYPTOPP_UNUSED(length); - m_rounds = params.GetIntValueWithDefault(Name::Rounds(), 20); - if (m_rounds != 20 && m_rounds != 12 && m_rounds != 8) - throw InvalidRounds(ChaCha::StaticAlgorithmName(), m_rounds); + // Use previous rounds as the default value + int rounds = params.GetIntValueWithDefault(Name::Rounds(), m_rounds); + if (rounds != 20 && rounds != 12 && rounds != 8) + throw InvalidRounds(ChaCha::StaticAlgorithmName(), rounds); + + // Latch a good value + m_rounds = rounds; // "expand 16-byte k" or "expand 32-byte k" m_state[0] = 0x61707865; @@ -425,9 +429,9 @@ void ChaChaTLS_Policy::CipherSetKey(const NameValuePairs ¶ms, const byte *ke // the function, so we have to use the heavier-weight SetKey to change it. word64 block; if (params.GetValue("InitialBlock", block)) - m_state[CTR] = static_cast(block); + m_counter = static_cast(block); else - m_state[CTR] = 0; + m_counter = 0; // State words are defined in RFC 8439, Section 2.3. Key is 32-bytes. GetBlock get(key); @@ -449,7 +453,7 @@ void ChaChaTLS_Policy::CipherResynchronize(byte *keystreamBuffer, const byte *IV // State words are defined in RFC 8439, Section 2.3 GetBlock get(IV); - m_state[12] = m_state[CTR]; + m_state[12] = m_counter; get(m_state[13])(m_state[14])(m_state[15]); } @@ -506,16 +510,19 @@ void XChaCha20_Policy::CipherSetKey(const NameValuePairs ¶ms, const byte *ke { CRYPTOPP_ASSERT(key); CRYPTOPP_ASSERT(length == 32); - // XChaCha20 is always 20 rounds. Fetch Rounds() to avoid a spurious failure. - int rounds = params.GetIntValueWithDefault(Name::Rounds(), ROUNDS); - if (rounds != 20) - throw InvalidRounds(XChaCha20::StaticAlgorithmName(), rounds); + // Use previous rounds as the default value + int rounds = params.GetIntValueWithDefault(Name::Rounds(), m_rounds); + if (rounds != 20 && rounds != 12) + throw InvalidRounds(ChaCha::StaticAlgorithmName(), rounds); + + // Latch a good value + m_rounds = rounds; word64 block; if (params.GetValue("InitialBlock", block)) - m_state[CTR] = static_cast(block); + m_counter = static_cast(block); else - m_state[CTR] = 1; + m_counter = 1; // Stash key away for use in CipherResynchronize GetBlock get(key); @@ -548,7 +555,7 @@ void XChaCha20_Policy::CipherResynchronize(byte *keystreamBuffer, const byte *iv m_state[2] = 0x79622d32; m_state[3] = 0x6b206574; // Setup new IV - m_state[12] = m_state[CTR]; + m_state[12] = m_counter; m_state[13] = 0; m_state[14] = GetWord(false, LITTLE_ENDIAN_ORDER, iv+16); m_state[15] = GetWord(false, LITTLE_ENDIAN_ORDER, iv+20); @@ -575,7 +582,7 @@ void XChaCha20_Policy::OperateKeystream(KeystreamOperation operation, byte *output, const byte *input, size_t iterationCount) { ChaCha_OperateKeystream(operation, m_state, m_state[12], m_state[13], - ROUNDS, output, input, iterationCount); + m_rounds, output, input, iterationCount); } NAMESPACE_END diff --git a/chacha.h b/chacha.h index 76aeb01b..0fab929a 100644 --- a/chacha.h +++ b/chacha.h @@ -58,7 +58,7 @@ class CRYPTOPP_NO_VTABLE ChaCha_Policy : public AdditiveCipherConcretePolicy m_state; unsigned int m_rounds; }; @@ -114,7 +115,7 @@ class CRYPTOPP_NO_VTABLE ChaChaTLS_Policy : public AdditiveCipherConcretePolicy< { public: virtual ~ChaChaTLS_Policy() {} - ChaChaTLS_Policy() {} + ChaChaTLS_Policy() : m_counter(0) {} protected: void CipherSetKey(const NameValuePairs ¶ms, const byte *key, size_t length); @@ -128,7 +129,8 @@ protected: std::string AlgorithmName() const; std::string AlgorithmProvider() const; - FixedSizeAlignedSecBlock m_state; + FixedSizeAlignedSecBlock m_state; + unsigned int m_counter; CRYPTOPP_CONSTANT(ROUNDS = ChaChaTLS_Info::ROUNDS) CRYPTOPP_CONSTANT(KEY = 16) // Index into m_state CRYPTOPP_CONSTANT(CTR = 24) // Index into m_state @@ -161,7 +163,7 @@ struct ChaChaTLS : public ChaChaTLS_Info, public SymmetricCipherDocumentation /// \brief XChaCha stream cipher information /// \since Crypto++ 8.1 -struct XChaCha20_Info : public FixedKeyLength<32, SimpleKeyingInterface::UNIQUE_IV, 24>, FixedRounds<20> +struct XChaCha20_Info : public FixedKeyLength<32, SimpleKeyingInterface::UNIQUE_IV, 24> { /// \brief The algorithm name /// \returns the algorithm name @@ -179,7 +181,7 @@ class CRYPTOPP_NO_VTABLE XChaCha20_Policy : public AdditiveCipherConcretePolicy< { public: virtual ~XChaCha20_Policy() {} - XChaCha20_Policy() {} + XChaCha20_Policy() : m_counter(0), m_rounds(ROUNDS) {} protected: void CipherSetKey(const NameValuePairs ¶ms, const byte *key, size_t length); @@ -193,10 +195,10 @@ protected: std::string AlgorithmName() const; std::string AlgorithmProvider() const; - FixedSizeAlignedSecBlock m_state; - CRYPTOPP_CONSTANT(ROUNDS = XChaCha20_Info::ROUNDS) + FixedSizeAlignedSecBlock m_state; + unsigned int m_counter, m_rounds; + CRYPTOPP_CONSTANT(ROUNDS = 20) // Default rounds CRYPTOPP_CONSTANT(KEY = 16) // Index into m_state - CRYPTOPP_CONSTANT(CTR = 24) // Index into m_state }; /// \brief XChaCha stream cipher diff --git a/salsa.cpp b/salsa.cpp index c5330a9a..93f6293b 100644 --- a/salsa.cpp +++ b/salsa.cpp @@ -112,8 +112,7 @@ std::string Salsa20_Policy::AlgorithmProvider() const void Salsa20_Policy::CipherSetKey(const NameValuePairs ¶ms, const byte *key, size_t length) { - m_rounds = params.GetIntValueWithDefault(Name::Rounds(), 20); - + m_rounds = params.GetIntValueWithDefault(Name::Rounds(), m_rounds); if (!(m_rounds == 8 || m_rounds == 12 || m_rounds == 20)) throw InvalidRounds(Salsa20::StaticAlgorithmName(), m_rounds); @@ -692,8 +691,7 @@ Salsa20_OperateKeystream ENDP void XSalsa20_Policy::CipherSetKey(const NameValuePairs ¶ms, const byte *key, size_t length) { - m_rounds = params.GetIntValueWithDefault(Name::Rounds(), 20); - + m_rounds = params.GetIntValueWithDefault(Name::Rounds(), m_rounds); if (!(m_rounds == 8 || m_rounds == 12 || m_rounds == 20)) throw InvalidRounds(XSalsa20::StaticAlgorithmName(), m_rounds); diff --git a/salsa.h b/salsa.h index a42d684b..31bb9a67 100644 --- a/salsa.h +++ b/salsa.h @@ -36,6 +36,7 @@ struct Salsa20_Info : public VariableKeyLength<32, 16, 32, 16, SimpleKeyingInter class CRYPTOPP_NO_VTABLE Salsa20_Policy : public AdditiveCipherConcretePolicy { protected: + Salsa20_Policy() : m_rounds(ROUNDS) {} void CipherSetKey(const NameValuePairs ¶ms, const byte *key, size_t length); void OperateKeystream(KeystreamOperation operation, byte *output, const byte *input, size_t iterationCount); void CipherResynchronize(byte *keystreamBuffer, const byte *IV, size_t length); @@ -49,6 +50,7 @@ protected: std::string AlgorithmProvider() const; + CRYPTOPP_CONSTANT(ROUNDS = 20) // Default rounds FixedSizeAlignedSecBlock m_state; int m_rounds; };