module OnnxRuntime class InferenceSession attr_reader :inputs, :outputs def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil) # session options session_options = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateSessionOptions].call(session_options) check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling if execution_mode mode = case execution_mode when :sequential 0 when :parallel 1 else raise ArgumentError, "Invalid execution mode" end check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode) end check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath # session @session = ::FFI::MemoryPointer.new(:pointer) from_memory = if path_or_bytes.respond_to?(:read) path_or_bytes = path_or_bytes.read true else path_or_bytes = path_or_bytes.to_str path_or_bytes.encoding == Encoding::BINARY end # fix for Windows "File doesn't exist" if Gem.win_platform? && !from_memory path_or_bytes = File.binread(path_or_bytes) from_memory = true end if from_memory check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session) else check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session) end # input info allocator = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetAllocatorWithDefaultOptions].call(allocator) @allocator = allocator @inputs = [] @outputs = [] # input num_input_nodes = ::FFI::MemoryPointer.new(:size_t) check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes) read_size_t(num_input_nodes).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo) @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end # output num_output_nodes = ::FFI::MemoryPointer.new(:size_t) check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes) read_size_t(num_output_nodes).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo) @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end end # TODO support logid def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) input_tensor = create_input_tensor(input_feed) output_names ||= @outputs.map { |v| v[:name] } output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size) input_node_names = create_node_names(input_feed.keys.map(&:to_s)) output_node_names = create_node_names(output_names.map(&:to_s)) # run options run_options = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateRunOptions].call(run_options) check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) output_names.size.times.map do |i| create_from_onnx_value(output_tensor[i].read_pointer) end end def modelmeta keys = ::FFI::MemoryPointer.new(:pointer) num_keys = ::FFI::MemoryPointer.new(:int64_t) description = ::FFI::MemoryPointer.new(:string) domain = ::FFI::MemoryPointer.new(:string) graph_name = ::FFI::MemoryPointer.new(:string) producer_name = ::FFI::MemoryPointer.new(:string) version = ::FFI::MemoryPointer.new(:int64_t) metadata = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetModelMetadata].call(read_pointer, metadata) custom_metadata_map = {} check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys) num_keys.read(:int64_t).times do |i| key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string value = ::FFI::MemoryPointer.new(:string) check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value) custom_metadata_map[key] = value.read_pointer.read_string end check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description) check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain) check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name) check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name) check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version) api[:ReleaseModelMetadata].call(metadata.read_pointer) { custom_metadata_map: custom_metadata_map, description: description.read_pointer.read_string, domain: domain.read_pointer.read_string, graph_name: graph_name.read_pointer.read_string, producer_name: producer_name.read_pointer.read_string, version: version.read(:int64_t) } end def end_profiling out = ::FFI::MemoryPointer.new(:string) check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out) out.read_pointer.read_string end private def create_input_tensor(input_feed) allocator_info = ::FFI::MemoryPointer.new(:pointer) check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info) input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size) input_feed.each_with_index do |(input_name, input), idx| input = input.to_a unless input.is_a?(Array) shape = [] s = input while s.is_a?(Array) shape << s.size s = s.first end flat_input = input.flatten input_tensor_size = flat_input.size # TODO support more types inp = @inputs.find { |i| i[:name] == input_name.to_s } raise "Unknown input: #{input_name}" unless inp input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size) input_node_dims.write_array_of_int64(shape) if inp[:type] == "tensor(string)" input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size) input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) }) type_enum = FFI::TensorElementDataType[:string] check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx]) check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size) else tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h tensor_type = tensor_types[inp[:type]] if tensor_type input_tensor_values = ::FFI::MemoryPointer.new(tensor_type, input_tensor_size) if tensor_type == :bool tensor_type = :uchar flat_input = flat_input.map { |v| v ? 1 : 0 } end input_tensor_values.send("write_array_of_#{tensor_type}", flat_input) type_enum = FFI::TensorElementDataType[tensor_type] else unsupported_type("input", inp[:type]) end check_status api[:CreateTensorWithDataAsOrtValue].call(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx]) end end input_tensor end def create_node_names(names) ptr = ::FFI::MemoryPointer.new(:pointer, names.size) ptr.write_array_of_pointer(names.map { |v| ::FFI::MemoryPointer.from_string(v) }) ptr end def create_from_onnx_value(out_ptr) out_type = ::FFI::MemoryPointer.new(:int) check_status = api[:GetValueType].call(out_ptr, out_type) type = FFI::OnnxType[out_type.read_int] case type when :tensor typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetTensorTypeAndShape].call(out_ptr, typeinfo) type, shape = tensor_type_and_shape(typeinfo) tensor_data = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetTensorMutableData].call(out_ptr, tensor_data) out_size = ::FFI::MemoryPointer.new(:size_t) output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size) output_tensor_size = read_size_t(out_size) # TODO support more types type = FFI::TensorElementDataType[type] arr = case type when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64 tensor_data.read_pointer.send("read_array_of_#{type}", output_tensor_size) when :bool tensor_data.read_pointer.read_array_of_uchar(output_tensor_size).map { |v| v == 1 } else unsupported_type("element", type) end Utils.reshape(arr, shape) when :sequence out = ::FFI::MemoryPointer.new(:size_t) check_status api[:GetValueCount].call(out_ptr, out) read_size_t(out).times.map do |i| seq = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq) create_from_onnx_value(seq.read_pointer) end when :map type_shape = ::FFI::MemoryPointer.new(:pointer) map_keys = ::FFI::MemoryPointer.new(:pointer) map_values = ::FFI::MemoryPointer.new(:pointer) elem_type = ::FFI::MemoryPointer.new(:int) check_status api[:GetValue].call(out_ptr, 0, @allocator.read_pointer, map_keys) check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values) check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape) check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type) # TODO support more types elem_type = FFI::TensorElementDataType[elem_type.read_int] case elem_type when :int64 ret = {} keys = create_from_onnx_value(map_keys.read_pointer) values = create_from_onnx_value(map_values.read_pointer) keys.zip(values).each do |k, v| ret[k] = v end ret else unsupported_type("element", elem_type) end else unsupported_type("ONNX", type) end end def read_pointer @session.read_pointer end def check_status(status) unless status.null? message = api[:GetErrorMessage].call(status) api[:ReleaseStatus].call(status) raise OnnxRuntime::Error, message end end def node_info(typeinfo) onnx_type = ::FFI::MemoryPointer.new(:int) check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type) type = FFI::OnnxType[onnx_type.read_int] case type when :tensor tensor_info = ::FFI::MemoryPointer.new(:pointer) check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info) type, shape = tensor_type_and_shape(tensor_info) { type: "tensor(#{FFI::TensorElementDataType[type]})", shape: shape } when :sequence sequence_type_info = ::FFI::MemoryPointer.new(:pointer) check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info) nested_type_info = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info) v = node_info(nested_type_info)[:type] { type: "seq(#{v})", shape: [] } when :map map_type_info = ::FFI::MemoryPointer.new(:pointer) check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info) # key key_type = ::FFI::MemoryPointer.new(:int) check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type) k = FFI::TensorElementDataType[key_type.read_int] # value value_type_info = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info) v = node_info(value_type_info)[:type] { type: "map(#{k},#{v})", shape: [] } else unsupported_type("ONNX", type) end ensure api[:ReleaseTypeInfo].call(typeinfo.read_pointer) end def tensor_type_and_shape(tensor_info) type = ::FFI::MemoryPointer.new(:int) check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type) num_dims_ptr = ::FFI::MemoryPointer.new(:size_t) check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr) num_dims = read_size_t(num_dims_ptr) node_dims = ::FFI::MemoryPointer.new(:int64, num_dims) check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims) [type.read_int, node_dims.read_array_of_int64(num_dims)] end def unsupported_type(name, type) raise "Unsupported #{name} type: #{type}" end # read(:size_t) not supported in FFI JRuby def read_size_t(ptr) if RUBY_PLATFORM == "java" ptr.read_long else ptr.read(:size_t) end end def api @api ||= FFI.OrtGetApiBase[:GetApi].call(1) end def env # use mutex for thread-safety Utils.mutex.synchronize do @@env ||= begin env = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateEnv].call(3, "Default", env) at_exit { api[:ReleaseEnv].call(env.read_pointer) } # disable telemetry # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md check_status api[:DisableTelemetryEvents].call(env) env end end end end end