lib/onnxruntime/inference_session.rb in onnxruntime-0.1.0 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.1.1

- old
+ new

@@ -8,10 +8,16 @@ check_status FFI.OrtCreateSessionOptions(session_options) # session @session = ::FFI::MemoryPointer.new(:pointer) path_or_bytes = path_or_bytes.to_str + + # fix for Windows "File doesn't exist" + if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY + path_or_bytes = File.binread(path_or_bytes) + end + if path_or_bytes.encoding == Encoding::BINARY check_status FFI.OrtCreateSessionFromArray(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session) else check_status FFI.OrtCreateSession(env.read_pointer, path_or_bytes, session_options.read_pointer, @session) end @@ -48,23 +54,19 @@ end def run(output_names, input_feed) input_tensor = create_input_tensor(input_feed) - outputs = @outputs - if output_names - output_names = output_names.map(&:to_s) - outputs = outputs.select { |o| output_names.include?(o[:name]) } - end + output_names ||= @outputs.map { |v| v[:name] } output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size) input_node_names = create_node_names(input_feed.keys.map(&:to_s)) - output_node_names = create_node_names(outputs.map { |v| v[:name] }) + output_node_names = create_node_names(output_names.map(&:to_s)) # TODO support run options - check_status FFI.OrtRun(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, outputs.size, output_tensor) + check_status FFI.OrtRun(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) - outputs.size.times.map do |i| + output_names.size.times.map do |i| create_from_onnx_value(output_tensor[i].read_pointer) end end private @@ -84,25 +86,40 @@ flat_input = input.flatten input_tensor_size = flat_input.size # TODO support more types - inp = @inputs.find { |i| i[:name] == input_name.to_s } || {} - case inp[:type] - when "tensor(bool)" - input_tensor_values = ::FFI::MemoryPointer.new(:uchar, input_tensor_size) - input_tensor_values.write_array_of_uchar(flat_input.map { |v| v ? 1 : 0 }) - type_enum = FFI::TensorElementDataType[:bool] - else - input_tensor_values = ::FFI::MemoryPointer.new(:float, input_tensor_size) - input_tensor_values.write_array_of_float(flat_input) - type_enum = FFI::TensorElementDataType[:float] - end + inp = @inputs.find { |i| i[:name] == input_name.to_s } + raise "Unknown input: #{input_name}" unless inp input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size) input_node_dims.write_array_of_int64(shape) - check_status FFI.OrtCreateTensorWithDataAsOrtValue(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx]) + + if inp[:type] == "tensor(string)" + input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size) + input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) }) + type_enum = FFI::TensorElementDataType[:string] + check_status FFI.OrtCreateTensorAsOrtValue(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx]) + check_status FFI.OrtFillStringTensor(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size) + else + tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h + tensor_type = tensor_types[inp[:type]] + + if tensor_type + input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size) + if tensor_type == :bool + tensor_type = :uchar + flat_input = flat_input.map { |v| v ? 1 : 0 } + end + input_tensor_values.send("write_array_of_#{tensor_type}", flat_input) + type_enum = FFI::TensorElementDataType[tensor_type] + else + unsupported_type("input", inp[:type]) + end + + check_status FFI.OrtCreateTensorWithDataAsOrtValue(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx]) + end end input_tensor end @@ -133,18 +150,16 @@ # TODO support more types type = FFI::TensorElementDataType[type] arr = case type - when :float - tensor_data.read_pointer.read_array_of_float(output_tensor_size) - when :int64 - tensor_data.read_pointer.read_array_of_int64(output_tensor_size) + when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64 + tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size) when :bool tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 } else - raise "Unsupported element type: #{type}" + unsupported_type("element", type) end Utils.reshape(arr, shape) when :sequence out = ::FFI::MemoryPointer.new(:size_t) @@ -176,14 +191,14 @@ keys.zip(values).each do |k, v| ret[k] = v end ret else - raise "Unsupported element type: #{elem_type}" + unsupported_type("element", elem_type) end else - raise "Unsupported ONNX type: #{type}" + unsupported_type("ONNX", type) end end def read_pointer @session.read_pointer @@ -223,11 +238,11 @@ { type: "map", shape: [] } else - raise "Unsupported ONNX type: #{type}" + unsupported_type("ONNX", type) end ensure FFI.OrtReleaseTypeInfo(typeinfo.read_pointer) end @@ -241,9 +256,13 @@ node_dims = ::FFI::MemoryPointer.new(:int64, num_dims) check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims) [type.read_int, node_dims.read_array_of_int64(num_dims)] + end + + def unsupported_type(name, type) + raise "Unsupported #{name} type: #{type}" end # share env # TODO mutex around creation? def env