lib/onnxruntime/inference_session.rb in onnxruntime-0.2.3 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.3.0
- old
+ new
@@ -29,18 +29,26 @@
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
# session
@session = ::FFI::MemoryPointer.new(:pointer)
- path_or_bytes = path_or_bytes.to_str
+ from_memory =
+ if path_or_bytes.respond_to?(:read)
+ path_or_bytes = path_or_bytes.read
+ true
+ else
+ path_or_bytes = path_or_bytes.to_str
+ path_or_bytes.encoding == Encoding::BINARY
+ end
# fix for Windows "File doesn't exist"
- if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
+ if Gem.win_platform? && !from_memory
path_or_bytes = File.binread(path_or_bytes)
+ from_memory = true
end
- if path_or_bytes.encoding == Encoding::BINARY
+ if from_memory
check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
else
check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
end
@@ -98,10 +106,44 @@
output_names.size.times.map do |i|
create_from_onnx_value(output_tensor[i].read_pointer)
end
end
+ def modelmeta
+ description = ::FFI::MemoryPointer.new(:string)
+ domain = ::FFI::MemoryPointer.new(:string)
+ graph_name = ::FFI::MemoryPointer.new(:string)
+ producer_name = ::FFI::MemoryPointer.new(:string)
+ version = ::FFI::MemoryPointer.new(:int64_t)
+
+ metadata = ::FFI::MemoryPointer.new(:pointer)
+ check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
+ check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
+ check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
+ check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
+ check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
+ check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
+ api[:ReleaseModelMetadata].call(metadata.read_pointer)
+
+ # TODO add custom_metadata_map
+ # need a way to get keys
+
+ {
+ description: description.read_pointer.read_string,
+ domain: domain.read_pointer.read_string,
+ graph_name: graph_name.read_pointer.read_string,
+ producer_name: producer_name.read_pointer.read_string,
+ version: version.read(:int64_t)
+ }
+ end
+
+ def end_profiling
+ out = ::FFI::MemoryPointer.new(:string)
+ check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
+ out.read_pointer.read_string
+ end
+
private
def create_input_tensor(input_feed)
allocator_info = ::FFI::MemoryPointer.new(:pointer)
check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info)
@@ -259,18 +301,35 @@
{
type: "tensor(#{FFI::TensorElementDataType[type]})",
shape: shape
}
when :sequence
- # TODO show nested
+ sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
+ check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
+ nested_type_info = ::FFI::MemoryPointer.new(:pointer)
+ check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
+ v = node_info(nested_type_info)[:type]
+
{
- type: "seq",
+ type: "seq(#{v})",
shape: []
}
when :map
- # TODO show nested
+ map_type_info = ::FFI::MemoryPointer.new(:pointer)
+ check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)
+
+ # key
+ key_type = ::FFI::MemoryPointer.new(:int)
+ check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
+ k = FFI::TensorElementDataType[key_type.read_int]
+
+ # value
+ value_type_info = ::FFI::MemoryPointer.new(:pointer)
+ check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
+ v = node_info(value_type_info)[:type]
+
{
- type: "map",
+ type: "map(#{k},#{v})",
shape: []
}
else
unsupported_type("ONNX", type)
end