lib/onnxruntime/inference_session.rb in onnxruntime-0.3.3 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.4.0
- old
+ new
@@ -6,29 +6,29 @@
# session options
session_options = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CreateSessionOptions].call(session_options)
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
- check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
+ check_status api[:EnableProfiling].call(session_options.read_pointer, ort_string("onnxruntime_profile_")) if enable_profiling
if execution_mode
execution_modes = {sequential: 0, parallel: 1}
mode = execution_modes[execution_mode]
raise ArgumentError, "Invalid execution mode" unless mode
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
end
if graph_optimization_level
optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
- # TODO raise error in 0.4.0
- level = optimization_levels[graph_optimization_level] || graph_optimization_level
+ level = optimization_levels[graph_optimization_level]
+ raise ArgumentError, "Invalid graph optimization level" unless level
check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
end
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
- check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, ort_string(optimized_model_filepath)) if optimized_model_filepath
# session
@session = ::FFI::MemoryPointer.new(:pointer)
from_memory =
if path_or_bytes.respond_to?(:read)
@@ -37,20 +37,14 @@
else
path_or_bytes = path_or_bytes.to_str
path_or_bytes.encoding == Encoding::BINARY
end
- # fix for Windows "File doesn't exist"
- if Gem.win_platform? && !from_memory
- path_or_bytes = File.binread(path_or_bytes)
- from_memory = true
- end
-
if from_memory
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
else
- check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
+ check_status api[:CreateSession].call(env.read_pointer, ort_string(path_or_bytes), session_options.read_pointer, @session)
end
ObjectSpace.define_finalizer(self, self.class.finalize(@session))
# input info
allocator = ::FFI::MemoryPointer.new(:pointer)
@@ -61,22 +55,22 @@
@outputs = []
# input
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
- read_size_t(num_input_nodes).times do |i|
+ num_input_nodes.read(:size_t).times do |i|
name_ptr = ::FFI::MemoryPointer.new(:string)
check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
typeinfo = ::FFI::MemoryPointer.new(:pointer)
check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo)
@inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
end
# output
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
- read_size_t(num_output_nodes).times do |i|
+ num_output_nodes.read(:size_t).times do |i|
name_ptr = ::FFI::MemoryPointer.new(:string)
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
typeinfo = ::FFI::MemoryPointer.new(:pointer)
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
@@ -154,16 +148,32 @@
}
ensure
release :ModelMetadata, metadata
end
+ # return value has double underscore like Python
def end_profiling
out = ::FFI::MemoryPointer.new(:string)
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
out.read_pointer.read_string
end
+ # no way to set providers with C API yet
+ # so we can return all available providers
+ def providers
+ out_ptr = ::FFI::MemoryPointer.new(:pointer)
+ length_ptr = ::FFI::MemoryPointer.new(:int)
+ check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
+ length = length_ptr.read_int
+ providers = []
+ length.times do |i|
+ providers << out_ptr.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
+ end
+ api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
+ providers
+ end
+
private
def create_input_tensor(input_feed)
allocator_info = ::FFI::MemoryPointer.new(:pointer)
check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
@@ -182,11 +192,11 @@
flat_input = input.flatten
input_tensor_size = flat_input.size
# TODO support more types
inp = @inputs.find { |i| i[:name] == input_name.to_s }
- raise "Unknown input: #{input_name}" unless inp
+ raise Error, "Unknown input: #{input_name}" unless inp
input_node_dims = ::FFI::MemoryPointer.new(:int64, shape.size)
input_node_dims.write_array_of_int64(shape)
if inp[:type] == "tensor(string)"
@@ -239,11 +249,11 @@
tensor_data = ::FFI::MemoryPointer.new(:pointer)
check_status api[:GetTensorMutableData].call(out_ptr, tensor_data)
out_size = ::FFI::MemoryPointer.new(:size_t)
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
- output_tensor_size = read_size_t(out_size)
+ output_tensor_size = out_size.read(:size_t)
release :TensorTypeAndShapeInfo, typeinfo
# TODO support more types
type = FFI::TensorElementDataType[type]
@@ -260,11 +270,11 @@
Utils.reshape(arr, shape)
when :sequence
out = ::FFI::MemoryPointer.new(:size_t)
check_status api[:GetValueCount].call(out_ptr, out)
- read_size_t(out).times.map do |i|
+ out.read(:size_t).times.map do |i|
seq = ::FFI::MemoryPointer.new(:pointer)
check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
create_from_onnx_value(seq.read_pointer)
end
when :map
@@ -304,11 +314,11 @@
def check_status(status)
unless status.null?
message = api[:GetErrorMessage].call(status).read_string
api[:ReleaseStatus].call(status)
- raise OnnxRuntime::Error, message
+ raise Error, message
end
end
def node_info(typeinfo)
onnx_type = ::FFI::MemoryPointer.new(:int)
@@ -366,49 +376,55 @@
type = ::FFI::MemoryPointer.new(:int)
check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type)
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
- num_dims = read_size_t(num_dims_ptr)
+ num_dims = num_dims_ptr.read(:size_t)
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
[type.read_int, node_dims.read_array_of_int64(num_dims)]
end
def unsupported_type(name, type)
- raise "Unsupported #{name} type: #{type}"
+ raise Error, "Unsupported #{name} type: #{type}"
end
- # read(:size_t) not supported in FFI JRuby
- def read_size_t(ptr)
- if RUBY_PLATFORM == "java"
- ptr.read_long
- else
- ptr.read(:size_t)
- end
- end
-
def api
self.class.api
end
def release(*args)
self.class.release(*args)
end
def self.api
- @api ||= FFI.OrtGetApiBase[:GetApi].call(3)
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(4)
end
def self.release(type, pointer)
api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
end
def self.finalize(session)
# must use proc instead of stabby lambda
proc { release :Session, session }
+ end
+
+ # wide string on Windows
+ # char string on Linux
+ # see ORTCHAR_T in onnxruntime_c_api.h
+ def ort_string(str)
+ if Gem.win_platform?
+ max = str.size + 1 # for null byte
+ dest = ::FFI::MemoryPointer.new(:wchar_t, max)
+ ret = FFI::Libc.mbstowcs(dest, str, max)
+ raise Error, "Expected mbstowcs to return #{str.size}, got #{ret}" if ret != str.size
+ dest
+ else
+ str
+ end
end
def env
# use mutex for thread-safety
Utils.mutex.synchronize do