diff --git a/src/megolm.c b/src/megolm.c index a969b36..affd3cb 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -116,7 +116,11 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff; if (steps == 0) { - continue; + if (advance_to < megolm->counter) { + steps = 0x100; + } else { + continue; + } } /* for all but the last step, we can just bump R(j) without regard diff --git a/tests/test_megolm.cpp b/tests/test_megolm.cpp index bf53346..3048fa3 100644 --- a/tests/test_megolm.cpp +++ b/tests/test_megolm.cpp @@ -98,4 +98,37 @@ std::uint8_t random_bytes[] = assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); } +{ + TestCase test_case("Megolm::advance overflow by one"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0xffffffffUL); + megolm_advance_to(&mr1, 0x0); + assert_equals(0x0U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0xffffffffUL); + megolm_advance(&mr2); + assert_equals(0x0U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + +{ + TestCase test_case("Megolm::advance overflow"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0x1UL); + megolm_advance_to(&mr1, 0x80000000UL); + megolm_advance_to(&mr1, 0x0); + assert_equals(0x0U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0x1UL); + megolm_advance_to(&mr2, 0x0UL); + assert_equals(0x0U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + }