lib/onnxruntime/inference_session.rb in onnxruntime-0.1.1 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.1.2
- old
+ new
@@ -31,22 +31,22 @@
@outputs = []
# input
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes)
- num_input_nodes.read(:size_t).times do |i|
+ read_size_t(num_input_nodes).times do |i|
name_ptr = ::FFI::MemoryPointer.new(:string)
check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr)
typeinfo = ::FFI::MemoryPointer.new(:pointer)
check_status FFI.OrtSessionGetInputTypeInfo(read_pointer, i, typeinfo)
@inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
end
# output
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes)
- num_output_nodes.read(:size_t).times do |i|
+ read_size_t(num_output_nodes).times do |i|
name_ptr = ::FFI::MemoryPointer.new(:string)
check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr)
typeinfo = ::FFI::MemoryPointer.new(:pointer)
check_status FFI.OrtSessionGetOutputTypeInfo(read_pointer, i, typeinfo)
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
@@ -75,10 +75,12 @@
allocator_info = ::FFI::MemoryPointer.new(:pointer)
check_status = FFI.OrtCreateCpuAllocatorInfo(1, 0, allocator_info)
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
input_feed.each_with_index do |(input_name, input), idx|
+ input = input.to_a unless input.is_a?(Array)
+
shape = []
s = input
while s.is_a?(Array)
shape << s.size
s = s.first
@@ -144,11 +146,11 @@
tensor_data = ::FFI::MemoryPointer.new(:pointer)
check_status FFI.OrtGetTensorMutableData(out_ptr, tensor_data)
out_size = ::FFI::MemoryPointer.new(:size_t)
output_tensor_size = FFI.OrtGetTensorShapeElementCount(typeinfo.read_pointer, out_size)
- output_tensor_size = out_size.read(:size_t)
+ output_tensor_size = read_size_t(out_size)
# TODO support more types
type = FFI::TensorElementDataType[type]
arr =
case type
@@ -163,11 +165,11 @@
Utils.reshape(arr, shape)
when :sequence
out = ::FFI::MemoryPointer.new(:size_t)
check_status FFI.OrtGetValueCount(out_ptr, out)
- out.read(:size_t).times.map do |i|
+ read_size_t(out).times.map do |i|
seq = ::FFI::MemoryPointer.new(:pointer)
check_status FFI.OrtGetValue(out_ptr, i, @allocator.read_pointer, seq)
create_from_onnx_value(seq.read_pointer)
end
when :map
@@ -250,11 +252,11 @@
type = ::FFI::MemoryPointer.new(:int)
check_status FFI.OrtGetTensorElementType(tensor_info.read_pointer, type)
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
check_status FFI.OrtGetDimensionsCount(tensor_info.read_pointer, num_dims_ptr)
- num_dims = num_dims_ptr.read(:size_t)
+ num_dims = read_size_t(num_dims_ptr)
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims)
[type.read_int, node_dims.read_array_of_int64(num_dims)]
@@ -262,17 +264,27 @@
def unsupported_type(name, type)
raise "Unsupported #{name} type: #{type}"
end
- # share env
- # TODO mutex around creation?
+ # read(:size_t) not supported in FFI JRuby
+ def read_size_t(ptr)
+ if RUBY_PLATFORM == "java"
+ ptr.read_long
+ else
+ ptr.read(:size_t)
+ end
+ end
+
def env
- @@env ||= begin
- env = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtCreateEnv(3, "Default", env)
- at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
- env
+ # use mutex for thread-safety
+ Utils.mutex.synchronize do
+ @@env ||= begin
+ env = ::FFI::MemoryPointer.new(:pointer)
+ check_status FFI.OrtCreateEnv(3, "Default", env)
+ at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
+ env
+ end
end
end
end
end