lib/onnxruntime/inference_session.rb in onnxruntime-0.3.3 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.4.0

- old
+ new

@@ -6,29 +6,29 @@ # session options session_options = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateSessionOptions].call(session_options) check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern - check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling + check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling if execution_mode execution_modes = {sequential: 0, parallel: 1} mode = execution_modes[execution_mode] raise ArgumentError, "Invalid execution mode" unless mode check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode) end if graph_optimization_level optimization_levels = {none: 0, basic: 1, extended: 2, all: 99} - # TODO raise error in 0.4.0 - level = optimization_levels[graph_optimization_level] || graph_optimization_level + level = optimization_levels[graph_optimization_level] + raise ArgumentError, "Invalid graph optimization level" unless level check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level) end check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid - check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath + check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath # session @session = ::FFI::MemoryPointer.new(:pointer) from_memory = if path_or_bytes.respond_to?(:read) @@ -37,20 +37,14 @@ else path_or_bytes = path_or_bytes.to_str path_or_bytes.encoding == Encoding::BINARY end - # fix for Windows "File doesn't exist" - if Gem.win_platform? && !from_memory - path_or_bytes = File.binread(path_or_bytes) - from_memory = true - end - if from_memory check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session) else - check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session) + check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session) end ObjectSpace.define_finalizer(self, self.class.finalize(@session)) # input info allocator = ::FFI::MemoryPointer.new(:pointer) @@ -61,22 +55,22 @@ @outputs = [] # input num_input_nodes = ::FFI::MemoryPointer.new(:size_t) check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes) - read_size_t(num_input_nodes).times do |i| + num_input_nodes.read(:size_t).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo) @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end # output num_output_nodes = ::FFI::MemoryPointer.new(:size_t) check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes) - read_size_t(num_output_nodes).times do |i| + num_output_nodes.read(:size_t).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo) @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) @@ -154,16 +148,32 @@ } ensure release :ModelMetadata, metadata end + # return value has double underscore like Python def end_profiling out = ::FFI::MemoryPointer.new(:string) check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out) out.read_pointer.read_string end + # no way to set providers with C API yet + # so we can return all available providers + def providers + out_ptr = ::FFI::MemoryPointer.new(:pointer) + length_ptr = ::FFI::MemoryPointer.new(:int) + check_status api[:GetAvailableProviders].call(out_ptr, length_ptr) + length = length_ptr.read_int + providers = [] + length.times do |i| + providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string + end + api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length) + providers + end + private def create_input_tensor(input_feed) allocator_info = ::FFI::MemoryPointer.new(:pointer) check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info) @@ -182,11 +192,11 @@ flat_input = input.flatten input_tensor_size = flat_input.size # TODO support more types inp = @inputs.find { |i| i[:name] == input_name.to_s } - raise "Unknown input: #{input_name}" unless inp + raise Error, "Unknown input: #{input_name}" unless inp input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size) input_node_dims.write_array_of_int64(shape) if inp[:type] == "tensor(string)" @@ -239,11 +249,11 @@ tensor_data = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetTensorMutableData].call(out_ptr, tensor_data) out_size = ::FFI::MemoryPointer.new(:size_t) output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size) - output_tensor_size = read_size_t(out_size) + output_tensor_size = out_size.read(:size_t) release :TensorTypeAndShapeInfo, typeinfo # TODO support more types type = FFI::TensorElementDataType[type] @@ -260,11 +270,11 @@ Utils.reshape(arr, shape) when :sequence out = ::FFI::MemoryPointer.new(:size_t) check_status api[:GetValueCount].call(out_ptr, out) - read_size_t(out).times.map do |i| + out.read(:size_t).times.map do |i| seq = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq) create_from_onnx_value(seq.read_pointer) end when :map @@ -304,11 +314,11 @@ def check_status(status) unless status.null? message = api[:GetErrorMessage].call(status).read_string api[:ReleaseStatus].call(status) - raise OnnxRuntime::Error, message + raise Error, message end end def node_info(typeinfo) onnx_type = ::FFI::MemoryPointer.new(:int) @@ -366,49 +376,55 @@ type = ::FFI::MemoryPointer.new(:int) check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type) num_dims_ptr = ::FFI::MemoryPointer.new(:size_t) check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr) - num_dims = read_size_t(num_dims_ptr) + num_dims = num_dims_ptr.read(:size_t) node_dims = ::FFI::MemoryPointer.new(:int64, num_dims) check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims) [type.read_int, node_dims.read_array_of_int64(num_dims)] end def unsupported_type(name, type) - raise "Unsupported #{name} type: #{type}" + raise Error, "Unsupported #{name} type: #{type}" end - # read(:size_t) not supported in FFI JRuby - def read_size_t(ptr) - if RUBY_PLATFORM == "java" - ptr.read_long - else - ptr.read(:size_t) - end - end - def api self.class.api end def release(*args) self.class.release(*args) end def self.api - @api ||= FFI.OrtGetApiBase[:GetApi].call(3) + @api ||= FFI.OrtGetApiBase[:GetApi].call(4) end def self.release(type, pointer) api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null? end def self.finalize(session) # must use proc instead of stabby lambda proc { release :Session, session } + end + + # wide string on Windows + # char string on Linux + # see ORTCHAR_T in onnxruntime_c_api.h + def ort_string(str) + if Gem.win_platform? + max = str.size + 1 # for null byte + dest = ::FFI::MemoryPointer.new(:wchar_t, max) + ret = FFI::Libc.mbstowcs(dest, str, max) + raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size + dest + else + str + end end def env # use mutex for thread-safety Utils.mutex.synchronize do