lib/onnxruntime/inference_session.rb in onnxruntime-0.5.0 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.5.1
- old
+ new
@@ -78,11 +78,11 @@
ensure
# release :SessionOptions, session_options
end
# TODO support logid
- def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
input_tensor = create_input_tensor(input_feed)
output_names ||= @outputs.map { |v| v[:name] }
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
@@ -98,11 +98,11 @@
check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
output_names.size.times.map do |i|
- create_from_onnx_value(output_tensor[i].read_pointer)
+ create_from_onnx_value(output_tensor[i].read_pointer, output_type)
end
ensure
release :RunOptions, run_options
if input_tensor
input_feed.size.times do |i|
@@ -178,46 +178,60 @@
allocator_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
input_feed.each_with_index do |(input_name, input), idx|
- input = input.to_a unless input.is_a?(Array)
+ if numo_array?(input)
+ shape = input.shape
+ else
+ input = input.to_a unless input.is_a?(Array)
- shape = []
- s = input
- while s.is_a?(Array)
- shape << s.size
- s = s.first
+ shape = []
+ s = input
+ while s.is_a?(Array)
+ shape << s.size
+ s = s.first
+ end
end
- flat_input = input.flatten
- input_tensor_size = flat_input.size
-
# TODO support more types
inp = @inputs.find { |i| i[:name] == input_name.to_s }
raise Error, "Unknown input: #{input_name}" unless inp
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
input_node_dims.write_array_of_int64(shape)
if inp[:type] == "tensor(string)"
- input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
- input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
+ if numo_array?(input)
+ input_tensor_size = input.size
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input.size)
+ input_tensor_values.write_array_of_pointer(input_tensor_size.times.map { |i| ::FFI::MemoryPointer.from_string(input[i]) })
+ else
+ flat_input = input.flatten.to_a
+ input_tensor_size = flat_input.size
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
+ input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
+ end
type_enum = FFI::TensorElementDataType[:string]
check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
- check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
+ check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, input_tensor_size)
else
- tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
tensor_type = tensor_types[inp[:type]]
if tensor_type
- input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size)
- if tensor_type == :bool
- tensor_type = :uchar
- flat_input = flat_input.map { |v| v ? 1 : 0 }
+ if numo_array?(input)
+ input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
+ else
+ flat_input = input.flatten.to_a
+ input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
+ if tensor_type == :bool
+ tensor_type = :uchar
+ flat_input = flat_input.map { |v| v ? 1 : 0 }
+ end
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
end
- input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
+
type_enum = FFI::TensorElementDataType[tensor_type]
else
unsupported_type("input", inp[:type])
end
@@ -232,11 +246,11 @@
ptr = ::FFI::MemoryPointer.new(:pointer, names.size)
ptr.write_array_of_pointer(names.map { |v| ::FFI::MemoryPointer.from_string(v) })
ptr
end
- def create_from_onnx_value(out_ptr)
+ def create_from_onnx_value(out_ptr, output_type)
out_type = ::FFI::MemoryPointer.new(:int)
check_status api[:GetValueType].call(out_ptr, out_type)
type = FFI::OnnxType[out_type.read_int]
case type
@@ -255,29 +269,48 @@
release :TensorTypeAndShapeInfo, typeinfo
# TODO support more types
type = FFI::TensorElementDataType[type]
- arr =
+
+ case output_type
+ when :numo
case type
- when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
- tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
- when :bool
- tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
+ when :string
+ result = Numo::RObject.new(shape)
+ result.allocate
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
else
- unsupported_type("element", type)
+ numo_type = numo_types[type]
+ unsupported_type("element", type) unless numo_type
+ numo_type.from_binary(tensor_data.read_pointer.read_bytes(output_tensor_size * numo_type::ELEMENT_BYTE_SIZE), shape)
end
+ when :ruby
+ arr =
+ case type
+ when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
+ tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
+ when :bool
+ tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
+ when :string
+ create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
+ else
+ unsupported_type("element", type)
+ end
- Utils.reshape(arr, shape)
+ Utils.reshape(arr, shape)
+ else
+ raise ArgumentError, "Invalid output type: #{output_type}"
+ end
when :sequence
out = ::FFI::MemoryPointer.new(:size_t)
check_status api[:GetValueCount].call(out_ptr, out)
out.read(:size_t).times.map do |i|
seq = ::FFI::MemoryPointer.new(:pointer)
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
- create_from_onnx_value(seq.read_pointer)
+ create_from_onnx_value(seq.read_pointer, output_type)
end
when :map
type_shape = ::FFI::MemoryPointer.new(:pointer)
map_keys = ::FFI::MemoryPointer.new(:pointer)
map_values = ::FFI::MemoryPointer.new(:pointer)
@@ -292,12 +325,12 @@
# TODO support more types
elem_type = FFI::TensorElementDataType[elem_type.read_int]
case elem_type
when :int64
ret = {}
- keys = create_from_onnx_value(map_keys.read_pointer)
- values = create_from_onnx_value(map_values.read_pointer)
+ keys = create_from_onnx_value(map_keys.read_pointer, output_type)
+ values = create_from_onnx_value(map_values.read_pointer, output_type)
keys.zip(values).each do |k, v|
ret[k] = v
end
ret
else
@@ -306,10 +339,27 @@
else
unsupported_type("ONNX", type)
end
end
+ def create_strings_from_onnx_value(out_ptr, output_tensor_size, result)
+ len = ::FFI::MemoryPointer.new(:size_t)
+ check_status api[:GetStringTensorDataLength].call(out_ptr, len)
+
+ s_len = len.read(:size_t)
+ s = ::FFI::MemoryPointer.new(:uchar, s_len)
+ offsets = ::FFI::MemoryPointer.new(:size_t, output_tensor_size)
+ check_status api[:GetStringTensorContent].call(out_ptr, s, s_len, offsets, output_tensor_size)
+
+ offsets = output_tensor_size.times.map { |i| offsets[i].read(:size_t) }
+ offsets << s_len
+ output_tensor_size.times do |i|
+ result[i] = s.get_bytes(offsets[i], offsets[i + 1] - offsets[i])
+ end
+ result
+ end
+
def read_pointer
@session.read_pointer
end
def check_status(status)
@@ -386,9 +436,33 @@
[type.read_int, node_dims.read_array_of_int64(num_dims)]
end
def unsupported_type(name, type)
raise Error, "Unsupported #{name} type: #{type}"
+ end
+
+ def tensor_types
+ @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
+ end
+
+ def numo_array?(obj)
+ defined?(Numo::NArray) && obj.is_a?(Numo::NArray)
+ end
+
+ def numo_types
+ @numo_types ||= {
+ float: Numo::SFloat,
+ uint8: Numo::UInt8,
+ int8: Numo::Int8,
+ uint16: Numo::UInt16,
+ int16: Numo::Int16,
+ int32: Numo::Int32,
+ int64: Numo::Int64,
+ bool: Numo::UInt8,
+ double: Numo::DFloat,
+ uint32: Numo::UInt32,
+ uint64: Numo::UInt64
+ }
end
def api
self.class.api
end