diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h index 59146c2..f8a0bc3 100644 --- a/include/olm/inbound_group_session.h +++ b/include/olm/inbound_group_session.h @@ -140,7 +140,8 @@ size_t olm_group_decrypt( uint8_t * message, size_t message_length, /* output */ - uint8_t * plaintext, size_t max_plaintext_length + uint8_t * plaintext, size_t max_plaintext_length, + uint32_t * message_index ); diff --git a/javascript/demo/group_demo.js b/javascript/demo/group_demo.js index 1b8f7ab..42a3d84 100644 --- a/javascript/demo/group_demo.js +++ b/javascript/demo/group_demo.js @@ -403,8 +403,8 @@ DemoUser.prototype.decryptGroup = function(jsonpacket, callback) { throw new Error("Unknown session id " + session_id); } - var plaintext = session.decrypt(packet.body); - done(plaintext); + var result = session.decrypt(packet.body); + done(result.plaintext); }, callback); }; diff --git a/javascript/olm_inbound_group_session.js b/javascript/olm_inbound_group_session.js index 6058233..1b7fcfe 100644 --- a/javascript/olm_inbound_group_session.js +++ b/javascript/olm_inbound_group_session.js @@ -73,10 +73,12 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function( // So we copy the array to a new buffer var message_buffer = stack(message_array); var plaintext_buffer = stack(max_plaintext_length + NULL_BYTE_PADDING_LENGTH); + var message_index = stack(4); var plaintext_length = inbound_group_session_method(Module["_olm_group_decrypt"])( this.ptr, message_buffer, message_array.length, - plaintext_buffer, max_plaintext_length + plaintext_buffer, max_plaintext_length, + message_index ); // Pointer_stringify requires a null-terminated argument (the optional @@ -86,7 +88,10 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function( 0, "i8" ); - return Pointer_stringify(plaintext_buffer); + return { + "plaintext": Pointer_stringify(plaintext_buffer), + "message_index": Module['getValue'](message_index, "i32") + } }); InboundGroupSession.prototype['session_id'] = restore_stack(function() { diff --git a/python/olm/__main__.py b/python/olm/__main__.py index cf9158d..eb76301 100755 --- a/python/olm/__main__.py +++ b/python/olm/__main__.py @@ -328,7 +328,7 @@ def do_group_decrypt(args): session = InboundGroupSession() session.unpickle(args.key, read_base64_file(args.session_file)) message = args.message_file.read() - plaintext = session.decrypt(message) + plaintext, message_index = session.decrypt(message) with open(args.session_file, "wb") as f: f.write(session.pickle(args.key)) args.plaintext_file.write(plaintext) diff --git a/python/olm/inbound_group_session.py b/python/olm/inbound_group_session.py index d5547fd..27a569c 100644 --- a/python/olm/inbound_group_session.py +++ b/python/olm/inbound_group_session.py @@ -43,6 +43,7 @@ inbound_group_session_function( lib.olm_group_decrypt, c_void_p, c_size_t, # message c_void_p, c_size_t, # plaintext + POINTER(c_uint32), # message_index ) inbound_group_session_function(lib.olm_inbound_group_session_id_length) @@ -82,11 +83,14 @@ class InboundGroupSession(object): ) plaintext_buffer = create_string_buffer(max_plaintext_length) message_buffer = create_string_buffer(message) + + message_index = c_uint32() plaintext_length = lib.olm_group_decrypt( self.ptr, message_buffer, len(message), - plaintext_buffer, max_plaintext_length + plaintext_buffer, max_plaintext_length, + byref(message_index) ) - return plaintext_buffer.raw[:plaintext_length] + return plaintext_buffer.raw[:plaintext_length], message_index def session_id(self): id_length = lib.olm_inbound_group_session_id_length(self.ptr) diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c index bf00008..ed313a0 100644 --- a/src/inbound_group_session.c +++ b/src/inbound_group_session.c @@ -263,7 +263,8 @@ size_t olm_group_decrypt_max_plaintext_length( static size_t _decrypt( OlmInboundGroupSession *session, uint8_t * message, size_t message_length, - uint8_t * plaintext, size_t max_plaintext_length + uint8_t * plaintext, size_t max_plaintext_length, + uint32_t * message_index ) { struct _OlmDecodeGroupMessageResults decoded_results; size_t max_length, r; @@ -286,6 +287,8 @@ static size_t _decrypt( return (size_t)-1; } + *message_index = decoded_results.message_index; + /* verify the signature. We could do this before decoding the message, but * we allow for the possibility of future protocol versions which use a * different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION" @@ -349,7 +352,8 @@ static size_t _decrypt( size_t olm_group_decrypt( OlmInboundGroupSession *session, uint8_t * message, size_t message_length, - uint8_t * plaintext, size_t max_plaintext_length + uint8_t * plaintext, size_t max_plaintext_length, + uint32_t * message_index ) { size_t raw_message_length; @@ -361,7 +365,8 @@ size_t olm_group_decrypt( return _decrypt( session, message, raw_message_length, - plaintext, max_plaintext_length + plaintext, max_plaintext_length, + message_index ); } diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp index 9930927..b15875c 100644 --- a/tests/test_group_session.cpp +++ b/tests/test_group_session.cpp @@ -161,8 +161,9 @@ int main() { memcpy(msgcopy, msg, msglen); size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen); uint8_t plaintext_buf[size]; + uint32_t message_index; res = olm_group_decrypt(inbound_session, msg, msglen, - plaintext_buf, size); + plaintext_buf, size, &message_index); assert_equals(plaintext_length, res); assert_equals(plaintext, plaintext_buf, res); } @@ -208,8 +209,9 @@ int main() { memcpy(msgcopy, message, msglen); uint8_t plaintext_buf[size]; + uint32_t message_index; res = olm_group_decrypt( - inbound_session, msgcopy, msglen, plaintext_buf, size + inbound_session, msgcopy, msglen, plaintext_buf, size, &message_index ); assert_equals(plaintext_length, res); assert_equals(plaintext, plaintext_buf, res); @@ -227,7 +229,7 @@ int main() { memcpy(msgcopy, message, msglen); res = olm_group_decrypt( inbound_session, msgcopy, msglen, - plaintext_buf, size + plaintext_buf, size, &message_index ); assert_equals((size_t)-1, res); assert_equals(