diff --git a/include/olm/error.h b/include/olm/error.h index 1c44de8..9d44a94 100644 --- a/include/olm/error.h +++ b/include/olm/error.h @@ -46,6 +46,11 @@ enum OlmErrorCode { */ OLM_BAD_LEGACY_ACCOUNT_PICKLE = 13, + /** + * Received message had a bad signature + */ + OLM_BAD_SIGNATURE = 14, + /* remember to update the list of string constants in error.c when updating * this list. */ }; diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h index e24f377..49992b2 100644 --- a/include/olm/inbound_group_session.h +++ b/include/olm/inbound_group_session.h @@ -97,7 +97,7 @@ size_t olm_init_inbound_group_session( OlmInboundGroupSession *session, uint32_t message_index, - /* base64-encoded key */ + /* base64-encoded keys */ uint8_t const * session_key, size_t session_key_length ); diff --git a/include/olm/message.h b/include/olm/message.h index 5eb504d..61012c9 100644 --- a/include/olm/message.h +++ b/include/olm/message.h @@ -37,7 +37,8 @@ extern "C" { size_t _olm_encode_group_message_length( uint32_t chain_index, size_t ciphertext_length, - size_t mac_length + size_t mac_length, + size_t signature_length ); /** @@ -49,7 +50,8 @@ size_t _olm_encode_group_message_length( * output: where to write the output. Should be at least * olm_encode_group_message_length() bytes long. * ciphertext_ptr: returns the address that the ciphertext - * should be written to, followed by the MAC. + * should be written to, followed by the MAC and the + * signature. * * Returns the size of the message, up to the MAC. */ @@ -76,7 +78,7 @@ struct _OlmDecodeGroupMessageResults { */ void _olm_decode_group_message( const uint8_t *input, size_t input_length, - size_t mac_length, + size_t mac_length, size_t signature_length, /* output structure: updated with results */ struct _OlmDecodeGroupMessageResults *results diff --git a/src/error.c b/src/error.c index b742197..f541a93 100644 --- a/src/error.c +++ b/src/error.c @@ -30,6 +30,7 @@ static const char * ERRORS[] = { "BAD_SESSION_KEY", "UNKNOWN_MESSAGE_INDEX", "BAD_LEGACY_ACCOUNT_PICKLE", + "BAD_SIGNATURE", }; const char * _olm_error_to_string(enum OlmErrorCode error) diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index ce26033..11e3dbe 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -19,6 +19,7 @@ #include "olm/base64.h" #include "olm/cipher.h" +#include "olm/crypto.h" #include "olm/error.h" #include "olm/megolm.h" #include "olm/memory.h" @@ -29,6 +30,7 @@ #define OLM_PROTOCOL_VERSION 3 #define PICKLE_VERSION 1 +#define SESSION_KEY_VERSION 1 struct OlmInboundGroupSession { /** our earliest known ratchet value */ @@ -37,6 +39,9 @@ struct OlmInboundGroupSession { /** The most recent ratchet value */ Megolm latest_ratchet; + /** The ed25519 signing key */ + struct _olm_ed25519_public_key signing_key; + enum OlmErrorCode last_error; }; @@ -65,30 +70,56 @@ size_t olm_clear_inbound_group_session( return sizeof(OlmInboundGroupSession); } +#define SESSION_KEY_RAW_LENGTH \ + (1 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH) + +/** init the session keys from the un-base64-ed session keys */ +static size_t _init_group_session_keys( + OlmInboundGroupSession *session, + uint32_t message_index, + const uint8_t *key_buf +) { + const uint8_t *ptr = key_buf; + size_t version = *ptr++; + + if (version != SESSION_KEY_VERSION) { + session->last_error = OLM_BAD_SESSION_KEY; + return (size_t)-1; + } + + megolm_init(&session->initial_ratchet, ptr, message_index); + megolm_init(&session->latest_ratchet, ptr, message_index); + ptr += MEGOLM_RATCHET_LENGTH; + memcpy( + session->signing_key.public_key, ptr, ED25519_PUBLIC_KEY_LENGTH + ); + ptr += ED25519_PUBLIC_KEY_LENGTH; + return 0; +} + size_t olm_init_inbound_group_session( OlmInboundGroupSession *session, uint32_t message_index, const uint8_t * session_key, size_t session_key_length ) { - uint8_t key_buf[MEGOLM_RATCHET_LENGTH]; + uint8_t key_buf[SESSION_KEY_RAW_LENGTH]; size_t raw_length = _olm_decode_base64_length(session_key_length); + size_t result; if (raw_length == (size_t)-1) { session->last_error = OLM_INVALID_BASE64; return (size_t)-1; } - if (raw_length != MEGOLM_RATCHET_LENGTH) { + if (raw_length != SESSION_KEY_RAW_LENGTH) { session->last_error = OLM_BAD_SESSION_KEY; return (size_t)-1; } _olm_decode_base64(session_key, session_key_length, key_buf); - megolm_init(&session->initial_ratchet, key_buf, message_index); - megolm_init(&session->latest_ratchet, key_buf, message_index); - _olm_unset(key_buf, MEGOLM_RATCHET_LENGTH); - - return 0; + result = _init_group_session_keys(session, message_index, key_buf); + _olm_unset(key_buf, SESSION_KEY_RAW_LENGTH); + return result; } static size_t raw_pickle_length( @@ -98,6 +129,7 @@ static size_t raw_pickle_length( length += _olm_pickle_uint32_length(PICKLE_VERSION); length += megolm_pickle_length(&session->initial_ratchet); length += megolm_pickle_length(&session->latest_ratchet); + length += _olm_pickle_ed25519_public_key_length(&session->signing_key); return length; } @@ -124,6 +156,7 @@ size_t olm_pickle_inbound_group_session( pos = _olm_pickle_uint32(pos, PICKLE_VERSION); pos = megolm_pickle(&session->initial_ratchet, pos); pos = megolm_pickle(&session->latest_ratchet, pos); + pos = _olm_pickle_ed25519_public_key(pos, &session->signing_key); return _olm_enc_output(key, key_length, pickled, raw_length); } @@ -153,6 +186,7 @@ size_t olm_unpickle_inbound_group_session( } pos = megolm_unpickle(&session->initial_ratchet, pos, end); pos = megolm_unpickle(&session->latest_ratchet, pos, end); + pos = _olm_unpickle_ed25519_public_key(pos, end, &session->signing_key); if (end != pos) { /* We had the wrong number of bytes in the input. */ @@ -175,6 +209,7 @@ static size_t _decrypt_max_plaintext_length( _olm_decode_group_message( message, message_length, megolm_cipher->ops->mac_length(megolm_cipher), + ED25519_SIGNATURE_LENGTH, &decoded_results); if (decoded_results.version != OLM_PROTOCOL_VERSION) { @@ -224,6 +259,7 @@ static size_t _decrypt( _olm_decode_group_message( message, message_length, megolm_cipher->ops->mac_length(megolm_cipher), + ED25519_SIGNATURE_LENGTH, &decoded_results); if (decoded_results.version != OLM_PROTOCOL_VERSION) { @@ -231,11 +267,28 @@ static size_t _decrypt( return (size_t)-1; } - if (!decoded_results.has_message_index || !decoded_results.ciphertext ) { + if (!decoded_results.has_message_index || !decoded_results.ciphertext) { session->last_error = OLM_BAD_MESSAGE_FORMAT; return (size_t)-1; } + /* verify the signature. We could do this before decoding the message, but + * we allow for the possibility of future protocol versions which use a + * different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION" + * than "BAD_SIGNATURE" in this case. + */ + message_length -= ED25519_SIGNATURE_LENGTH; + r = _olm_crypto_ed25519_verify( + &session->signing_key, + message, message_length, + message + message_length + ); + if (!r) { + session->last_error = OLM_BAD_SIGNATURE; + return (size_t)-1; + } + + max_length = megolm_cipher->ops->decrypt_max_plaintext_length( megolm_cipher, decoded_results.ciphertext_length diff --git a/src/message.cpp b/src/message.cpp index ad26cb9..05fe2c7 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -334,12 +334,14 @@ static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022; size_t _olm_encode_group_message_length( uint32_t message_index, size_t ciphertext_length, - size_t mac_length + size_t mac_length, + size_t signature_length ) { size_t length = VERSION_LENGTH; length += 1 + varint_length(message_index); length += 1 + varstring_length(ciphertext_length); length += mac_length; + length += signature_length; return length; } @@ -361,11 +363,12 @@ size_t _olm_encode_group_message( void _olm_decode_group_message( const uint8_t *input, size_t input_length, - size_t mac_length, + size_t mac_length, size_t signature_length, struct _OlmDecodeGroupMessageResults *results ) { std::uint8_t const * pos = input; - std::uint8_t const * end = input + input_length - mac_length; + std::size_t trailer_length = mac_length + signature_length; + std::uint8_t const * end = input + input_length - trailer_length; std::uint8_t const * unknown = nullptr; bool has_message_index = false; @@ -373,8 +376,7 @@ void _olm_decode_group_message( results->ciphertext = nullptr; results->ciphertext_length = 0; - if (pos == end) return; - if (input_length < mac_length) return; + if (input_length < trailer_length) return; results->version = *(pos++); while (pos != end) { diff --git a/src/outbound_group_session.c b/src/outbound_group_session.c index 2a6c220..a605fa6 100644 --- a/src/outbound_group_session.c +++ b/src/outbound_group_session.c @@ -20,6 +20,7 @@ #include "olm/base64.h" #include "olm/cipher.h" +#include "olm/crypto.h" #include "olm/error.h" #include "olm/megolm.h" #include "olm/memory.h" @@ -31,11 +32,15 @@ #define SESSION_ID_RANDOM_BYTES 4 #define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES) #define PICKLE_VERSION 1 +#define SESSION_KEY_VERSION 1 struct OlmOutboundGroupSession { /** the Megolm ratchet providing the encryption keys */ Megolm ratchet; + /** The ed25519 keypair used for signing the messages */ + struct _olm_ed25519_key_pair signing_key; + /** unique identifier for this session */ uint8_t session_id[GROUP_SESSION_ID_LENGTH]; @@ -74,6 +79,7 @@ static size_t raw_pickle_length( size_t length = 0; length += _olm_pickle_uint32_length(PICKLE_VERSION); length += megolm_pickle_length(&(session->ratchet)); + length += _olm_pickle_ed25519_key_pair_length(&(session->signing_key)); length += _olm_pickle_bytes_length(session->session_id, GROUP_SESSION_ID_LENGTH); return length; @@ -101,6 +107,7 @@ size_t olm_pickle_outbound_group_session( pos = _olm_enc_output_pos(pickled, raw_length); pos = _olm_pickle_uint32(pos, PICKLE_VERSION); pos = megolm_pickle(&(session->ratchet), pos); + pos = _olm_pickle_ed25519_key_pair(pos, &(session->signing_key)); pos = _olm_pickle_bytes(pos, session->session_id, GROUP_SESSION_ID_LENGTH); return _olm_enc_output(key, key_length, pickled, raw_length); @@ -130,6 +137,7 @@ size_t olm_unpickle_outbound_group_session( return (size_t)-1; } pos = megolm_unpickle(&(session->ratchet), pos, end); + pos = _olm_unpickle_ed25519_key_pair(pos, end, &(session->signing_key)); pos = _olm_unpickle_bytes(pos, end, session->session_id, GROUP_SESSION_ID_LENGTH); if (end != pos) { @@ -148,7 +156,9 @@ size_t olm_init_outbound_group_session_random_length( /* we need data to initialize the megolm ratchet, plus some more for the * session id. */ - return MEGOLM_RATCHET_LENGTH + SESSION_ID_RANDOM_BYTES; + return MEGOLM_RATCHET_LENGTH + + ED25519_RANDOM_LENGTH + + SESSION_ID_RANDOM_BYTES; } size_t olm_init_outbound_group_session( @@ -164,6 +174,9 @@ size_t olm_init_outbound_group_session( megolm_init(&(session->ratchet), random, 0); random += MEGOLM_RATCHET_LENGTH; + _olm_crypto_ed25519_generate_key(random, &(session->signing_key)); + random += ED25519_RANDOM_LENGTH; + /* initialise the session id. This just has to be unique. We use the * current time plus some random data. */ @@ -188,7 +201,8 @@ static size_t raw_message_length( return _olm_encode_group_message_length( session->ratchet.counter, - ciphertext_length, mac_length); + ciphertext_length, mac_length, ED25519_SIGNATURE_LENGTH + ); } size_t olm_group_encrypt_message_length( @@ -241,6 +255,13 @@ static size_t _encrypt( megolm_advance(&(session->ratchet)); + /* sign the whole thing with the ed25519 key. */ + _olm_crypto_ed25519_sign( + &(session->signing_key), + buffer, message_length, + buffer + message_length + ); + return result; } @@ -302,23 +323,40 @@ uint32_t olm_outbound_group_session_message_index( return session->ratchet.counter; } +#define SESSION_KEY_RAW_LENGTH \ + (1 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH) + size_t olm_outbound_group_session_key_length( const OlmOutboundGroupSession *session ) { - return _olm_encode_base64_length(MEGOLM_RATCHET_LENGTH); + return _olm_encode_base64_length(SESSION_KEY_RAW_LENGTH); } size_t olm_outbound_group_session_key( OlmOutboundGroupSession *session, uint8_t * key, size_t key_length ) { - if (key_length < olm_outbound_group_session_key_length(session)) { + uint8_t *raw; + uint8_t *ptr; + size_t encoded_length = olm_outbound_group_session_key_length(session); + + if (key_length < encoded_length) { session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; return (size_t)-1; } - return _olm_encode_base64( - megolm_get_data(&session->ratchet), - MEGOLM_RATCHET_LENGTH, key + /* put the raw data at the end of the output buffer. */ + raw = ptr = key + encoded_length - SESSION_KEY_RAW_LENGTH; + *ptr++ = SESSION_KEY_VERSION; + + memcpy(ptr, megolm_get_data(&session->ratchet), MEGOLM_RATCHET_LENGTH); + ptr += MEGOLM_RATCHET_LENGTH; + + memcpy( + ptr, session->signing_key.public_key.public_key, + ED25519_PUBLIC_KEY_LENGTH ); + ptr += ED25519_PUBLIC_KEY_LENGTH; + + return _olm_encode_base64(raw, SESSION_KEY_RAW_LENGTH, key); } diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index 4a82154..7ac91c3 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -80,7 +80,6 @@ int main() { assert_equals(pickle1, pickle2, pickle_length); } - { TestCase test_case("Group message send/receive"); @@ -89,6 +88,7 @@ int main() { "0123456789ABDEF0123456789ABCDEF" "0123456789ABDEF0123456789ABCDEF" "0123456789ABDEF0123456789ABCDEF" + "0123456789ABDEF0123456789ABCDEF" "0123456789ABDEF0123456789ABCDEF"; @@ -97,7 +97,7 @@ int main() { uint8_t memory[size]; OlmOutboundGroupSession *session = olm_outbound_group_session(memory); - assert_equals((size_t)132, + assert_equals((size_t)164, olm_init_outbound_group_session_random_length(session)); size_t res = olm_init_outbound_group_session( @@ -109,7 +109,6 @@ int main() { uint8_t session_key[session_key_len]; olm_outbound_group_session_key(session, session_key, session_key_len); - /* encode the message */ uint8_t plaintext[] = "Message"; size_t plaintext_length = sizeof(plaintext) - 1; @@ -148,4 +147,73 @@ int main() { assert_equals(plaintext, plaintext_buf, res); } +{ + TestCase test_case("Invalid signature group message"); + + uint8_t plaintext[] = "Message"; + size_t plaintext_length = sizeof(plaintext) - 1; + + uint8_t session_key[] = + "ATAxMjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzNDU2Nzg5QUJERUYw" + "MTIzNDU2Nzg5QUJDREVGMDEyMzQ1Njc4OUFCREVGMDEyMzQ1Njc4OUFCQ0RFRjAx" + "MjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzDRt2DUEOrg/H+yUGjDTq" + "ryf8H1YF/BZjI04HwOVSZcY"; + + uint8_t message[] = + "AwgAEhAcbh6UpbByoyZxufQ+h2B+8XHMjhR69G8F4+qjMaFlnIXusJZX3r8LnROR" + "G9T3DXFdbVuvIWrLyRfm4i8QRbe8VPwGRFG57B1CtmxanuP8bHtnnYqlwPsD"; + size_t msglen = sizeof(message)-1; + + /* build the inbound session */ + size_t size = olm_inbound_group_session_size(); + uint8_t inbound_session_memory[size]; + OlmInboundGroupSession *inbound_session = + olm_inbound_group_session(inbound_session_memory); + + size_t res = olm_init_inbound_group_session( + inbound_session, 0U, session_key, sizeof(session_key)-1 + ); + assert_equals((size_t)0, res); + + /* decode the message */ + + /* olm_group_decrypt_max_plaintext_length destroys the input so we have to + copy it. */ + uint8_t msgcopy[msglen]; + memcpy(msgcopy, message, msglen); + size = olm_group_decrypt_max_plaintext_length( + inbound_session, msgcopy, msglen + ); + + memcpy(msgcopy, message, msglen); + uint8_t plaintext_buf[size]; + res = olm_group_decrypt( + inbound_session, msgcopy, msglen, plaintext_buf, size + ); + assert_equals(plaintext_length, res); + assert_equals(plaintext, plaintext_buf, res); + + /* now twiddle the signature */ + message[msglen-1] = 'E'; + memcpy(msgcopy, message, msglen); + assert_equals( + size, + olm_group_decrypt_max_plaintext_length( + inbound_session, msgcopy, msglen + ) + ); + + memcpy(msgcopy, message, msglen); + res = olm_group_decrypt( + inbound_session, msgcopy, msglen, + plaintext_buf, size + ); + assert_equals((size_t)-1, res); + assert_equals( + std::string("BAD_SIGNATURE"), + std::string(olm_inbound_group_session_last_error(inbound_session)) + ); +} + + } diff --git a/tests/test_message.cpp b/tests/test_message.cpp index 06b36dc..25693f5 100644 --- a/tests/test_message.cpp +++ b/tests/test_message.cpp @@ -67,8 +67,8 @@ assert_equals(message2, output, 35); TestCase test_case("Group message encode test"); - size_t length = _olm_encode_group_message_length(200, 10, 8); - size_t expected_length = 1 + (1+2) + (2+10) + 8; + size_t length = _olm_encode_group_message_length(200, 10, 8, 64); + size_t expected_length = 1 + (1+2) + (2+10) + 8 + 64; assert_equals(expected_length, length); uint8_t output[50]; @@ -99,9 +99,10 @@ assert_equals(message2, output, 35); "\x03" "\x08\xC8\x01" "\x12\x0A" "ciphertext" - "hmacsha2"; + "hmacsha2" + "ed25519signature"; - _olm_decode_group_message(message, sizeof(message)-1, 8, &results); + _olm_decode_group_message(message, sizeof(message)-1, 8, 16, &results); assert_equals(std::uint8_t(3), results.version); assert_equals(1, results.has_message_index); assert_equals(std::uint32_t(200), results.message_index);