diff --git a/src/megolm.c b/src/megolm.c index 6d8af08..a969b36 100644 --- a/src/megolm.c +++ b/src/megolm.c @@ -108,8 +108,12 @@ void megolm_advance_to(Megolm *megolm, uint32_t advance_to) { uint32_t mask = (~(uint32_t)0) << shift; int k; - /* how many times to we need to rehash this part? */ - int steps = (advance_to >> shift) - (megolm->counter >> shift); + /* how many times do we need to rehash this part? + * + * '& 0xff' ensures we handle integer wraparound correctly + */ + unsigned int steps = + ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff; if (steps == 0) { continue; diff --git a/tests/test_megolm.cpp b/tests/test_megolm.cpp index 871de36..bf53346 100644 --- a/tests/test_megolm.cpp +++ b/tests/test_megolm.cpp @@ -82,4 +82,20 @@ std::uint8_t random_bytes[] = assert_equals(expected3, megolm_get_data(&mr), MEGOLM_RATCHET_LENGTH); } +{ + TestCase test_case("Megolm::advance wraparound"); + + Megolm mr1, mr2; + + megolm_init(&mr1, random_bytes, 0xffffffffUL); + megolm_advance_to(&mr1, 0x1000000); + assert_equals(0x1000000U, mr1.counter); + + megolm_init(&mr2, random_bytes, 0); + megolm_advance_to(&mr2, 0x2000000); + assert_equals(0x2000000U, mr2.counter); + + assert_equals(megolm_get_data(&mr2), megolm_get_data(&mr1), MEGOLM_RATCHET_LENGTH); +} + }