lib/onnxruntime/inference_session.rb in onnxruntime-0.1.0 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.1.1
- old
+ new
@@ -8,10 +8,16 @@
check_status FFI.OrtCreateSessionOptions(session_options)
# session
@session = ::FFI::MemoryPointer.new(:pointer)
path_or_bytes = path_or_bytes.to_str
+
+ # fix for Windows "File doesn't exist"
+ if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
+ path_or_bytes = File.binread(path_or_bytes)
+ end
+
if path_or_bytes.encoding == Encoding::BINARY
check_status FFI.OrtCreateSessionFromArray(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
else
check_status FFI.OrtCreateSession(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
end
@@ -48,23 +54,19 @@
end
def run(output_names, input_feed)
input_tensor = create_input_tensor(input_feed)
- outputs = @outputs
- if output_names
- output_names = output_names.map(&:to_s)
- outputs = outputs.select { |o| output_names.include?(o[:name]) }
- end
+ output_names ||= @outputs.map { |v| v[:name] }
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
input_node_names = create_node_names(input_feed.keys.map(&:to_s))
- output_node_names = create_node_names(outputs.map { |v| v[:name] })
+ output_node_names = create_node_names(output_names.map(&:to_s))
# TODO support run options
- check_status FFI.OrtRun(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, outputs.size, output_tensor)
+ check_status FFI.OrtRun(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
- outputs.size.times.map do |i|
+ output_names.size.times.map do |i|
create_from_onnx_value(output_tensor[i].read_pointer)
end
end
private
@@ -84,25 +86,40 @@
flat_input = input.flatten
input_tensor_size = flat_input.size
# TODO support more types
- inp = @inputs.find { |i| i[:name] == input_name.to_s } || {}
- case inp[:type]
- when "tensor(bool)"
- input_tensor_values = ::FFI::MemoryPointer.new(:uchar, input_tensor_size)
- input_tensor_values.write_array_of_uchar(flat_input.map { |v| v ? 1 : 0 })
- type_enum = FFI::TensorElementDataType[:bool]
- else
- input_tensor_values = ::FFI::MemoryPointer.new(:float, input_tensor_size)
- input_tensor_values.write_array_of_float(flat_input)
- type_enum = FFI::TensorElementDataType[:float]
- end
+ inp = @inputs.find { |i| i[:name] == input_name.to_s }
+ raise "Unknown input: #{input_name}" unless inp
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
input_node_dims.write_array_of_int64(shape)
- check_status FFI.OrtCreateTensorWithDataAsOrtValue(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx])
+
+ if inp[:type] == "tensor(string)"
+ input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
+ input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
+ type_enum = FFI::TensorElementDataType[:string]
+ check_status FFI.OrtCreateTensorAsOrtValue(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
+ check_status FFI.OrtFillStringTensor(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
+ else
+ tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
+ tensor_type = tensor_types[inp[:type]]
+
+ if tensor_type
+ input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size)
+ if tensor_type == :bool
+ tensor_type = :uchar
+ flat_input = flat_input.map { |v| v ? 1 : 0 }
+ end
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
+ type_enum = FFI::TensorElementDataType[tensor_type]
+ else
+ unsupported_type("input", inp[:type])
+ end
+
+ check_status FFI.OrtCreateTensorWithDataAsOrtValue(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx])
+ end
end
input_tensor
end
@@ -133,18 +150,16 @@
# TODO support more types
type = FFI::TensorElementDataType[type]
arr =
case type
- when :float
- tensor_data.read_pointer.read_array_of_float(output_tensor_size)
- when :int64
- tensor_data.read_pointer.read_array_of_int64(output_tensor_size)
+ when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
+ tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
when :bool
tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
else
- raise "Unsupported element type: #{type}"
+ unsupported_type("element", type)
end
Utils.reshape(arr, shape)
when :sequence
out = ::FFI::MemoryPointer.new(:size_t)
@@ -176,14 +191,14 @@
keys.zip(values).each do |k, v|
ret[k] = v
end
ret
else
- raise "Unsupported element type: #{elem_type}"
+ unsupported_type("element", elem_type)
end
else
- raise "Unsupported ONNX type: #{type}"
+ unsupported_type("ONNX", type)
end
end
def read_pointer
@session.read_pointer
@@ -223,11 +238,11 @@
{
type: "map",
shape: []
}
else
- raise "Unsupported ONNX type: #{type}"
+ unsupported_type("ONNX", type)
end
ensure
FFI.OrtReleaseTypeInfo(typeinfo.read_pointer)
end
@@ -241,9 +256,13 @@
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims)
[type.read_int, node_dims.read_array_of_int64(num_dims)]
+ end
+
+ def unsupported_type(name, type)
+ raise "Unsupported #{name} type: #{type}"
end
# share env
# TODO mutex around creation?
def env