lib/onnxruntime/inference_session.rb in onnxruntime-0.1.1 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.1.2

- old
+ new

@@ -31,22 +31,22 @@ @outputs = [] # input num_input_nodes = ::FFI::MemoryPointer.new(:size_t) check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes) - num_input_nodes.read(:size_t).times do |i| + read_size_t(num_input_nodes).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status FFI.OrtSessionGetInputTypeInfo(read_pointer, i, typeinfo) @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end # output num_output_nodes = ::FFI::MemoryPointer.new(:size_t) check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes) - num_output_nodes.read(:size_t).times do |i| + read_size_t(num_output_nodes).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status FFI.OrtSessionGetOutputTypeInfo(read_pointer, i, typeinfo) @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) @@ -75,10 +75,12 @@ allocator_info = ::FFI::MemoryPointer.new(:pointer) check_status = FFI.OrtCreateCpuAllocatorInfo(1, 0, allocator_info) input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size) input_feed.each_with_index do |(input_name, input), idx| + input = input.to_a unless input.is_a?(Array) + shape = [] s = input while s.is_a?(Array) shape << s.size s = s.first @@ -144,11 +146,11 @@ tensor_data = ::FFI::MemoryPointer.new(:pointer) check_status FFI.OrtGetTensorMutableData(out_ptr, tensor_data) out_size = ::FFI::MemoryPointer.new(:size_t) output_tensor_size = FFI.OrtGetTensorShapeElementCount(typeinfo.read_pointer, out_size) - output_tensor_size = out_size.read(:size_t) + output_tensor_size = read_size_t(out_size) # TODO support more types type = FFI::TensorElementDataType[type] arr = case type @@ -163,11 +165,11 @@ Utils.reshape(arr, shape) when :sequence out = ::FFI::MemoryPointer.new(:size_t) check_status FFI.OrtGetValueCount(out_ptr, out) - out.read(:size_t).times.map do |i| + read_size_t(out).times.map do |i| seq = ::FFI::MemoryPointer.new(:pointer) check_status FFI.OrtGetValue(out_ptr, i, @allocator.read_pointer, seq) create_from_onnx_value(seq.read_pointer) end when :map @@ -250,11 +252,11 @@ type = ::FFI::MemoryPointer.new(:int) check_status FFI.OrtGetTensorElementType(tensor_info.read_pointer, type) num_dims_ptr = ::FFI::MemoryPointer.new(:size_t) check_status FFI.OrtGetDimensionsCount(tensor_info.read_pointer, num_dims_ptr) - num_dims = num_dims_ptr.read(:size_t) + num_dims = read_size_t(num_dims_ptr) node_dims = ::FFI::MemoryPointer.new(:int64, num_dims) check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims) [type.read_int, node_dims.read_array_of_int64(num_dims)] @@ -262,17 +264,27 @@ def unsupported_type(name, type) raise "Unsupported #{name} type: #{type}" end - # share env - # TODO mutex around creation? + # read(:size_t) not supported in FFI JRuby + def read_size_t(ptr) + if RUBY_PLATFORM == "java" + ptr.read_long + else + ptr.read(:size_t) + end + end + def env - @@env ||= begin - env = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtCreateEnv(3, "Default", env) - at_exit { FFI.OrtReleaseEnv(env.read_pointer) } - env + # use mutex for thread-safety + Utils.mutex.synchronize do + @@env ||= begin + env = ::FFI::MemoryPointer.new(:pointer) + check_status FFI.OrtCreateEnv(3, "Default", env) + at_exit { FFI.OrtReleaseEnv(env.read_pointer) } + env + end end end end end