diff --git a/java/android/OlmLibSdk/olm-sdk/src/main/jni/olm_jni_helper.cpp b/java/android/OlmLibSdk/olm-sdk/src/main/jni/olm_jni_helper.cpp index d2ecce3..3fddf62 100644 --- a/java/android/OlmLibSdk/olm-sdk/src/main/jni/olm_jni_helper.cpp +++ b/java/android/OlmLibSdk/olm-sdk/src/main/jni/olm_jni_helper.cpp @@ -29,91 +29,96 @@ using namespace AndroidOlmSdk; **/ bool setRandomInBuffer(JNIEnv *env, uint8_t **aBuffer2Ptr, size_t aRandomSize) { - bool retCode = false; - int bufferLen = aRandomSize*sizeof(uint8_t); + bool retCode = false; + int bufferLen = aRandomSize*sizeof(uint8_t); - if(NULL == aBuffer2Ptr) + if(NULL == aBuffer2Ptr) + { + LOGE("## setRandomInBuffer(): failure - aBuffer=NULL"); + } + else if(0 == aRandomSize) + { + LOGE("## setRandomInBuffer(): failure - random size=0"); + } + else if(NULL == (*aBuffer2Ptr = (uint8_t*)malloc(bufferLen))) + { + LOGE("## setRandomInBuffer(): failure - alloc mem OOM"); + } + else + { + LOGD("## setRandomInBuffer(): randomSize=%lu",static_cast(aRandomSize)); + + bool secureRandomSucceeds = false; + + // use the secureRandom class + jclass cls = env->FindClass("java/security/SecureRandom"); + + if (cls) { - LOGE("## setRandomInBuffer(): failure - aBuffer=NULL"); - } - else if(0 == aRandomSize) - { - LOGE("## setRandomInBuffer(): failure - random size=0"); - } - else if(NULL == (*aBuffer2Ptr = (uint8_t*)malloc(bufferLen))) - { - LOGE("## setRandomInBuffer(): failure - alloc mem OOM"); - } - else - { - LOGD("## setRandomInBuffer(): randomSize=%lu",static_cast(aRandomSize)); + jobject newObj = 0; + jmethodID constructor = env->GetMethodID(cls, "", "()V"); + jmethodID nextByteMethod = env->GetMethodID(cls, "nextBytes", "([B)V"); - bool secureRandomSucceeds = false; + if (constructor) + { + newObj = env->NewObject(cls, constructor); + jbyteArray tempByteArray = env->NewByteArray(bufferLen); - // clear the buffer - memset(*aBuffer2Ptr, 0, bufferLen); - - // use the secureRandom class - jclass cls = env->FindClass("java/security/SecureRandom"); - - if (cls) + if (newObj && tempByteArray) { - jobject newObj = 0; - jmethodID constructor = env->GetMethodID(cls, "", "()V"); - jmethodID nextByteMethod = env->GetMethodID(cls, "nextBytes", "([B)V"); + env->CallVoidMethod(newObj, nextByteMethod, tempByteArray); - if (constructor) - { - newObj = env->NewObject(cls, constructor); - jbyteArray tempByteArray = env->NewByteArray(bufferLen); + jbyte* buffer = env->GetByteArrayElements(tempByteArray, NULL); - if (newObj && tempByteArray) - { - env->CallVoidMethod(newObj, nextByteMethod, tempByteArray); + if (buffer) + { + memcpy(*aBuffer2Ptr, buffer, bufferLen); + secureRandomSucceeds = true; - jbyte* buffer = env->GetByteArrayElements(tempByteArray,0); + // clear tempByteArray to hide sensitive data. + memset(buffer, 0, bufferLen); + env->SetByteArrayRegion(tempByteArray, 0, bufferLen, buffer); - if (buffer) - { - memcpy(*aBuffer2Ptr, buffer, bufferLen); - secureRandomSucceeds = true; - } - } - - if (tempByteArray) - { - env->DeleteLocalRef(tempByteArray); - } - - if (newObj) - { - env->DeleteLocalRef(newObj); - } - } + // ensure that the buffer is released + env->ReleaseByteArrayElements(tempByteArray, buffer, JNI_ABORT); + } } - if (!secureRandomSucceeds) + if (tempByteArray) { - LOGE("## setRandomInBuffer(): SecureRandom failed, use a fallback"); - struct timeval timeValue; - gettimeofday(&timeValue, NULL); - srand(timeValue.tv_usec); // init seed - - for(size_t i=0;iDeleteLocalRef(tempByteArray); } - // debug purpose - /*for(int i = 0; i < aRandomSize; i++) + if (newObj) { - LOGD("## setRandomInBuffer(): randomBuffPtr[%ld]=%d",i, (*aBuffer2Ptr)[i]); - }*/ - - retCode = true; + env->DeleteLocalRef(newObj); + } + } } - return retCode; + + if (!secureRandomSucceeds) + { + LOGE("## setRandomInBuffer(): SecureRandom failed, use a fallback"); + struct timeval timeValue; + gettimeofday(&timeValue, NULL); + srand(timeValue.tv_usec); // init seed + + for(size_t i=0;iNewByteArray(aMsgLength))) + { + LOGE("## javaCStringToUtf8(): failure - return byte array OOM"); + } + else + { + env->SetByteArrayRegion(tempByteArray, 0, aMsgLength, (const jbyte*)aCStringMsgPtr); + + // UTF-8 conversion from JAVA + jstring strEncode = (env)->NewStringUTF("UTF-8"); + jclass jClass = env->FindClass("java/lang/String"); + jmethodID cstor = env->GetMethodID(jClass, "", "([BLjava/lang/String;)V"); + + if((0!=jClass) && (0!=jClass) && (0!=strEncode)) { - LOGE("## javaCStringToUtf8(): failure - invalid parameters (null)"); - } - else if(NULL == (tempByteArray=env->NewByteArray(aMsgLength))) - { - LOGE("## javaCStringToUtf8(): failure - return byte array OOM"); + convertedRetValue = (jstring) env->NewObject(jClass, cstor, tempByteArray, strEncode); + LOGD(" ## javaCStringToUtf8(): succeed"); + env->DeleteLocalRef(tempByteArray); } else { - env->SetByteArrayRegion(tempByteArray, 0, aMsgLength, (const jbyte*)aCStringMsgPtr); - - // UTF-8 conversion from JAVA - jstring strEncode = (env)->NewStringUTF("UTF-8"); - jclass jClass = env->FindClass("java/lang/String"); - jmethodID cstor = env->GetMethodID(jClass, "", "([BLjava/lang/String;)V"); - - if((0!=jClass) && (0!=jClass) && (0!=strEncode)) - { - convertedRetValue = (jstring) env->NewObject(jClass, cstor, tempByteArray, strEncode); - LOGD(" ## javaCStringToUtf8(): succeed"); - env->DeleteLocalRef(tempByteArray); - } - else - { - LOGE(" ## javaCStringToUtf8(): failure - invalid Java references"); - } + LOGE(" ## javaCStringToUtf8(): failure - invalid Java references"); } + } - return convertedRetValue; + return convertedRetValue; }