diff --git a/include/axolotl/message.hh b/include/axolotl/message.hh index cfbb715..2b9bc99 100644 --- a/include/axolotl/message.hh +++ b/include/axolotl/message.hh @@ -37,6 +37,7 @@ struct MessageWriter { struct MessageReader { std::uint8_t version; + bool has_counter; std::uint32_t counter; std::uint8_t const * input; std::size_t input_length; std::uint8_t const * ratchet_key; std::size_t ratchet_key_length; @@ -46,9 +47,8 @@ struct MessageReader { /** * Writes the message headers into the output buffer. - * Returns a writer struct populated with pointers into the output buffer. + * Populates the writer struct with pointers into the output buffer. */ - void encode_message( MessageWriter & writer, std::uint8_t version, @@ -62,13 +62,67 @@ void encode_message( /** * Reads the message headers from the input buffer. * Populates the reader struct with pointers into the input buffer. - * On failure returns std::size_t(-1). */ -std::size_t decode_message( +void decode_message( MessageReader & reader, std::uint8_t const * input, std::size_t input_length, std::size_t mac_length ); +struct PreKeyMessageWriter { + std::uint8_t * identity_key; + std::uint8_t * base_key; + std::uint8_t * message; +}; + + +struct PreKeyMessageReader { + std::uint8_t version; + bool has_registration_id; + bool has_one_time_key_id; + std::uint32_t registration_id; + std::uint32_t one_time_key_id; + std::uint8_t const * identity_key; std::size_t identity_key_length; + std::uint8_t const * base_key; std::size_t base_key_length; + std::uint8_t const * message; std::size_t message_length; +}; + +/** + * The length of the buffer needed to hold a message. + */ +std::size_t encode_one_time_key_message_length( + std::uint32_t registration_id, + std::uint32_t one_time_key_id, + std::size_t identity_key_length, + std::size_t base_key_length, + std::size_t message_length +); + +/** + * Writes the message headers into the output buffer. + * Populates the writer struct with pointers into the output buffer. + */ +void encode_one_time_key_message( + PreKeyMessageWriter & writer, + std::uint8_t version, + std::uint32_t registration_id, + std::uint32_t one_time_key_id, + std::size_t identity_key_length, + std::size_t base_key_length, + std::size_t message_length, + std::uint8_t * output +); + + +/** + * Reads the message headers from the input buffer. + * Populates the reader struct with pointers into the input buffer. + */ +void decode_one_time_key_message( + PreKeyMessageReader & reader, + std::uint8_t const * input, std::size_t input_length +); + + } // namespace axolotl diff --git a/src/message.cpp b/src/message.cpp index 46cadd8..fcedd07 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -81,6 +81,82 @@ static std::uint8_t const RATCHET_KEY_TAG = 012; static std::uint8_t const COUNTER_TAG = 020; static std::uint8_t const CIPHERTEXT_TAG = 042; +std::uint8_t * encode( + std::uint8_t * pos, + std::uint8_t tag, + std::uint32_t value +) { + *(pos++) = tag; + return varint_encode(pos, value); +} + +std::uint8_t * encode( + std::uint8_t * pos, + std::uint8_t tag, + std::uint8_t * & value, std::size_t value_length +) { + *(pos++) = tag; + pos = varint_encode(pos, value_length); + value = pos; + return pos + value_length; +} + +std::uint8_t const * decode( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint8_t tag, + std::uint32_t & value, bool & has_value +) { + if (pos != end && *pos == tag) { + ++pos; + std::uint8_t const * value_start = pos; + pos = varint_skip(pos, end); + value = varint_decode(value_start, pos); + has_value = true; + } + return pos; +} + + +std::uint8_t const * decode( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint8_t tag, + std::uint8_t const * & value, std::size_t & value_length +) { + if (pos != end && *pos == tag) { + ++pos; + std::uint8_t const * len_start = pos; + pos = varint_skip(pos, end); + std::size_t len = varint_decode(len_start, pos); + if (len > end - pos) return end; + value = pos; + value_length = len; + pos += len; + } + return pos; +} + +std::uint8_t const * skip_unknown( + std::uint8_t const * pos, std::uint8_t const * end +) { + if (pos != end) { + uint8_t tag = *pos; + if (tag & 0x7 == 0) { + pos = varint_skip(pos, end); + pos = varint_skip(pos, end); + } else if (tag & 0x7 == 2) { + pos = varint_skip(pos, end); + std::uint8_t const * len_start = pos; + pos = varint_skip(pos, end); + std::size_t len = varint_decode(len_start, pos); + if (len > end - pos) return end; + pos += len; + } else { + return end; + } + } + return pos; +} + } // namespace @@ -109,75 +185,138 @@ void axolotl::encode_message( ) { std::uint8_t * pos = output; *(pos++) = version; - *(pos++) = COUNTER_TAG; - pos = varint_encode(pos, counter); - *(pos++) = RATCHET_KEY_TAG; - pos = varint_encode(pos, ratchet_key_length); - writer.ratchet_key = pos; - pos += ratchet_key_length; - *(pos++) = CIPHERTEXT_TAG; - pos = varint_encode(pos, ciphertext_length); - writer.ciphertext = pos; - pos += ciphertext_length; + pos = encode(pos, RATCHET_KEY_TAG, writer.ratchet_key, ratchet_key_length); + pos = encode(pos, COUNTER_TAG, counter); + pos = encode(pos, CIPHERTEXT_TAG, writer.ciphertext, ciphertext_length); } -std::size_t axolotl::decode_message( +void axolotl::decode_message( axolotl::MessageReader & reader, std::uint8_t const * input, std::size_t input_length, std::size_t mac_length ) { std::uint8_t const * pos = input; std::uint8_t const * end = input + input_length - mac_length; - std::uint8_t flags = 0; - std::size_t result = std::size_t(-1); - if (pos == end) return result; + std::uint8_t const * unknown = NULL; + + if (pos == end) return; reader.version = *(pos++); + reader.input = input; + reader.input_length = input_length; + reader.has_counter = false; + reader.ratchet_key = NULL; + reader.ciphertext = NULL; + while (pos != end) { - uint8_t tag = *(pos); - if (tag == COUNTER_TAG) { - ++pos; - std::uint8_t const * counter_start = pos; - pos = varint_skip(pos, end); - reader.counter = varint_decode(counter_start, pos); - flags |= 1; - } else if (tag == RATCHET_KEY_TAG) { - ++pos; - std::uint8_t const * len_start = pos; - pos = varint_skip(pos, end); - std::size_t len = varint_decode(len_start, pos); - if (len > end - pos) return result; - reader.ratchet_key_length = len; - reader.ratchet_key = pos; - pos += len; - flags |= 2; - } else if (tag == CIPHERTEXT_TAG) { - ++pos; - std::uint8_t const * len_start = pos; - pos = varint_skip(pos, end); - std::size_t len = varint_decode(len_start, pos); - if (len > end - pos) return result; - reader.ciphertext_length = len; - reader.ciphertext = pos; - pos += len; - flags |= 4; - } else if (tag & 0x7 == 0) { - pos = varint_skip(pos, end); - pos = varint_skip(pos, end); - } else if (tag & 0x7 == 2) { - std::uint8_t const * len_start = pos; - pos = varint_skip(pos, end); - std::size_t len = varint_decode(len_start, pos); - if (len > end - pos) return result; - pos += len; - } else { - return std::size_t(-1); + pos = decode( + pos, end, RATCHET_KEY_TAG, + reader.ratchet_key, reader.ratchet_key_length + ); + pos = decode( + pos, end, COUNTER_TAG, + reader.counter, reader.has_counter + ); + pos = decode( + pos, end, CIPHERTEXT_TAG, + reader.ciphertext, reader.ciphertext_length + ); + if (unknown == pos) { + pos == skip_unknown(pos, end); } + unknown = pos; + } +} + + +namespace { + +static std::uint8_t const REGISTRATION_ID_TAG = 050; +static std::uint8_t const ONE_TIME_KEY_ID_TAG = 010; +static std::uint8_t const BASE_KEY_TAG = 022; +static std::uint8_t const IDENTITY_KEY_TAG = 032; +static std::uint8_t const MESSAGE_TAG = 042; + +} // namespace + + +std::size_t axolotl::encode_one_time_key_message_length( + std::uint32_t registration_id, + std::uint32_t one_time_key_id, + std::size_t identity_key_length, + std::size_t base_key_length, + std::size_t message_length +) { + std::size_t length = VERSION_LENGTH; + length += 1 + varint_length(registration_id); + length += 1 + varint_length(one_time_key_id); + length += 1 + varstring_length(identity_key_length); + length += 1 + varstring_length(base_key_length); + length += 1 + varstring_length(message_length); + return length; +} + + +void axolotl::encode_one_time_key_message( + axolotl::PreKeyMessageWriter & writer, + std::uint8_t version, + std::uint32_t registration_id, + std::uint32_t one_time_key_id, + std::size_t identity_key_length, + std::size_t base_key_length, + std::size_t message_length, + std::uint8_t * output +) { + std::uint8_t * pos = output; + *(pos++) = version; + pos = encode(pos, REGISTRATION_ID_TAG, registration_id); + pos = encode(pos, ONE_TIME_KEY_ID_TAG, one_time_key_id); + pos = encode(pos, BASE_KEY_TAG, writer.base_key, base_key_length); + pos = encode(pos, IDENTITY_KEY_TAG, writer.identity_key, identity_key_length); + pos = encode(pos, MESSAGE_TAG, writer.message, message_length); +} + + +void axolotl::decode_one_time_key_message( + PreKeyMessageReader & reader, + std::uint8_t const * input, std::size_t input_length +) { + std::uint8_t const * pos = input; + std::uint8_t const * end = input + input_length; + std::uint8_t const * unknown = NULL; + + if (pos == end) return; + reader.version = *(pos++); + reader.has_registration_id = false; + reader.has_one_time_key_id = false; + reader.identity_key = NULL; + reader.base_key = NULL; + reader.message = NULL; + + while (pos != end) { + pos = decode( + pos, end, REGISTRATION_ID_TAG, + reader.registration_id, reader.has_registration_id + ); + pos = decode( + pos, end, ONE_TIME_KEY_ID_TAG, + reader.one_time_key_id, reader.has_one_time_key_id + ); + pos = decode( + pos, end, BASE_KEY_TAG, + reader.base_key, reader.base_key_length + ); + pos = decode( + pos, end, IDENTITY_KEY_TAG, + reader.identity_key, reader.identity_key_length + ); + pos = decode( + pos, end, MESSAGE_TAG, + reader.message, reader.message_length + ); + if (unknown == pos) { + pos == skip_unknown(pos, end); + } + unknown = pos; } - if (flags == 0x7) { - reader.input = input; - reader.input_length = input_length; - return std::size_t(pos - input); - } - return result; } diff --git a/src/ratchet.cpp b/src/ratchet.cpp index cd4f8f7..cbafa83 100644 --- a/src/ratchet.cpp +++ b/src/ratchet.cpp @@ -444,7 +444,7 @@ std::size_t axolotl::Session::decrypt( } axolotl::MessageReader reader; - std::size_t body_length = axolotl::decode_message( + axolotl::decode_message( reader, input, input_length, ratchet_cipher.mac_length() ); @@ -453,7 +453,12 @@ std::size_t axolotl::Session::decrypt( return std::size_t(-1); } - if (body_length == size_t(-1) || reader.ratchet_key_length != KEY_LENGTH) { + if (!reader.has_counter || !reader.ratchet_key || !reader.ciphertext) { + last_error = axolotl::ErrorCode::BAD_MESSAGE_FORMAT; + return std::size_t(-1); + } + + if (reader.ratchet_key_length != KEY_LENGTH) { last_error = axolotl::ErrorCode::BAD_MESSAGE_FORMAT; return std::size_t(-1); } diff --git a/tests/test_message.cpp b/tests/test_message.cpp index 9c0ab4a..b46b15a 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -17,8 +17,8 @@ int main() { -std::uint8_t message1[36] = "\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2"; -std::uint8_t message2[36] = "\x03\x10\x01\n\nratchetkey\"\nciphertexthmacsha2"; +std::uint8_t message1[36] = "\x03\x10\x01\n\nratchetkey\"\nciphertexthmacsha2"; +std::uint8_t message2[36] = "\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2"; std::uint8_t ratchetkey[11] = "ratchetkey"; std::uint8_t ciphertext[11] = "ciphertext"; std::uint8_t hmacsha2[9] = "hmacsha2"; @@ -31,6 +31,7 @@ axolotl::MessageReader reader; axolotl::decode_message(reader, message1, 35, 8); assert_equals(std::uint8_t(3), reader.version); +assert_equals(true, reader.has_counter); assert_equals(std::uint32_t(1), reader.counter); assert_equals(std::size_t(10), reader.ratchet_key_length); assert_equals(std::size_t(10), reader.ciphertext_length);