Use aligned buffer for datatest.cpp

pull/489/head
Jeffrey Walton 2017-09-04 20:07:47 -04:00
parent a2223356b0
commit efe88c043b
No known key found for this signature in database
GPG Key ID: B36AB348921B1838
1 changed files with 53 additions and 23 deletions

View File

@ -30,11 +30,25 @@
#endif #endif
NAMESPACE_BEGIN(CryptoPP) NAMESPACE_BEGIN(CryptoPP)
typedef std::basic_string<char, std::char_traits<char>, AllocatorWithCleanup<char, true> > aligned_string;
typedef StringSinkTemplate<aligned_string> AlignedStringSink;
NAMESPACE_BEGIN(Test) NAMESPACE_BEGIN(Test)
typedef std::map<std::string, std::string> TestData; typedef std::map<std::string, std::string> TestData;
static bool s_thorough = false; static bool s_thorough = false;
bool operator ==(const std::string& a, const aligned_string& b)
{
return a.length() == b.length() && 0 == std::memcmp(a.data(), b.data(), a.size());
}
bool operator !=(const std::string& a, const aligned_string& b)
{
return !(a == b);
}
class TestFailure : public Exception class TestFailure : public Exception
{ {
public: public:
@ -87,7 +101,7 @@ void RandomizedTransfer(BufferedTransformation &source, BufferedTransformation &
{ {
while (source.MaxRetrievable() > (finish ? 0 : 4096)) while (source.MaxRetrievable() > (finish ? 0 : 4096))
{ {
byte buf[4096+64]; CRYPTOPP_ALIGN_DATA(16) byte buf[4096+64];
size_t start = Test::GlobalRNG().GenerateWord32(0, 63); size_t start = Test::GlobalRNG().GenerateWord32(0, 63);
size_t len = Test::GlobalRNG().GenerateWord32(1, UnsignedMin(4096U, 3*source.MaxRetrievable()/2)); size_t len = Test::GlobalRNG().GenerateWord32(1, UnsignedMin(4096U, 3*source.MaxRetrievable()/2));
len = source.Get(buf+start, len); len = source.Get(buf+start, len);
@ -181,6 +195,13 @@ std::string GetDecodedDatum(const TestData &data, const char *name)
return s; return s;
} }
aligned_string GetAlignedDecodedDatum(const TestData &data, const char *name)
{
aligned_string s;
PutDecodedDatumInto(data, name, AlignedStringSink(s).Ref());
return s;
}
std::string GetOptionalDecodedDatum(const TestData &data, const char *name) std::string GetOptionalDecodedDatum(const TestData &data, const char *name)
{ {
std::string s; std::string s;
@ -189,6 +210,14 @@ std::string GetOptionalDecodedDatum(const TestData &data, const char *name)
return s; return s;
} }
aligned_string GetOptionalAlignedDecodedDatum(const TestData &data, const char *name)
{
aligned_string s;
if (DataExists(data, name))
PutDecodedDatumInto(data, name, AlignedStringSink(s).Ref());
return s;
}
class TestDataNameValuePairs : public NameValuePairs class TestDataNameValuePairs : public NameValuePairs
{ {
public: public:
@ -384,11 +413,11 @@ void TestAsymmetricCipher(TestData &v)
void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters) void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
{ {
std::string name = GetRequiredDatum(v, "Name"); const std::string name = GetRequiredDatum(v, "Name");
std::string test = GetRequiredDatum(v, "Test"); const std::string test = GetRequiredDatum(v, "Test");
std::string key = GetDecodedDatum(v, "Key"); const aligned_string key = GetAlignedDecodedDatum(v, "Key");
std::string plaintext = GetDecodedDatum(v, "Plaintext"); const aligned_string plaintext = GetAlignedDecodedDatum(v, "Plaintext");
TestDataNameValuePairs testDataPairs(v); TestDataNameValuePairs testDataPairs(v);
CombinedNameValuePairs pairs(overrideParameters, testDataPairs); CombinedNameValuePairs pairs(overrideParameters, testDataPairs);
@ -446,16 +475,17 @@ void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
// If overrideParameters are specified, the caller is responsible for managing the parameter. // If overrideParameters are specified, the caller is responsible for managing the parameter.
v.erase("Tweak"); v.erase("BlockSize"); v.erase("BlockPaddingScheme"); v.erase("Tweak"); v.erase("BlockSize"); v.erase("BlockPaddingScheme");
std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest; // std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest;
aligned_string encrypted, xorDigest, ciphertext, ciphertextXorDigest;
if (test == "EncryptionMCT" || test == "DecryptionMCT") if (test == "EncryptionMCT" || test == "DecryptionMCT")
{ {
SymmetricCipher *cipher = encryptor.get(); SymmetricCipher *cipher = encryptor.get();
SecByteBlock buf((byte *)plaintext.data(), plaintext.size()), keybuf((byte *)key.data(), key.size()); AlignedSecByteBlock buf((byte *)plaintext.data(), plaintext.size()), keybuf((byte *)key.data(), key.size());
if (test == "DecryptionMCT") if (test == "DecryptionMCT")
{ {
cipher = decryptor.get(); cipher = decryptor.get();
ciphertext = GetDecodedDatum(v, "Ciphertext"); ciphertext = GetAlignedDecodedDatum(v, "Ciphertext");
buf.Assign((byte *)ciphertext.data(), ciphertext.size()); buf.Assign((byte *)ciphertext.data(), ciphertext.size());
} }
@ -473,11 +503,11 @@ void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
cipher->SetKey(keybuf, keybuf.size()); cipher->SetKey(keybuf, keybuf.size());
} }
encrypted.assign((char *)buf.begin(), buf.size()); encrypted.assign((char *)buf.begin(), buf.size());
ciphertext = GetDecodedDatum(v, test == "EncryptionMCT" ? "Ciphertext" : "Plaintext"); ciphertext = GetAlignedDecodedDatum(v, test == "EncryptionMCT" ? "Ciphertext" : "Plaintext");
if (encrypted != ciphertext) if (encrypted != ciphertext)
{ {
std::cout << "\nincorrectly encrypted: "; std::cout << "\nincorrectly encrypted: ";
StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout))); StringSource xx(reinterpret_cast<const byte*>(encrypted.data()), encrypted.size(), false, new HexEncoder(new FileSink(std::cout)));
xx.Pump(256); xx.Flush(false); xx.Pump(256); xx.Flush(false);
std::cout << "\n"; std::cout << "\n";
SignalTestFailure(); SignalTestFailure();
@ -485,7 +515,7 @@ void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
return; return;
} }
StreamTransformationFilter encFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter encFilter(*encryptor, new AlignedStringSink(encrypted),
static_cast<BlockPaddingSchemeDef::BlockPaddingScheme>(paddingScheme)); static_cast<BlockPaddingSchemeDef::BlockPaddingScheme>(paddingScheme));
RandomizedTransfer(StringStore(plaintext).Ref(), encFilter, true); RandomizedTransfer(StringStore(plaintext).Ref(), encFilter, true);
encFilter.MessageEnd(); encFilter.MessageEnd();
@ -500,10 +530,10 @@ void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
CRYPTOPP_ASSERT(encrypted[i] == z[i]); CRYPTOPP_ASSERT(encrypted[i] == z[i]);
}*/ }*/
if (test != "EncryptXorDigest") if (test != "EncryptXorDigest")
ciphertext = GetDecodedDatum(v, "Ciphertext"); ciphertext = GetAlignedDecodedDatum(v, "Ciphertext");
else else
{ {
ciphertextXorDigest = GetDecodedDatum(v, "CiphertextXorDigest"); ciphertextXorDigest = GetAlignedDecodedDatum(v, "CiphertextXorDigest");
xorDigest.append(encrypted, 0, 64); xorDigest.append(encrypted, 0, 64);
for (size_t i=64; i<encrypted.size(); i++) for (size_t i=64; i<encrypted.size(); i++)
xorDigest[i%64] ^= encrypted[i]; xorDigest[i%64] ^= encrypted[i];
@ -511,20 +541,20 @@ void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
if (test != "EncryptXorDigest" ? encrypted != ciphertext : xorDigest != ciphertextXorDigest) if (test != "EncryptXorDigest" ? encrypted != ciphertext : xorDigest != ciphertextXorDigest)
{ {
std::cout << "\nincorrectly encrypted: "; std::cout << "\nincorrectly encrypted: ";
StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout))); StringSource xx(reinterpret_cast<const byte*>(encrypted.data()), encrypted.size(), false, new HexEncoder(new FileSink(std::cout)));
xx.Pump(2048); xx.Flush(false); xx.Pump(2048); xx.Flush(false);
std::cout << "\n"; std::cout << "\n";
SignalTestFailure(); SignalTestFailure();
} }
std::string decrypted; aligned_string decrypted;
StreamTransformationFilter decFilter(*decryptor, new StringSink(decrypted), StreamTransformationFilter decFilter(*decryptor, new AlignedStringSink(decrypted),
static_cast<BlockPaddingSchemeDef::BlockPaddingScheme>(paddingScheme)); static_cast<BlockPaddingSchemeDef::BlockPaddingScheme>(paddingScheme));
RandomizedTransfer(StringStore(encrypted).Ref(), decFilter, true); RandomizedTransfer(StringStore(encrypted).Ref(), decFilter, true);
decFilter.MessageEnd(); decFilter.MessageEnd();
if (decrypted != plaintext) if (decrypted != plaintext)
{ {
std::cout << "\nincorrectly decrypted: "; std::cout << "\nincorrectly decrypted: ";
StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout))); StringSource xx(reinterpret_cast<const byte*>(decrypted.data()), decrypted.size(), false, new HexEncoder(new FileSink(std::cout)));
xx.Pump(256); xx.Flush(false); xx.Pump(256); xx.Flush(false);
std::cout << "\n"; std::cout << "\n";
SignalTestFailure(); SignalTestFailure();
@ -542,13 +572,13 @@ void TestAuthenticatedSymmetricCipher(TestData &v, const NameValuePairs &overrid
std::string type = GetRequiredDatum(v, "AlgorithmType"); std::string type = GetRequiredDatum(v, "AlgorithmType");
std::string name = GetRequiredDatum(v, "Name"); std::string name = GetRequiredDatum(v, "Name");
std::string test = GetRequiredDatum(v, "Test"); std::string test = GetRequiredDatum(v, "Test");
std::string key = GetDecodedDatum(v, "Key"); aligned_string key = GetAlignedDecodedDatum(v, "Key");
std::string plaintext = GetOptionalDecodedDatum(v, "Plaintext"); aligned_string plaintext = GetOptionalAlignedDecodedDatum(v, "Plaintext");
std::string ciphertext = GetOptionalDecodedDatum(v, "Ciphertext"); aligned_string ciphertext = GetOptionalAlignedDecodedDatum(v, "Ciphertext");
std::string header = GetOptionalDecodedDatum(v, "Header"); aligned_string header = GetOptionalAlignedDecodedDatum(v, "Header");
std::string footer = GetOptionalDecodedDatum(v, "Footer"); aligned_string footer = GetOptionalAlignedDecodedDatum(v, "Footer");
std::string mac = GetOptionalDecodedDatum(v, "MAC"); aligned_string mac = GetOptionalAlignedDecodedDatum(v, "MAC");
TestDataNameValuePairs testDataPairs(v); TestDataNameValuePairs testDataPairs(v);
CombinedNameValuePairs pairs(overrideParameters, testDataPairs); CombinedNameValuePairs pairs(overrideParameters, testDataPairs);