lib/onnxruntime/inference_session.rb in onnxruntime-0.1.2 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.2.0
- old
+ new
@@ -3,11 +3,11 @@
attr_reader :inputs, :outputs
def initialize(path_or_bytes)
# session options
session_options = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtCreateSessionOptions(session_options)
+ check_status api[:CreateSessionOptions].call(session_options)
# session
@session = ::FFI::MemoryPointer.new(:pointer)
path_or_bytes = path_or_bytes.to_str
@@ -15,42 +15,42 @@
if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
path_or_bytes = File.binread(path_or_bytes)
end
if path_or_bytes.encoding == Encoding::BINARY
- check_status FFI.OrtCreateSessionFromArray(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
+ check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
else
- check_status FFI.OrtCreateSession(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
+ check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
end
# input info
allocator = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtCreateDefaultAllocator(allocator)
+ check_status api[:GetAllocatorWithDefaultOptions].call(allocator)
@allocator = allocator
@inputs = []
@outputs = []
# input
num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
- check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes)
+ check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
read_size_t(num_input_nodes).times do |i|
name_ptr = ::FFI::MemoryPointer.new(:string)
- check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr)
+ check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
typeinfo = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtSessionGetInputTypeInfo(read_pointer, i, typeinfo)
+ check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo)
@inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
end
# output
num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
- check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes)
+ check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
read_size_t(num_output_nodes).times do |i|
name_ptr = ::FFI::MemoryPointer.new(:string)
- check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr)
+ check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
typeinfo = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtSessionGetOutputTypeInfo(read_pointer, i, typeinfo)
+ check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
end
end
def run(output_names, input_feed)
@@ -60,22 +60,22 @@
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
input_node_names = create_node_names(input_feed.keys.map(&:to_s))
output_node_names = create_node_names(output_names.map(&:to_s))
# TODO support run options
- check_status FFI.OrtRun(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
+ check_status api[:Run].call(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
output_names.size.times.map do |i|
create_from_onnx_value(output_tensor[i].read_pointer)
end
end
private
def create_input_tensor(input_feed)
allocator_info = ::FFI::MemoryPointer.new(:pointer)
- check_status = FFI.OrtCreateCpuAllocatorInfo(1, 0, allocator_info)
+ check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
input_feed.each_with_index do |(input_name, input), idx|
input = input.to_a unless input.is_a?(Array)
@@ -98,12 +98,12 @@
if inp[:type] == "tensor(string)"
input_tensor_values = ::FFI::MemoryPointer.new(:pointer, input_tensor_size)
input_tensor_values.write_array_of_pointer(flat_input.map { |v| ::FFI::MemoryPointer.from_string(v) })
type_enum = FFI::TensorElementDataType[:string]
- check_status FFI.OrtCreateTensorAsOrtValue(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
- check_status FFI.OrtFillStringTensor(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
+ check_status api[:CreateTensorAsOrtValue].call(@allocator.read_pointer, input_node_dims, shape.size, type_enum, input_tensor[idx])
+ check_status api[:FillStringTensor].call(input_tensor[idx].read_pointer, input_tensor_values, flat_input.size)
else
tensor_types = [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
tensor_type = tensor_types[inp[:type]]
if tensor_type
@@ -116,11 +116,11 @@
type_enum = FFI::TensorElementDataType[tensor_type]
else
unsupported_type("input", inp[:type])
end
- check_status FFI.OrtCreateTensorWithDataAsOrtValue(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx])
+ check_status api[:CreateTensorWithDataAsOrtValue].call(allocator_info.read_pointer, input_tensor_values, input_tensor_values.size, input_node_dims, shape.size, type_enum, input_tensor[idx])
end
end
input_tensor
end
@@ -131,25 +131,25 @@
ptr
end
def create_from_onnx_value(out_ptr)
out_type = ::FFI::MemoryPointer.new(:int)
- check_status = FFI.OrtGetValueType(out_ptr, out_type)
+ check_status = api[:GetValueType].call(out_ptr, out_type)
type = FFI::OnnxType[out_type.read_int]
case type
when :tensor
typeinfo = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtGetTensorTypeAndShape(out_ptr, typeinfo)
+ check_status api[:GetTensorTypeAndShape].call(out_ptr, typeinfo)
type, shape = tensor_type_and_shape(typeinfo)
tensor_data = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtGetTensorMutableData(out_ptr, tensor_data)
+ check_status api[:GetTensorMutableData].call(out_ptr, tensor_data)
out_size = ::FFI::MemoryPointer.new(:size_t)
- output_tensor_size = FFI.OrtGetTensorShapeElementCount(typeinfo.read_pointer, out_size)
+ output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
output_tensor_size = read_size_t(out_size)
# TODO support more types
type = FFI::TensorElementDataType[type]
arr =
@@ -163,27 +163,27 @@
end
Utils.reshape(arr, shape)
when :sequence
out = ::FFI::MemoryPointer.new(:size_t)
- check_status FFI.OrtGetValueCount(out_ptr, out)
+ check_status api[:GetValueCount].call(out_ptr, out)
read_size_t(out).times.map do |i|
seq = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtGetValue(out_ptr, i, @allocator.read_pointer, seq)
+ check_status api[:GetValue].call(out_ptr, i, @allocator.read_pointer, seq)
create_from_onnx_value(seq.read_pointer)
end
when :map
type_shape = ::FFI::MemoryPointer.new(:pointer)
map_keys = ::FFI::MemoryPointer.new(:pointer)
map_values = ::FFI::MemoryPointer.new(:pointer)
elem_type = ::FFI::MemoryPointer.new(:int)
- check_status FFI.OrtGetValue(out_ptr, 0, @allocator.read_pointer, map_keys)
- check_status FFI.OrtGetValue(out_ptr, 1, @allocator.read_pointer, map_values)
- check_status FFI.OrtGetTensorTypeAndShape(map_keys.read_pointer, type_shape)
- check_status FFI.OrtGetTensorElementType(type_shape.read_pointer, elem_type)
+ check_status api[:GetValue].call(out_ptr, 0, @allocator.read_pointer, map_keys)
+ check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
+ check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
+ check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
# TODO support more types
elem_type = FFI::TensorElementDataType[elem_type.read_int]
case elem_type
when :int64
@@ -206,25 +206,25 @@
@session.read_pointer
end
def check_status(status)
unless status.null?
- message = FFI.OrtGetErrorMessage(status)
- FFI.OrtReleaseStatus(status)
+ message = api[:GetErrorMessage].call(status)
+ api[:ReleaseStatus].call(status)
raise OnnxRuntime::Error, message
end
end
def node_info(typeinfo)
onnx_type = ::FFI::MemoryPointer.new(:int)
- check_status FFI.OrtOnnxTypeFromTypeInfo(typeinfo.read_pointer, onnx_type)
+ check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type)
type = FFI::OnnxType[onnx_type.read_int]
case type
when :tensor
tensor_info = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtCastTypeInfoToTensorInfo(typeinfo.read_pointer, tensor_info)
+ check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
type, shape = tensor_type_and_shape(tensor_info)
{
type: "tensor(#{FFI::TensorElementDataType[type]})",
shape: shape
@@ -243,23 +243,23 @@
}
else
unsupported_type("ONNX", type)
end
ensure
- FFI.OrtReleaseTypeInfo(typeinfo.read_pointer)
+ api[:ReleaseTypeInfo].call(typeinfo.read_pointer)
end
def tensor_type_and_shape(tensor_info)
type = ::FFI::MemoryPointer.new(:int)
- check_status FFI.OrtGetTensorElementType(tensor_info.read_pointer, type)
+ check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type)
num_dims_ptr = ::FFI::MemoryPointer.new(:size_t)
- check_status FFI.OrtGetDimensionsCount(tensor_info.read_pointer, num_dims_ptr)
+ check_status api[:GetDimensionsCount].call(tensor_info.read_pointer, num_dims_ptr)
num_dims = read_size_t(num_dims_ptr)
node_dims = ::FFI::MemoryPointer.new(:int64, num_dims)
- check_status FFI.OrtGetDimensions(tensor_info.read_pointer, node_dims, num_dims)
+ check_status api[:GetDimensions].call(tensor_info.read_pointer, node_dims, num_dims)
[type.read_int, node_dims.read_array_of_int64(num_dims)]
end
def unsupported_type(name, type)
@@ -273,16 +273,23 @@
else
ptr.read(:size_t)
end
end
+ def api
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(1)
+ end
+
def env
# use mutex for thread-safety
Utils.mutex.synchronize do
@@env ||= begin
env = ::FFI::MemoryPointer.new(:pointer)
- check_status FFI.OrtCreateEnv(3, "Default", env)
- at_exit { FFI.OrtReleaseEnv(env.read_pointer) }
+ check_status api[:CreateEnv].call(3, "Default", env)
+ at_exit { api[:ReleaseEnv].call(env.read_pointer) }
+ # disable telemetry
+ # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
+ check_status api[:DisableTelemetryEvents].call(env)
env
end
end
end
end