lib/onnxruntime/inference_session.rb in onnxruntime-0.1.2 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.2.0

- old
+ new

@@ -3,11 +3,11 @@ attr_reader :inputs, :outputs def initialize(path_or_bytes) # session options session_options = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtCreateSessionOptions(session_options) + check_status api[:CreateSessionOptions].call(session_options) # session @session = ::FFI::MemoryPointer.new(:pointer) path_or_bytes = path_or_bytes.to_str @@ -15,42 +15,42 @@ if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY path_or_bytes = File.binread(path_or_bytes) end if path_or_bytes.encoding == Encoding::BINARY - check_status FFI.OrtCreateSessionFromArray(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session) + check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session) else - check_status FFI.OrtCreateSession(env.read_pointer, path_or_bytes, session_options.read_pointer, @session) + check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session) end # input info allocator = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtCreateDefaultAllocator(allocator) + check_status api[:GetAllocatorWithDefaultOptions].call(allocator) @allocator = allocator @inputs = [] @outputs = [] # input num_input_nodes = ::FFI::MemoryPointer.new(:size_t) - check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes) + check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes) read_size_t(num_input_nodes).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) - check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr) + check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtSessionGetInputTypeInfo(read_pointer, i, typeinfo) + check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo) @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end # output num_output_nodes = ::FFI::MemoryPointer.new(:size_t) - check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes) + check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes) read_size_t(num_output_nodes).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) - check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr) + check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtSessionGetOutputTypeInfo(read_pointer, i, typeinfo) + check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo) @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end end def run(output_names, input_feed) @@ -60,22 +60,22 @@ output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size) input_node_names = create_node_names(input_feed.keys.map(&:to_s)) output_node_names = create_node_names(output_names.map(&:to_s)) # TODO support run options - check_status FFI.OrtRun(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) + check_status api[:Run].call(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) output_names.size.times.map do |i| create_from_onnx_value(output_tensor[i].read_pointer) end end private def create_input_tensor(input_feed) allocator_info = ::FFI::MemoryPointer.new(:pointer) - check_status = FFI.OrtCreateCpuAllocatorInfo(1, 0, allocator_info) + check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info) input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size) input_feed.each_with_index do |(input_name, input), idx| input = input.to_a unless input.is_a?(Array) @@ -98,12 +98,12 @@ if inp[:type] == "tensor(string)" input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size) input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) }) type_enum = FFI::TensorElementDataType[:string] - check_status FFI.OrtCreateTensorAsOrtValue(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx]) - check_status FFI.OrtFillStringTensor(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size) + check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx]) + check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size) else tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h tensor_type = tensor_types[inp[:type]] if tensor_type @@ -116,11 +116,11 @@ type_enum = FFI::TensorElementDataType[tensor_type] else unsupported_type("input", inp[:type]) end - check_status FFI.OrtCreateTensorWithDataAsOrtValue(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx]) + check_status api[:CreateTensorWithDataAsOrtValue].call(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx]) end end input_tensor end @@ -131,25 +131,25 @@ ptr end def create_from_onnx_value(out_ptr) out_type = ::FFI::MemoryPointer.new(:int) - check_status = FFI.OrtGetValueType(out_ptr, out_type) + check_status = api[:GetValueType].call(out_ptr, out_type) type = FFI::OnnxType[out_type.read_int] case type when :tensor typeinfo = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtGetTensorTypeAndShape(out_ptr, typeinfo) + check_status api[:GetTensorTypeAndShape].call(out_ptr, typeinfo) type, shape = tensor_type_and_shape(typeinfo) tensor_data = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtGetTensorMutableData(out_ptr, tensor_data) + check_status api[:GetTensorMutableData].call(out_ptr, tensor_data) out_size = ::FFI::MemoryPointer.new(:size_t) - output_tensor_size = FFI.OrtGetTensorShapeElementCount(typeinfo.read_pointer, out_size) + output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size) output_tensor_size = read_size_t(out_size) # TODO support more types type = FFI::TensorElementDataType[type] arr = @@ -163,27 +163,27 @@ end Utils.reshape(arr, shape) when :sequence out = ::FFI::MemoryPointer.new(:size_t) - check_status FFI.OrtGetValueCount(out_ptr, out) + check_status api[:GetValueCount].call(out_ptr, out) read_size_t(out).times.map do |i| seq = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtGetValue(out_ptr, i, @allocator.read_pointer, seq) + check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq) create_from_onnx_value(seq.read_pointer) end when :map type_shape = ::FFI::MemoryPointer.new(:pointer) map_keys = ::FFI::MemoryPointer.new(:pointer) map_values = ::FFI::MemoryPointer.new(:pointer) elem_type = ::FFI::MemoryPointer.new(:int) - check_status FFI.OrtGetValue(out_ptr, 0, @allocator.read_pointer, map_keys) - check_status FFI.OrtGetValue(out_ptr, 1, @allocator.read_pointer, map_values) - check_status FFI.OrtGetTensorTypeAndShape(map_keys.read_pointer, type_shape) - check_status FFI.OrtGetTensorElementType(type_shape.read_pointer, elem_type) + check_status api[:GetValue].call(out_ptr, 0, @allocator.read_pointer, map_keys) + check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values) + check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape) + check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type) # TODO support more types elem_type = FFI::TensorElementDataType[elem_type.read_int] case elem_type when :int64 @@ -206,25 +206,25 @@ @session.read_pointer end def check_status(status) unless status.null? - message = FFI.OrtGetErrorMessage(status) - FFI.OrtReleaseStatus(status) + message = api[:GetErrorMessage].call(status) + api[:ReleaseStatus].call(status) raise OnnxRuntime::Error, message end end def node_info(typeinfo) onnx_type = ::FFI::MemoryPointer.new(:int) - check_status FFI.OrtOnnxTypeFromTypeInfo(typeinfo.read_pointer, onnx_type) + check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type) type = FFI::OnnxType[onnx_type.read_int] case type when :tensor tensor_info = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtCastTypeInfoToTensorInfo(typeinfo.read_pointer, tensor_info) + check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info) type, shape = tensor_type_and_shape(tensor_info) { type: "tensor(#{FFI::TensorElementDataType[type]})", shape: shape @@ -243,23 +243,23 @@ } else unsupported_type("ONNX", type) end ensure - FFI.OrtReleaseTypeInfo(typeinfo.read_pointer) + api[:ReleaseTypeInfo].call(typeinfo.read_pointer) end def tensor_type_and_shape(tensor_info) type = ::FFI::MemoryPointer.new(:int) - check_status FFI.OrtGetTensorElementType(tensor_info.read_pointer, type) + check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type) num_dims_ptr = ::FFI::MemoryPointer.new(:size_t) - check_status FFI.OrtGetDimensionsCount(tensor_info.read_pointer, num_dims_ptr) + check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr) num_dims = read_size_t(num_dims_ptr) node_dims = ::FFI::MemoryPointer.new(:int64, num_dims) - check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims) + check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims) [type.read_int, node_dims.read_array_of_int64(num_dims)] end def unsupported_type(name, type) @@ -273,16 +273,23 @@ else ptr.read(:size_t) end end + def api + @api ||= FFI.OrtGetApiBase[:GetApi].call(1) + end + def env # use mutex for thread-safety Utils.mutex.synchronize do @@env ||= begin env = ::FFI::MemoryPointer.new(:pointer) - check_status FFI.OrtCreateEnv(3, "Default", env) - at_exit { FFI.OrtReleaseEnv(env.read_pointer) } + check_status api[:CreateEnv].call(3, "Default", env) + at_exit { api[:ReleaseEnv].call(env.read_pointer) } + # disable telemetry + # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md + check_status api[:DisableTelemetryEvents].call(env) env end end end end