fix bug when AuthenticatedDecryptionFilter::MAC_AT_BEGIN is not specified

pull/2/head
weidai 2009-03-05 08:53:50 +00:00
parent e4295fda97
commit da24db2a8b
2 changed files with 10 additions and 2 deletions

View File

@ -397,7 +397,8 @@ void TestAuthenticatedSymmetricCipher(TestData &v, const NameValuePairs &overrid
std::string encrypted, decrypted; std::string encrypted, decrypted;
AuthenticatedEncryptionFilter ef(*asc1, new StringSink(encrypted)); AuthenticatedEncryptionFilter ef(*asc1, new StringSink(encrypted));
AuthenticatedDecryptionFilter df(*asc2, new StringSink(decrypted), AuthenticatedDecryptionFilter::MAC_AT_BEGIN); bool macAtBegin = !GlobalRNG().GenerateBit(); // test both ways randomly
AuthenticatedDecryptionFilter df(*asc2, new StringSink(decrypted), macAtBegin ? AuthenticatedDecryptionFilter::MAC_AT_BEGIN : 0);
if (asc1->NeedsPrespecifiedDataLengths()) if (asc1->NeedsPrespecifiedDataLengths())
{ {
@ -407,10 +408,13 @@ void TestAuthenticatedSymmetricCipher(TestData &v, const NameValuePairs &overrid
StringStore sh(header), sp(plaintext), sc(ciphertext), sf(footer), sm(mac); StringStore sh(header), sp(plaintext), sc(ciphertext), sf(footer), sm(mac);
sm.TransferTo(df); if (macAtBegin)
sm.TransferTo(df);
sh.CopyTo(df, LWORD_MAX, "AAD"); sh.CopyTo(df, LWORD_MAX, "AAD");
sc.TransferTo(df); sc.TransferTo(df);
sf.CopyTo(df, LWORD_MAX, "AAD"); sf.CopyTo(df, LWORD_MAX, "AAD");
if (!macAtBegin)
sm.TransferTo(df);
df.MessageEnd(); df.MessageEnd();
sh.TransferTo(ef, sh.MaxRetrievable()/2+1, "AAD"); sh.TransferTo(ef, sh.MaxRetrievable()/2+1, "AAD");

View File

@ -885,7 +885,11 @@ byte * AuthenticatedDecryptionFilter::ChannelCreatePutSpace(const std::string &c
size_t AuthenticatedDecryptionFilter::ChannelPut2(const std::string &channel, const byte *begin, size_t length, int messageEnd, bool blocking) size_t AuthenticatedDecryptionFilter::ChannelPut2(const std::string &channel, const byte *begin, size_t length, int messageEnd, bool blocking)
{ {
if (channel.empty()) if (channel.empty())
{
if (m_lastSize > 0)
m_hashVerifier.ForceNextPut();
return FilterWithBufferedInput::Put2(begin, length, messageEnd, blocking); return FilterWithBufferedInput::Put2(begin, length, messageEnd, blocking);
}
if (channel == "AAD") if (channel == "AAD")
return m_hashVerifier.Put2(begin, length, 0, blocking); return m_hashVerifier.Put2(begin, length, 0, blocking);