diff --git a/include/olm/base64.hh b/include/olm/base64.hh index dfdccd0..c0cbc0a 100644 --- a/include/olm/base64.hh +++ b/include/olm/base64.hh @@ -51,8 +51,12 @@ std::size_t decode_base64_length( * Writes decode_base64_length(input_length) bytes to the output buffer. * The output can overlap with the first three quarters of the input buffer. * That is, the input pointers and output pointer may be the same. + * + * Returns the number of bytes of raw data the base64 input decoded to. If the + * input length supplied is not a valid length for base64, returns + * std::size_t(-1) and does not decode. */ -std::uint8_t const * decode_base64( +std::size_t decode_base64( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ); diff --git a/src/base64.cpp b/src/base64.cpp index bbfb210..0e195fb 100644 --- a/src/base64.cpp +++ b/src/base64.cpp @@ -12,6 +12,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include + #include "olm/base64.h" #include "olm/base64.hh" @@ -101,12 +103,19 @@ std::size_t olm::decode_base64_length( } -std::uint8_t const * olm::decode_base64( +std::size_t olm::decode_base64( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ) { + size_t raw_length = olm::decode_base64_length(input_length); + + if (raw_length == std::size_t(-1)) { + return std::size_t(-1); + } + std::uint8_t const * end = input + (input_length / 4) * 4; std::uint8_t const * pos = input; + while (pos != end) { unsigned value = DECODE_BASE64[pos[0] & 0x7F]; value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F]; @@ -118,8 +127,19 @@ std::uint8_t const * olm::decode_base64( value >>= 8; output[0] = value; output += 3; } + unsigned remainder = input + input_length - pos; if (remainder) { + /* A base64 payload with a single byte remainder cannot occur because + * a single base64 character only encodes 6 bits, which is less than + * a full byte. Therefore, a minimum of two base64 characters are + * required to construct a single output byte and payloads with + * a remainder of 1 are illegal. + * + * Should never be the case due to length check above. + */ + assert(remainder != 1); + unsigned value = DECODE_BASE64[pos[0] & 0x7F]; value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F]; if (remainder == 3) { @@ -132,7 +152,8 @@ std::uint8_t const * olm::decode_base64( } output[0] = value; } - return input + input_length; + + return raw_length; } @@ -162,6 +183,5 @@ size_t _olm_decode_base64( uint8_t const * input, size_t input_length, uint8_t * output ) { - olm::decode_base64(input, input_length, output); - return olm::decode_base64_length(input_length); + return olm::decode_base64(input, input_length, output); } diff --git a/tests/test_base64.cpp b/tests/test_base64.cpp index 6f80acf..9e49bef 100644 --- a/tests/test_base64.cpp +++ b/tests/test_base64.cpp @@ -66,5 +66,24 @@ assert_equals(std::size_t(11), output_length); assert_equals(expected_output, output, output_length); } +{ +TestCase test_case("Decoding base64 of invalid length fails with -1"); +#include +std::uint8_t input[] = "SGVsbG8gV29ybGQab"; +std::size_t input_length = sizeof(input) - 1; + +/* We use a longer but valid input length here so that we don't get back -1. + * Nothing will be written to the output buffer anyway because the input is + * invalid. */ +std::size_t buf_length = olm::decode_base64_length(input_length + 1); +std::uint8_t output[buf_length]; +std::uint8_t expected_output[buf_length]; +memset(output, 0, buf_length); +memset(expected_output, 0, buf_length); + +std::size_t output_length = ::_olm_decode_base64(input, input_length, output); +assert_equals(std::size_t(-1), output_length); +assert_equals(0, memcmp(output, expected_output, buf_length)); +} }