lib/onnxruntime/inference_session.rb in onnxruntime-0.5.0 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.5.1

- old
+ new

@@ -78,11 +78,11 @@ ensure # release :SessionOptions, session_options end # TODO support logid - def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) + def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby) input_tensor = create_input_tensor(input_feed) output_names ||= @outputs.map { |v| v[:name] } output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size) @@ -98,11 +98,11 @@ check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) output_names.size.times.map do |i| - create_from_onnx_value(output_tensor[i].read_pointer) + create_from_onnx_value(output_tensor[i].read_pointer, output_type) end ensure release :RunOptions, run_options if input_tensor input_feed.size.times do |i| @@ -178,46 +178,60 @@ allocator_info = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info) input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size) input_feed.each_with_index do |(input_name, input), idx| - input = input.to_a unless input.is_a?(Array) + if numo_array?(input) + shape = input.shape + else + input = input.to_a unless input.is_a?(Array) - shape = [] - s = input - while s.is_a?(Array) - shape << s.size - s = s.first + shape = [] + s = input + while s.is_a?(Array) + shape << s.size + s = s.first + end end - flat_input = input.flatten - input_tensor_size = flat_input.size - # TODO support more types inp = @inputs.find { |i| i[:name] == input_name.to_s } raise Error, "Unknown input: #{input_name}" unless inp input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size) input_node_dims.write_array_of_int64(shape) if inp[:type] == "tensor(string)" - input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size) - input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) }) + if numo_array?(input) + input_tensor_size = input.size + input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size) + input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) }) + else + flat_input = input.flatten.to_a + input_tensor_size = flat_input.size + input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size) + input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) }) + end type_enum = FFI::TensorElementDataType[:string] check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx]) - check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size) + check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size) else - tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h tensor_type = tensor_types[inp[:type]] if tensor_type - input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size) - if tensor_type == :bool - tensor_type = :uchar - flat_input = flat_input.map { |v| v ? 1 : 0 } + if numo_array?(input) + input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary + else + flat_input = input.flatten.to_a + input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size) + if tensor_type == :bool + tensor_type = :uchar + flat_input = flat_input.map { |v| v ? 1 : 0 } + end + input_tensor_values.send("write_array_of_#{tensor_type}", flat_input) end - input_tensor_values.send("write_array_of_#{tensor_type}", flat_input) + type_enum = FFI::TensorElementDataType[tensor_type] else unsupported_type("input", inp[:type]) end @@ -232,11 +246,11 @@ ptr = ::FFI::MemoryPointer.new(:pointer, names.size) ptr.write_array_of_pointer(names.map { |v| ::FFI::MemoryPointer.from_string(v) }) ptr end - def create_from_onnx_value(out_ptr) + def create_from_onnx_value(out_ptr, output_type) out_type = ::FFI::MemoryPointer.new(:int) check_status api[:GetValueType].call(out_ptr, out_type) type = FFI::OnnxType[out_type.read_int] case type @@ -255,29 +269,48 @@ release :TensorTypeAndShapeInfo, typeinfo # TODO support more types type = FFI::TensorElementDataType[type] - arr = + + case output_type + when :numo case type - when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64 - tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size) - when :bool - tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 } + when :string + result = Numo::RObject.new(shape) + result.allocate + create_strings_from_onnx_value(out_ptr, output_tensor_size, result) else - unsupported_type("element", type) + numo_type = numo_types[type] + unsupported_type("element", type) unless numo_type + numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape) end + when :ruby + arr = + case type + when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64 + tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size) + when :bool + tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 } + when :string + create_strings_from_onnx_value(out_ptr, output_tensor_size, []) + else + unsupported_type("element", type) + end - Utils.reshape(arr, shape) + Utils.reshape(arr, shape) + else + raise ArgumentError, "Invalid output type: #{output_type}" + end when :sequence out = ::FFI::MemoryPointer.new(:size_t) check_status api[:GetValueCount].call(out_ptr, out) out.read(:size_t).times.map do |i| seq = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq) - create_from_onnx_value(seq.read_pointer) + create_from_onnx_value(seq.read_pointer, output_type) end when :map type_shape = ::FFI::MemoryPointer.new(:pointer) map_keys = ::FFI::MemoryPointer.new(:pointer) map_values = ::FFI::MemoryPointer.new(:pointer) @@ -292,12 +325,12 @@ # TODO support more types elem_type = FFI::TensorElementDataType[elem_type.read_int] case elem_type when :int64 ret = {} - keys = create_from_onnx_value(map_keys.read_pointer) - values = create_from_onnx_value(map_values.read_pointer) + keys = create_from_onnx_value(map_keys.read_pointer, output_type) + values = create_from_onnx_value(map_values.read_pointer, output_type) keys.zip(values).each do |k, v| ret[k] = v end ret else @@ -306,10 +339,27 @@ else unsupported_type("ONNX", type) end end + def create_strings_from_onnx_value(out_ptr, output_tensor_size, result) + len = ::FFI::MemoryPointer.new(:size_t) + check_status api[:GetStringTensorDataLength].call(out_ptr, len) + + s_len = len.read(:size_t) + s = ::FFI::MemoryPointer.new(:uchar, s_len) + offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size) + check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size) + + offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) } + offsets << s_len + output_tensor_size.times do |i| + result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i]) + end + result + end + def read_pointer @session.read_pointer end def check_status(status) @@ -386,9 +436,33 @@ [type.read_int, node_dims.read_array_of_int64(num_dims)] end def unsupported_type(name, type) raise Error, "Unsupported #{name} type: #{type}" + end + + def tensor_types + @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h + end + + def numo_array?(obj) + defined?(Numo::NArray) && obj.is_a?(Numo::NArray) + end + + def numo_types + @numo_types ||= { + float: Numo::SFloat, + uint8: Numo::UInt8, + int8: Numo::Int8, + uint16: Numo::UInt16, + int16: Numo::Int16, + int32: Numo::Int32, + int64: Numo::Int64, + bool: Numo::UInt8, + double: Numo::DFloat, + uint32: Numo::UInt32, + uint64: Numo::UInt64 + } end def api self.class.api end