From 39ad75314b9e28053f568ed6a4109f5d3a9468fe Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 18 May 2016 17:23:09 +0100 Subject: [PATCH] Implement decrypting inbound group messages Includes creation of inbound sessions, etc --- include/olm/error.h | 3 + include/olm/inbound_group_session.h | 153 +++++++++++++++++++++ include/olm/message.h | 24 ++++ include/olm/olm.h | 1 + src/inbound_group_session.c | 199 ++++++++++++++++++++++++++++ src/message.cpp | 42 ++++++ tests/test_group_session.cpp | 42 +++++- tests/test_message.cpp | 22 +++ 8 files changed, 480 insertions(+), 6 deletions(-) create mode 100644 include/olm/inbound_group_session.h create mode 100644 src/inbound_group_session.c diff --git a/include/olm/error.h b/include/olm/error.h index 87e019a..3f74992 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -32,6 +32,9 @@ enum OlmErrorCode { OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */ OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */ + OLM_BAD_RATCHET_KEY = 11, + OLM_BAD_CHAIN_INDEX = 12, + /* remember to update the list of string constants in error.c when updating * this list. */ }; diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h new file mode 100644 index 0000000..4cf4ac4 --- /dev/null +++ b/include/olm/inbound_group_session.h @@ -0,0 +1,153 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef OLM_INBOUND_GROUP_SESSION_H_ +#define OLM_INBOUND_GROUP_SESSION_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct OlmInboundGroupSession OlmInboundGroupSession; + +/** get the size of an inbound group session, in bytes. */ +size_t olm_inbound_group_session_size(); + +/** + * Initialise an inbound group session object using the supplied memory + * The supplied memory should be at least olm_inbound_group_session_size() + * bytes. + */ +OlmInboundGroupSession * olm_inbound_group_session( + void *memory +); + +/** + * A null terminated string describing the most recent error to happen to a + * group session */ +const char *olm_inbound_group_session_last_error( + const OlmInboundGroupSession *session +); + +/** Clears the memory used to back this group session */ +size_t olm_clear_inbound_group_session( + OlmInboundGroupSession *session +); + +/** Returns the number of bytes needed to store an inbound group session */ +size_t olm_pickle_inbound_group_session_length( + const OlmInboundGroupSession *session +); + +/** + * Stores a group session as a base64 string. Encrypts the session using the + * supplied key. Returns the length of the session on success. + * + * Returns olm_error() on failure. If the pickle output buffer + * is smaller than olm_pickle_inbound_group_session_length() then + * olm_inbound_group_session_last_error() will be "OUTPUT_BUFFER_TOO_SMALL" + */ +size_t olm_pickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + +/** + * Loads a group session from a pickled base64 string. Decrypts the session + * using the supplied key. + * + * Returns olm_error() on failure. If the key doesn't match the one used to + * encrypt the account then olm_inbound_group_session_last_error() will be + * "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then + * olm_inbound_group_session_last_error() will be "INVALID_BASE64". The input + * pickled buffer is destroyed + */ +size_t olm_unpickle_inbound_group_session( + OlmInboundGroupSession *session, + void const * key, size_t key_length, + void * pickled, size_t pickled_length +); + + +/** + * Start a new inbound group session, based on the parameters supplied. + * + * Returns olm_error() on failure. On failure last_error will be set with an + * error code. The last_error will be: + * + * * OLM_INVALID_BASE64 if the session_key is not valid base64 + * * OLM_BAD_RATCHET_KEY if the session_key is invalid + */ +size_t olm_init_inbound_group_session( + OlmInboundGroupSession *session, + uint32_t message_index, + + /* base64-encoded key */ + uint8_t const * session_key, size_t session_key_length +); + +/** + * Get an upper bound on the number of bytes of plain-text the decrypt method + * will write for a given input message length. The actual size could be + * different due to padding. + * + * The input message buffer is destroyed. + * + * Returns olm_error() on failure. + */ +size_t olm_group_decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +); + +/** + * Decrypt a message. + * + * The input message buffer is destroyed. + * + * Returns the length of the decrypted plain-text, or olm_error() on failure. + * + * On failure last_error will be set with an error code. The last_error will + * be: + * * OLM_OUTPUT_BUFFER_TOO_SMALL if the plain-text buffer is too small + * * OLM_INVALID_BASE64 if the message is not valid base-64 + * * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported + * version of the protocol + * * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded + * * OLM_BAD_MESSAGE_MAC if the message could not be verified + * * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the + * message's index (ie, it was sent before the ratchet key was shared with + * us) + */ +size_t olm_group_decrypt( + OlmInboundGroupSession *session, + + /* input; note that it will be overwritten with the base64-decoded + message. */ + uint8_t * message, size_t message_length, + + /* output */ + uint8_t * plaintext, size_t max_plaintext_length +); + + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* OLM_INBOUND_GROUP_SESSION_H_ */ diff --git a/include/olm/message.h b/include/olm/message.h index 05fb56c..bd7aec3 100644 --- a/include/olm/message.h +++ b/include/olm/message.h @@ -65,6 +65,30 @@ void _olm_encode_group_message( ); +struct _OlmDecodeGroupMessageResults { + uint8_t version; + const uint8_t *session_id; + size_t session_id_length; + uint32_t chain_index; + int has_chain_index; + const uint8_t *ciphertext; + size_t ciphertext_length; +}; + + +/** + * Reads the message headers from the input buffer. + */ +void _olm_decode_group_message( + const uint8_t *input, size_t input_length, + size_t mac_length, + + /* output structure: updated with results */ + struct _OlmDecodeGroupMessageResults *results +); + + + #ifdef __cplusplus } // extern "C" #endif diff --git a/include/olm/olm.h b/include/olm/olm.h index 00e1f63..dbaf71e 100644 --- a/include/olm/olm.h +++ b/include/olm/olm.h @@ -19,6 +19,7 @@ #include #include +#include "olm/inbound_group_session.h" #include "olm/outbound_group_session.h" #ifdef __cplusplus diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c new file mode 100644 index 0000000..4796414 --- /dev/null +++ b/src/inbound_group_session.c @@ -0,0 +1,199 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "olm/inbound_group_session.h" + +#include + +#include "olm/base64.h" +#include "olm/cipher.h" +#include "olm/error.h" +#include "olm/megolm.h" +#include "olm/message.h" + +#define OLM_PROTOCOL_VERSION 3 + +struct OlmInboundGroupSession { + /** our earliest known ratchet value */ + Megolm initial_ratchet; + + /** The most recent ratchet value */ + Megolm latest_ratchet; + + enum OlmErrorCode last_error; +}; + +size_t olm_inbound_group_session_size() { + return sizeof(OlmInboundGroupSession); +} + +OlmInboundGroupSession * olm_inbound_group_session( + void *memory +) { + OlmInboundGroupSession *session = memory; + olm_clear_inbound_group_session(session); + return session; +} + +const char *olm_inbound_group_session_last_error( + const OlmInboundGroupSession *session +) { + return _olm_error_to_string(session->last_error); +} + +size_t olm_clear_inbound_group_session( + OlmInboundGroupSession *session +) { + memset(session, 0, sizeof(OlmInboundGroupSession)); + return sizeof(OlmInboundGroupSession); +} + +size_t olm_init_inbound_group_session( + OlmInboundGroupSession *session, + uint32_t message_index, + const uint8_t * session_key, size_t session_key_length +) { + uint8_t key_buf[MEGOLM_RATCHET_LENGTH]; + size_t raw_length = _olm_decode_base64_length(session_key_length); + + if (raw_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + if (raw_length != MEGOLM_RATCHET_LENGTH) { + session->last_error = OLM_BAD_RATCHET_KEY; + return (size_t)-1; + } + + _olm_decode_base64(session_key, session_key_length, key_buf); + megolm_init(&session->initial_ratchet, key_buf, message_index); + megolm_init(&session->latest_ratchet, key_buf, message_index); + memset(key_buf, 0, MEGOLM_RATCHET_LENGTH); + + return 0; +} + +size_t olm_group_decrypt_max_plaintext_length( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length +) { + size_t r; + const struct _olm_cipher *cipher = megolm_cipher(); + struct _OlmDecodeGroupMessageResults decoded_results; + + r = _olm_decode_base64(message, message_length, message); + if (r == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return r; + } + + _olm_decode_group_message( + message, message_length, + cipher->ops->mac_length(cipher), + &decoded_results); + + if (decoded_results.version != OLM_PROTOCOL_VERSION) { + session->last_error = OLM_BAD_MESSAGE_VERSION; + return (size_t)-1; + } + + if (!decoded_results.ciphertext) { + session->last_error = OLM_BAD_MESSAGE_FORMAT; + return (size_t)-1; + } + + return cipher->ops->decrypt_max_plaintext_length( + cipher, decoded_results.ciphertext_length); +} + + +size_t olm_group_decrypt( + OlmInboundGroupSession *session, + uint8_t * message, size_t message_length, + uint8_t * plaintext, size_t max_plaintext_length +) { + struct _OlmDecodeGroupMessageResults decoded_results; + const struct _olm_cipher *cipher = megolm_cipher(); + size_t max_length, raw_message_length, r; + Megolm *megolm; + Megolm tmp_megolm; + + raw_message_length = _olm_decode_base64(message, message_length, message); + if (raw_message_length == (size_t)-1) { + session->last_error = OLM_INVALID_BASE64; + return (size_t)-1; + } + + _olm_decode_group_message( + message, raw_message_length, + cipher->ops->mac_length(cipher), + &decoded_results); + + if (decoded_results.version != OLM_PROTOCOL_VERSION) { + session->last_error = OLM_BAD_MESSAGE_VERSION; + return (size_t)-1; + } + + if (!decoded_results.has_chain_index || !decoded_results.session_id + || !decoded_results.ciphertext + ) { + session->last_error = OLM_BAD_MESSAGE_FORMAT; + return (size_t)-1; + } + + max_length = cipher->ops->decrypt_max_plaintext_length( + cipher, + decoded_results.ciphertext_length + ); + if (max_plaintext_length < max_length) { + session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; + return (size_t)-1; + } + + /* pick a megolm instance to use. If we're at or beyond the latest ratchet + * value, use that */ + if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) { + megolm = &session->latest_ratchet; + } else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) { + /* the counter is before our intial ratchet - we can't decode this. */ + session->last_error = OLM_BAD_CHAIN_INDEX; + return (size_t)-1; + } else { + /* otherwise, start from the initial megolm. Take a copy so that we + * don't overwrite the initial megolm */ + tmp_megolm = session->initial_ratchet; + megolm = &tmp_megolm; + } + + megolm_advance_to(megolm, decoded_results.chain_index); + + /* now try checking the mac, and decrypting */ + r = cipher->ops->decrypt( + cipher, + megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH, + message, raw_message_length, + decoded_results.ciphertext, decoded_results.ciphertext_length, + plaintext, max_plaintext_length + ); + + memset(&tmp_megolm, 0, sizeof(tmp_megolm)); + if (r == (size_t)-1) { + session->last_error = OLM_BAD_MESSAGE_MAC; + return r; + } + + return r; +} diff --git a/src/message.cpp b/src/message.cpp index df0c7bb..ec44262 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -363,3 +363,45 @@ void _olm_encode_group_message( pos = encode(pos, COUNTER_TAG, chain_index); pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length); } + +void _olm_decode_group_message( + const uint8_t *input, size_t input_length, + size_t mac_length, + struct _OlmDecodeGroupMessageResults *results +) { + std::uint8_t const * pos = input; + std::uint8_t const * end = input + input_length - mac_length; + std::uint8_t const * unknown = nullptr; + + results->session_id = nullptr; + results->session_id_length = 0; + bool has_chain_index = false; + results->chain_index = 0; + results->ciphertext = nullptr; + results->ciphertext_length = 0; + + if (pos == end) return; + if (input_length < mac_length) return; + results->version = *(pos++); + + while (pos != end) { + pos = decode( + pos, end, GROUP_SESSION_ID_TAG, + results->session_id, results->session_id_length + ); + pos = decode( + pos, end, COUNTER_TAG, + results->chain_index, has_chain_index + ); + pos = decode( + pos, end, CIPHERTEXT_TAG, + results->ciphertext, results->ciphertext_length + ); + if (unknown == pos) { + pos = skip_unknown(pos, end); + } + unknown = pos; + } + + results->has_chain_index = (int)has_chain_index; +} diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index b9fe1ef..5bbdc9d 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -12,6 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "olm/inbound_group_session.h" #include "olm/outbound_group_session.h" #include "unittest.hh" @@ -19,11 +20,10 @@ int main() { { - TestCase test_case("Pickle outbound group"); size_t size = olm_outbound_group_session_size(); - void *memory = alloca(size); + uint8_t memory[size]; OlmOutboundGroupSession *session = olm_outbound_group_session(memory); size_t pickle_length = olm_pickle_outbound_group_session_length(session); @@ -61,9 +61,9 @@ int main() { "0123456789ABDEF0123456789ABCDEF"; - + /* build the outbound session */ size_t size = olm_outbound_group_session_size(); - void *memory = alloca(size); + uint8_t memory[size]; OlmOutboundGroupSession *session = olm_outbound_group_session(memory); assert_equals((size_t)132, @@ -73,18 +73,48 @@ int main() { session, random_bytes, sizeof(random_bytes)); assert_equals((size_t)0, res); + assert_equals(0U, olm_outbound_group_session_message_index(session)); + size_t session_key_len = olm_outbound_group_session_key_length(session); + uint8_t session_key[session_key_len]; + olm_outbound_group_session_key(session, session_key, session_key_len); + + + /* encode the message */ uint8_t plaintext[] = "Message"; size_t plaintext_length = sizeof(plaintext) - 1; size_t msglen = olm_group_encrypt_message_length( session, plaintext_length); - uint8_t *msg = (uint8_t *)alloca(msglen); + uint8_t msg[msglen]; res = olm_group_encrypt(session, plaintext, plaintext_length, msg, msglen); assert_equals(msglen, res); + assert_equals(1U, olm_outbound_group_session_message_index(session)); - // TODO: decode the message + + /* build the inbound session */ + size = olm_inbound_group_session_size(); + uint8_t inbound_session_memory[size]; + OlmInboundGroupSession *inbound_session = + olm_inbound_group_session(inbound_session_memory); + + res = olm_init_inbound_group_session( + inbound_session, 0U, session_key, session_key_len); + assert_equals((size_t)0, res); + + /* decode the message */ + + /* olm_group_decrypt_max_plaintext_length destroys the input so we have to + copy it. */ + uint8_t msgcopy[msglen]; + memcpy(msgcopy, msg, msglen); + size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen); + uint8_t plaintext_buf[size]; + res = olm_group_decrypt(inbound_session, msg, msglen, + plaintext_buf, size); + assert_equals(plaintext_length, res); + assert_equals(plaintext, plaintext_buf, res); } } diff --git a/tests/test_message.cpp b/tests/test_message.cpp index e2385ea..5fec9e0 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -97,4 +97,26 @@ assert_equals(message2, output, 35); assert_equals(output+sizeof(expected)-1, ciphertext_ptr); } /* group message encode test */ +{ + TestCase test_case("Group message decode test"); + + struct _OlmDecodeGroupMessageResults results; + std::uint8_t message[] = + "\x03" + "\x2A\x09sessionid" + "\x10\xc8\x01" + "\x22\x0A" "ciphertext" + "hmacsha2"; + + const uint8_t expected_session_id[] = "sessionid"; + + _olm_decode_group_message(message, sizeof(message)-1, 8, &results); + assert_equals(std::uint8_t(3), results.version); + assert_equals(std::size_t(9), results.session_id_length); + assert_equals(expected_session_id, results.session_id, 9); + assert_equals(1, results.has_chain_index); + assert_equals(std::uint32_t(200), results.chain_index); + assert_equals(std::size_t(10), results.ciphertext_length); + assert_equals(ciphertext, results.ciphertext, 10); +} /* group message decode test */ }