diff --git a/fuzzers/fuzz_group_decrypt.cpp b/fuzzers/fuzz_group_decrypt.cpp index 03b7bc4..e32c0ae 100644 --- a/fuzzers/fuzz_group_decrypt.cpp +++ b/fuzzers/fuzz_group_decrypt.cpp @@ -2,6 +2,18 @@ #include "fuzzing.hh" +#ifndef __AFL_FUZZ_TESTCASE_LEN + ssize_t fuzz_len; + #define __AFL_FUZZ_TESTCASE_LEN fuzz_len + unsigned char fuzz_buf[1024000]; + #define __AFL_FUZZ_TESTCASE_BUF fuzz_buf + #define __AFL_FUZZ_INIT() void sync(void); + #define __AFL_LOOP(x) ((fuzz_len = read(0, fuzz_buf, sizeof(fuzz_buf))) > 0 ? 1 : 0) + #define __AFL_INIT() sync() +#endif + +__AFL_FUZZ_INIT(); + int main(int argc, const char *argv[]) { if (argc <= 2) { const char * message = "Usage: decrypt \n"; @@ -22,15 +34,6 @@ int main(int argc, const char *argv[]) { "Error reading session file", read_file(session_fd, &session_buffer) ); - int message_fd = STDIN_FILENO; - uint8_t * message_buffer; - ssize_t message_length = check_errno( - "Error reading message file", read_file(message_fd, &message_buffer) - ); - - uint8_t * tmp_buffer = (uint8_t *) malloc(message_length); - memcpy(tmp_buffer, message_buffer, message_length); - uint8_t session_memory[olm_inbound_group_session_size()]; OlmInboundGroupSession * session = olm_inbound_group_session(session_memory); check_error( @@ -42,32 +45,54 @@ int main(int argc, const char *argv[]) { ) ); - size_t max_length = check_error( - olm_inbound_group_session_last_error, - session, - "Error getting plaintext length", - olm_group_decrypt_max_plaintext_length( - session, tmp_buffer, message_length - ) - ); +#ifdef __AFL_HAVE_MANUAL_CONTROL + __AFL_INIT(); +#endif - uint8_t plaintext[max_length]; + size_t test_case_buf_len = 1024; + uint8_t * message_buffer = (uint8_t *) malloc(test_case_buf_len); + uint8_t * tmp_buffer = (uint8_t *) malloc(test_case_buf_len); - uint32_t ratchet_index; + while (__AFL_LOOP(10000)) { + size_t message_length = __AFL_FUZZ_TESTCASE_LEN; - size_t length = check_error( - olm_inbound_group_session_last_error, - session, - "Error decrypting message", - olm_group_decrypt( + if (message_length > test_case_buf_len) { + message_buffer = (uint8_t *)realloc(message_buffer, message_length); + tmp_buffer = (uint8_t *)realloc(tmp_buffer, message_length); + + if (!message_buffer || !tmp_buffer) return 1; + } + + memcpy(message_buffer, __AFL_FUZZ_TESTCASE_BUF, message_length); + memcpy(tmp_buffer, message_buffer, message_length); + + size_t max_length = check_error( + olm_inbound_group_session_last_error, session, - message_buffer, message_length, - plaintext, max_length, &ratchet_index - ) - ); + "Error getting plaintext length", + olm_group_decrypt_max_plaintext_length( + session, tmp_buffer, message_length + ) + ); - (void)write(STDOUT_FILENO, plaintext, length); - (void)write(STDOUT_FILENO, "\n", 1); + uint8_t plaintext[max_length]; + + uint32_t ratchet_index; + + size_t length = check_error( + olm_inbound_group_session_last_error, + session, + "Error decrypting message", + olm_group_decrypt( + session, + message_buffer, message_length, + plaintext, max_length, &ratchet_index + ) + ); + + (void)write(STDOUT_FILENO, plaintext, length); + (void)write(STDOUT_FILENO, "\n", 1); + } free(session_buffer); free(message_buffer);