diff --git a/include/olm/cipher.h b/include/olm/cipher.h index 5f7185c..b26f8ba 100644 --- a/include/olm/cipher.h +++ b/include/olm/cipher.h @@ -102,28 +102,33 @@ struct _olm_cipher { struct _olm_cipher_aes_sha_256 { struct _olm_cipher base_cipher; + /** context string for the HKDF used for deriving the AES256 key, HMAC key, + * and AES IV, from the key material passed to encrypt/decrypt. + */ uint8_t const * kdf_info; + + /** length of context string kdf_info */ size_t kdf_info_length; }; +extern const struct _olm_cipher_ops _olm_cipher_aes_sha_256_ops; /** - * initialises a cipher type which uses AES256 for encryption and SHA256 for - * authentication. + * get an initializer for an instance of struct _olm_cipher_aes_sha_256. * - * cipher: structure to be initialised + * To use it, declare: * - * kdf_info: context string for the HKDF used for deriving the AES256 key, HMAC - * key, and AES IV, from the key material passed to encrypt/decrypt. Note that - * this is NOT copied so must have a lifetime at least as long as the cipher - * instance. - * - * kdf_info_length: length of context string kdf_info + * struct _olm_cipher_aes_sha_256 MY_CIPHER = + * OLM_CIPHER_INIT_AES_SHA_256("MY_KDF"); + * struct _olm_cipher *cipher = OLM_CIPHER_BASE(&MY_CIPHER); */ -struct _olm_cipher *_olm_cipher_aes_sha_256_init( - struct _olm_cipher_aes_sha_256 *cipher, - uint8_t const * kdf_info, - size_t kdf_info_length); +#define OLM_CIPHER_INIT_AES_SHA_256(KDF_INFO) { \ + .base_cipher = { &_olm_cipher_aes_sha_256_ops },\ + .kdf_info = (uint8_t *)(KDF_INFO), \ + .kdf_info_length = sizeof(KDF_INFO) - 1 \ +} +#define OLM_CIPHER_BASE(CIPHER) \ + (&((CIPHER)->base_cipher)) #ifdef __cplusplus diff --git a/src/cipher.cpp b/src/cipher.cpp index 7830f6c..8c3de92 100644 --- a/src/cipher.cpp +++ b/src/cipher.cpp @@ -130,25 +130,12 @@ size_t aes_sha_256_cipher_decrypt( return plaintext_length; } +} // namespace -const _olm_cipher_ops aes_sha_256_cipher_ops = { +const struct _olm_cipher_ops _olm_cipher_aes_sha_256_ops = { aes_sha_256_cipher_mac_length, aes_sha_256_cipher_encrypt_ciphertext_length, aes_sha_256_cipher_encrypt, aes_sha_256_cipher_decrypt_max_plaintext_length, aes_sha_256_cipher_decrypt, }; - -} // namespace - - -_olm_cipher *_olm_cipher_aes_sha_256_init( - struct _olm_cipher_aes_sha_256 *cipher, - uint8_t const * kdf_info, - size_t kdf_info_length -) { - cipher->base_cipher.ops = &aes_sha_256_cipher_ops; - cipher->kdf_info = kdf_info; - cipher->kdf_info_length = kdf_info_length; - return &(cipher->base_cipher); -} diff --git a/src/olm.cpp b/src/olm.cpp index b34a1dc..fcd033a 100644 --- a/src/olm.cpp +++ b/src/olm.cpp @@ -57,24 +57,13 @@ static std::uint8_t const * from_c(void const * bytes) { return reinterpret_cast(bytes); } -static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle"; - -const _olm_cipher *get_pickle_cipher() { - static _olm_cipher *cipher = NULL; - static _olm_cipher_aes_sha_256 PICKLE_CIPHER; - if (!cipher) { - cipher = _olm_cipher_aes_sha_256_init( - &PICKLE_CIPHER, - CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1 - ); - } - return cipher; -} +static const struct _olm_cipher_aes_sha_256 PICKLE_CIPHER = + OLM_CIPHER_INIT_AES_SHA_256("Pickle"); std::size_t enc_output_length( size_t raw_length ) { - auto *cipher = get_pickle_cipher(); + auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); length += cipher->ops->mac_length(cipher); return olm::encode_base64_length(length); @@ -85,7 +74,7 @@ std::uint8_t * enc_output_pos( std::uint8_t * output, size_t raw_length ) { - auto *cipher = get_pickle_cipher(); + auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); length += cipher->ops->mac_length(cipher); return output + olm::encode_base64_length(length) - length; @@ -95,7 +84,7 @@ std::size_t enc_output( std::uint8_t const * key, std::size_t key_length, std::uint8_t * output, size_t raw_length ) { - auto *cipher = get_pickle_cipher(); + auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); std::size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length( cipher, raw_length ); @@ -124,7 +113,7 @@ std::size_t enc_input( return std::size_t(-1); } olm::decode_base64(input, b64_length, input); - auto *cipher = get_pickle_cipher(); + auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER); std::size_t raw_length = enc_length - cipher->ops->mac_length(cipher); std::size_t result = cipher->ops->decrypt( cipher, diff --git a/src/session.cpp b/src/session.cpp index 19b9f21..c148c97 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -35,22 +35,13 @@ static const olm::KdfInfo OLM_KDF_INFO = { RATCHET_KDF_INFO, sizeof(RATCHET_KDF_INFO) - 1 }; -const _olm_cipher *get_cipher() { - static _olm_cipher *cipher; - static _olm_cipher_aes_sha_256 OLM_CIPHER; - if (!cipher) { - cipher = _olm_cipher_aes_sha_256_init( - &OLM_CIPHER, - CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1 - ); - } - return cipher; -} +static const struct _olm_cipher_aes_sha_256 OLM_CIPHER = + OLM_CIPHER_INIT_AES_SHA_256(CIPHER_KDF_INFO); } // namespace olm::Session::Session( -) : ratchet(OLM_KDF_INFO, get_cipher()), +) : ratchet(OLM_KDF_INFO, OLM_CIPHER_BASE(&OLM_CIPHER)), last_error(OlmErrorCode::OLM_SUCCESS), received_message(false) { diff --git a/tests/test_ratchet.cpp b/tests/test_ratchet.cpp index 3997eb3..2f8412e 100644 --- a/tests/test_ratchet.cpp +++ b/tests/test_ratchet.cpp @@ -28,10 +28,8 @@ olm::KdfInfo kdf_info = { ratchet_info, sizeof(ratchet_info) - 1 }; -_olm_cipher_aes_sha_256 cipher0; -_olm_cipher *cipher = _olm_cipher_aes_sha_256_init( - &cipher0, message_info, sizeof(message_info) - 1 -); +_olm_cipher_aes_sha_256 cipher0 = OLM_CIPHER_INIT_AES_SHA_256(message_info); +_olm_cipher *cipher = OLM_CIPHER_BASE(&cipher0); std::uint8_t random_bytes[] = "0123456789ABDEF0123456789ABCDEF"; olm::Curve25519KeyPair alice_key;