diff --git a/tests/include/unittest.hh b/tests/include/unittest.hh index fb1c90d..437ea77 100644 --- a/tests/include/unittest.hh +++ b/tests/include/unittest.hh @@ -36,11 +36,17 @@ char const * TEST_CASE; template void assert_equals( + const char *file, + unsigned line, + const char *expected_expr, + const char *actual_expr, T const & expected, T const & actual ) { if (expected != actual) { std::cout << "FAILED: " << TEST_CASE << std::endl; + std::cout << file << ":" << line << std::endl; + std::cout << expected_expr << " == " << actual_expr << std::endl; std::cout << "Expected: " << expected << std::endl; std::cout << "Actual: " << actual << std::endl; std::exit(1); @@ -49,18 +55,27 @@ void assert_equals( void assert_equals( + const char *file, + unsigned line, + const char *expected_expr, + const char *actual_expr, std::uint8_t const * expected, std::uint8_t const * actual, std::size_t length ) { if (std::memcmp(expected, actual, length)) { std::cout << "FAILED: " << TEST_CASE << std::endl; + std::cout << file << ":" << line << std::endl; + std::cout << expected_expr << " == " << actual_expr << std::endl; print_hex(std::cout << "Expected: ", expected, length) << std::endl; print_hex(std::cout << "Actual: ", actual, length) << std::endl; std::exit(1); } } +#define assert_equals(expected, actual, ...) assert_equals( \ + __FILE__, __LINE__, #expected, #actual, expected, actual, ##__VA_ARGS__ \ +) class TestCase { public: diff --git a/tests/test_axolotl.cpp b/tests/test_axolotl.cpp index 08de1cf..f268f2a 100644 --- a/tests/test_axolotl.cpp +++ b/tests/test_axolotl.cpp @@ -18,9 +18,6 @@ int main() { -{ /* Loopback test case */ -TestCase test_case("Axolotl Loopback 1"); - std::uint8_t root_info[] = "Axolotl"; std::uint8_t ratchet_info[] = "AxolotlRatchet"; std::uint8_t message_info[] = "AxolotlMessageKeys"; @@ -31,31 +28,35 @@ axolotl::KdfInfo kdf_info = { message_info, sizeof(ratchet_info - 1) }; -axolotl::Session alice(kdf_info); -axolotl::Session bob(kdf_info); - std::uint8_t random_bytes[] = "0123456789ABDEF0123456789ABCDEF"; axolotl::Curve25519KeyPair bob_key; axolotl::generate_key(random_bytes, bob_key); std::uint8_t shared_secret[] = "A secret"; +{ /* Send/Receive test case */ +TestCase test_case("Axolotl Send/Receive"); + +axolotl::Session alice(kdf_info); +axolotl::Session bob(kdf_info); + alice.initialise_as_bob(shared_secret, sizeof(shared_secret) - 1, bob_key); bob.initialise_as_alice(shared_secret, sizeof(shared_secret) - 1, bob_key); std::uint8_t plaintext[] = "Message"; std::size_t plaintext_length = sizeof(plaintext) - 1; -std::size_t message_length, random_length, actual_length; +std::size_t message_length, random_length, output_length; std::size_t encrypt_length, decrypt_length; - -message_length = bob.encrypt_max_output_length(plaintext_length); -random_length = bob.encrypt_random_length(); -assert_equals(std::size_t(0), random_length); -actual_length = alice.decrypt_max_plaintext_length(message_length); { + /* Bob sends Alice a message */ + message_length = bob.encrypt_max_output_length(plaintext_length); + random_length = bob.encrypt_random_length(); + assert_equals(std::size_t(0), random_length); + output_length = alice.decrypt_max_plaintext_length(message_length); + std::uint8_t message[message_length]; - std::uint8_t actual[actual_length]; + std::uint8_t output[output_length]; encrypt_length = bob.encrypt( plaintext, plaintext_length, @@ -66,12 +67,108 @@ actual_length = alice.decrypt_max_plaintext_length(message_length); decrypt_length = alice.decrypt( message, message_length, - actual, actual_length + output, output_length ); assert_equals(plaintext_length, decrypt_length); - assert_equals(plaintext, actual, decrypt_length); + assert_equals(plaintext, output, decrypt_length); } -} /* Loopback test case */ + +{ + /* Alice sends Bob a message */ + message_length = alice.encrypt_max_output_length(plaintext_length); + random_length = alice.encrypt_random_length(); + assert_equals(std::size_t(32), random_length); + output_length = bob.decrypt_max_plaintext_length(message_length); + + std::uint8_t message[message_length]; + std::uint8_t output[output_length]; + std::uint8_t random[] = "This is a random 32 byte string."; + + encrypt_length = alice.encrypt( + plaintext, plaintext_length, + random, 32, + message, message_length + ); + assert_equals(message_length, encrypt_length); + + decrypt_length = bob.decrypt( + message, message_length, + output, output_length + ); + assert_equals(plaintext_length, decrypt_length); + assert_equals(plaintext, output, decrypt_length); +} + +} /* Send/receive message test case */ + +{ /* Out of order test case */ + +TestCase test_case("Axolotl Out of Order"); + +axolotl::Session alice(kdf_info); +axolotl::Session bob(kdf_info); + +alice.initialise_as_bob(shared_secret, sizeof(shared_secret) - 1, bob_key); +bob.initialise_as_alice(shared_secret, sizeof(shared_secret) - 1, bob_key); + +std::uint8_t plaintext_1[] = "First Message"; +std::size_t plaintext_1_length = sizeof(plaintext_1) - 1; + +std::uint8_t plaintext_2[] = "Second Messsage. A bit longer than the first."; +std::size_t plaintext_2_length = sizeof(plaintext_2) - 1; + +std::size_t message_1_length, message_2_length, random_length, output_length; +std::size_t encrypt_length, decrypt_length; + +{ + /* Alice sends Bob two messages and they arrive out of order */ + message_1_length = alice.encrypt_max_output_length(plaintext_1_length); + random_length = alice.encrypt_random_length(); + assert_equals(std::size_t(32), random_length); + + std::uint8_t message_1[message_1_length]; + std::uint8_t random[] = "This is a random 32 byte string."; + encrypt_length = alice.encrypt( + plaintext_1, plaintext_1_length, + random, 32, + message_1, message_1_length + ); + assert_equals(message_1_length, encrypt_length); + + message_2_length = alice.encrypt_max_output_length(plaintext_2_length); + random_length = alice.encrypt_random_length(); + assert_equals(std::size_t(0), random_length); + + std::uint8_t message_2[message_2_length]; + encrypt_length = alice.encrypt( + plaintext_2, plaintext_2_length, + NULL, 0, + message_2, message_2_length + ); + assert_equals(message_2_length, encrypt_length); + + output_length = bob.decrypt_max_plaintext_length(message_2_length); + std::uint8_t output_1[output_length]; + decrypt_length = bob.decrypt( + message_2, message_2_length, + output_1, output_length + ); + assert_equals(plaintext_2_length, decrypt_length); + assert_equals(plaintext_2, output_1, decrypt_length); + + output_length = bob.decrypt_max_plaintext_length(message_1_length); + std::uint8_t output_2[output_length]; + decrypt_length = bob.decrypt( + message_1, message_1_length, + output_2, output_length + ); + + assert_equals(plaintext_1_length, decrypt_length); + assert_equals(plaintext_1, output_2, decrypt_length); +} + +} /* Out of order test case */ + }