lib/onnxruntime/inference_session.rb in onnxruntime-0.3.2 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.3.3
- old
+ new
@@ -48,10 +48,11 @@
if from_memory
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
else
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
end
+ ObjectSpace.define_finalizer(self, self.class.finalize(@session))
# input info
allocator = ::FFI::MemoryPointer.new(:pointer)
check_status api[:GetAllocatorWithDefaultOptions].call(allocator)
@allocator = allocator
@@ -78,10 +79,12 @@
check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
typeinfo = ::FFI::MemoryPointer.new(:pointer)
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
end
+ ensure
+ # release :SessionOptions, session_options
end
# TODO support logid
def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
input_tensor = create_input_tensor(input_feed)
@@ -103,10 +106,17 @@
check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
output_names.size.times.map do |i|
create_from_onnx_value(output_tensor[i].read_pointer)
end
+ ensure
+ release :RunOptions, run_options
+ if input_tensor
+ input_feed.size.times do |i|
+ release :Value, input_tensor[i]
+ end
+ end
end
def modelmeta
keys = ::FFI::MemoryPointer.new(:pointer)
num_keys = ::FFI::MemoryPointer.new(:int64_t)
@@ -131,20 +141,21 @@
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
- api[:ReleaseModelMetadata].call(metadata.read_pointer)
{
custom_metadata_map: custom_metadata_map,
description: description.read_pointer.read_string,
domain: domain.read_pointer.read_string,
graph_name: graph_name.read_pointer.read_string,
producer_name: producer_name.read_pointer.read_string,
version: version.read(:int64_t)
}
+ ensure
+ release :ModelMetadata, metadata
end
def end_profiling
out = ::FFI::MemoryPointer.new(:string)
check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
@@ -230,10 +241,12 @@
out_size = ::FFI::MemoryPointer.new(:size_t)
output_tensor_size = api[:GetTensorShapeElementCount].call(typeinfo.read_pointer, out_size)
output_tensor_size = read_size_t(out_size)
+ release :TensorTypeAndShapeInfo, typeinfo
+
# TODO support more types
type = FFI::TensorElementDataType[type]
arr =
case type
when :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :double, :uint32, :uint64
@@ -262,10 +275,11 @@
check_status api[:GetValue].call(out_ptr, 0, @allocator.read_pointer, map_keys)
check_status api[:GetValue].call(out_ptr, 1, @allocator.read_pointer, map_values)
check_status api[:GetTensorTypeAndShape].call(map_keys.read_pointer, type_shape)
check_status api[:GetTensorElementType].call(type_shape.read_pointer, elem_type)
+ release :TensorTypeAndShapeInfo, type_shape
# TODO support more types
elem_type = FFI::TensorElementDataType[elem_type.read_int]
case elem_type
when :int64
@@ -302,10 +316,11 @@
type = FFI::OnnxType[onnx_type.read_int]
case type
when :tensor
tensor_info = ::FFI::MemoryPointer.new(:pointer)
+ # don't free tensor_info
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)
type, shape = tensor_type_and_shape(tensor_info)
{
type: "tensor(#{FFI::TensorElementDataType[type]})",
@@ -342,11 +357,11 @@
}
else
unsupported_type("ONNX", type)
end
ensure
- api[:ReleaseTypeInfo].call(typeinfo.read_pointer)
+ release :TypeInfo, typeinfo
end
def tensor_type_and_shape(tensor_info)
type = ::FFI::MemoryPointer.new(:int)
check_status api[:GetTensorElementType].call(tensor_info.read_pointer, type)
@@ -373,19 +388,36 @@
ptr.read(:size_t)
end
end
def api
- @api ||= FFI.OrtGetApiBase[:GetApi].call(1)
+ self.class.api
end
+ def release(*args)
+ self.class.release(*args)
+ end
+
+ def self.api
+ @api ||= FFI.OrtGetApiBase[:GetApi].call(3)
+ end
+
+ def self.release(type, pointer)
+ api[:"Release#{type}"].call(pointer.read_pointer) if pointer && !pointer.null?
+ end
+
+ def self.finalize(session)
+ # must use proc instead of stabby lambda
+ proc { release :Session, session }
+ end
+
def env
# use mutex for thread-safety
Utils.mutex.synchronize do
@@env ||= begin
env = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CreateEnv].call(3, "Default", env)
- at_exit { api[:ReleaseEnv].call(env.read_pointer) }
+ at_exit { release :Env, env }
# disable telemetry
# https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md
check_status api[:DisableTelemetryEvents].call(env)
env
end