Skip to content

Commit

Permalink
add more to iobinding
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzawa-san committed Jan 9, 2024
1 parent 58f3edb commit 083b5d0
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ final class ApiImpl implements Api {
final SessionOptionsAppendExecutionProvider_OpenVINO SessionOptionsAppendExecutionProvider_OpenVINO;
final SessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM;
final SessionOptionsAppendExecutionProvider_TensorRT_V2 SessionOptionsAppendExecutionProvider_TensorRT_V2;
final SynchronizeBoundInputs SynchronizeBoundInputs;
final SynchronizeBoundOutputs SynchronizeBoundOutputs;
final UpdateCUDAProviderOptions UpdateCUDAProviderOptions;
final UpdateDnnlProviderOptions UpdateDnnlProviderOptions;
final UpdateTensorRTProviderOptions UpdateTensorRTProviderOptions;
Expand Down Expand Up @@ -290,6 +292,8 @@ final class ApiImpl implements Api {
OrtApi.SessionOptionsAppendExecutionProvider_ROCM(memorySegment, memorySession);
this.SessionOptionsAppendExecutionProvider_TensorRT_V2 =
OrtApi.SessionOptionsAppendExecutionProvider_TensorRT_V2(memorySegment, memorySession);
this.SynchronizeBoundInputs = OrtApi.SynchronizeBoundInputs(memorySegment, memorySession);
this.SynchronizeBoundOutputs = OrtApi.SynchronizeBoundOutputs(memorySegment, memorySession);
this.UpdateCUDAProviderOptions = OrtApi.UpdateCUDAProviderOptions(memorySegment, memorySession);
this.UpdateDnnlProviderOptions = OrtApi.UpdateDnnlProviderOptions(memorySegment, memorySession);
this.UpdateTensorRTProviderOptions = OrtApi.UpdateTensorRTProviderOptions(memorySegment, memorySession);
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/IoBinding.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public interface IoBinding extends AutoCloseable {
*/
IoBinding setRunTag(String runTag);

IoBinding synchronizeBoundInputs();

IoBinding synchronizeBoundOutputs();

NamedCollection<OnnxValue> getInputs();

NamedCollection<OnnxValue> getOutputs();
Expand Down
16 changes: 14 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ static final class Builder implements IoBinding.Builder {
public Builder(SessionImpl session) {
this.api = session.api;
this.session = session;
this.inputs = new ArrayList<>();
this.outputs = new ArrayList<>();
this.inputs = new ArrayList<>(session.inputs.size());
this.outputs = new ArrayList<>(session.outputs.size());
}

@Override
Expand Down Expand Up @@ -178,4 +178,16 @@ public NamedCollection<OnnxValue> getInputs() {
public NamedCollection<OnnxValue> getOutputs() {
return outputs;
}

@Override
public IoBinding synchronizeBoundInputs() {
api.checkStatus(api.SynchronizeBoundInputs.apply(ioBinding));
return this;
}

@Override
public IoBinding synchronizeBoundOutputs() {
api.checkStatus(api.SynchronizeBoundOutputs.apply(ioBinding));
return this;
}
}
2 changes: 2 additions & 0 deletions src/test/java/com/jyuzawa/onnxruntime/SessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,9 @@ public void ioBindingTest() throws IOException {
ThreadLocalRandom.current().nextInt()
};
inputBuf.clear().put(rawInput);
txn.synchronizeBoundInputs();
txn.run();
txn.synchronizeBoundOutputs();
outputBuf.rewind().get(rawOutput);
assertArrayEquals(rawInput, rawOutput);
}
Expand Down

0 comments on commit 083b5d0

Please sign in to comment.