0110.be logo

~ Calling JNI code from multiple Java threads: sharing state

1

1
Java Threads
Java Threads
C++ states
C++ states
2
2
1
1
2
2
3
3
3
3
JNI Bridge

Mapping Java threads to C++ states in a JNI bridge

This post deals with the problem of using stateful C++ code from multiple Java threads. With JNI (Java Native Interface) it is possible to glue C++ code to a Java environment. There are many helpful tutorials on how to call C++ code and receive results. JNI helps to reuse existing, often highly complex and computationally expensive, C++ code.

The introductory tutorials often stop once it is made clear how to repackage (simple) datatypes and do not mention threads. It is, however, reasonable to expect JNI code to take into account thread-safety and proper multi-threading. In all but the simplest cases it is not that straightforward to share state at the C++ side and allow JNI code to be called from multiple Java threads. Incorrectly sharing state can lead to memory leaks and segmentation faults (segfaults) and crashes the application. In what follows, a way to share thread-local state is presented.

It is quite common to have an init, work and dispose method to create a state, use that state and do some work and finally dispose of used resources. Each Java thread independently calls these methods and expects results. These results should not change if multiple Java threads are calling the same methods. In other words: the state should remain Java thread-local. A typical Java class could look like the code below.

With the Java code in mind, the C++ code should know which Java thread is used and which state needs to be used for the work. Luckily there is a way to find out: The JNI specification states that each JNIEnv is local to a Java thread. So we can use the JNIEnv pointer to identify a thread. This is the idea that is used below.

The code maps a JNIEnv pointer to a structure with (any) state information. An unordered map is used for this mapping. There is, however, still a problem: multiple threads can call the init method at once. So multiple threads potentially write to the unordered_map at the same time which leads to problems. To prevent this from happening a mutex is used. The mutex, together with a unique lock, makes sure that only a single thread writes to the unordered map. The same holds for the dispose method.

The work method does not need a unique lock since it does not write to the unordered map and reading from multiple threads is no problem.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <unordered_map>
#include <mutex>

const int DATA_ARRAY_SIZE = 300000 * 2;

struct BridgeState{
 jfloat *data;
}

//A hash map with a JNIEnv * as key and a BridgeState * as value
std::unordered_map<uintptr_t, uintptr_t> stateMap;

//A mutex to ensure that writes to the stateMap are synchronized.
std::mutex stateMutex;

JNIEXPORT jint JNICALL Java_init(JNIEnv * env, jobject object){
  //Makes sure only one thread writes to the stateMap
  std::unique_lock<std::mutex> lck (stateMutex);

  BridgeState * state =  new BridgeState();
  uintptr_t env_addresss = reinterpret_cast<uintptr_t>(env);

  state->cArray = new jfloat[DATA_ARRAY_SIZE];
 
  uintptr_t state_addresss = reinterpret_cast<uintptr_t>(state);
  stateMap[env_addresss] = state_addresss;
  
  return 1;
}

JNIEXPORT jint JNICALL Java_work(JNIEnv * env, jobject object){
  //get a ref to the state pointer
  uintptr_t env_addresss = reinterpret_cast<uintptr_t>(env);
  BridgeState * state = reinterpret_cast<BridgeState *>(stateMap[env_addresss]);

  //do something with state->data, e.g. calculate the sum
  int sum = 0;
  for(int i = 0 ; i < DATA_ARRAY_SIZE ; i++){
    state->data[i] = state->data[i] + 1;
    sum += (int) state->data[i];
  }
  return sum; 
}

JNIEXPORT jint JNICALL Java_dispose(JNIEnv * env, jobject object){
  //Makes sure only one thread writes to the stateMap
  std::unique_lock<std::mutex> lck (stateMutex);

  uintptr_t env_addresss = reinterpret_cast<uintptr_t>(env);
  BridgeState * state = reinterpret_cast<BridgeState *>(stateMap[env_addresss]);
  stateMap.erase(env_addresss); 

  //cleanup memory
  delete [] state->data;
  delete state;
  return 0
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public class Bridge{
        
  //Load a native library
  static {
    try {
      System.loadLibrary("bridge"); 
    catch (UnsatisfiedLinkError e){
      e.printStackTrace();
    }
  }
  
  private native int init();

  private native int work();

  private native int dispose();

  public static void main (String[] args){
    //start work on 20 threads
    for(int i = 0 ; i<20 ; i++){
      Thread.new(new Runnable() {
        @Override
        public void run() {
          Bridge b = new Bridge();
          b.init();
          b.work();
          b.dispose();
      }).start();
    }
  }
}

This conceptual code has been lifted from a JNI library doing actual work: The JGaborator JNI bridge . If you need more information on how to compile and use this construct in actual code, please have a look at the JGaborator GitHub repository