lib/onnxruntime/inference_session.rb in onnxruntime-0.5.1 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.5.2
- old
+ new
@@ -222,14 +222,14 @@
input_tensor_values = input.cast_to(numo_types[tensor_type]).to_binary
else
flat_input = input.flatten.to_a
input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, flat_input.size)
if tensor_type == :bool
- tensor_type = :uchar
- flat_input = flat_input.map { |v| v ? 1 : 0 }
+ input_tensor_values.write_array_of_uint8(flat_input.map { |v| v ? 1 : 0 })
+ else
+ input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
end
- input_tensor_values.send("write_array_of_#{tensor_type}", flat_input)
end
type_enum = FFI::TensorElementDataType[tensor_type]
else
unsupported_type("input", inp[:type])
@@ -288,10 +288,10 @@
arr =
case type
when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size)
when :bool
- tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 }
+ tensor_data.read_pointer.read_array_of_uint8(output_tensor_size).map { |v| v == 1 }
when :string
create_strings_from_onnx_value(out_ptr, output_tensor_size, [])
else
unsupported_type("element", type)
end