lib/onnxruntime/inference_session.rb in onnxruntime-0.3.0 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.3.1
- old
+ new
@@ -107,28 +107,38 @@
create_from_onnx_value(output_tensor[i].read_pointer)
end
end
def modelmeta
+ keys = ::FFI::MemoryPointer.new(:pointer)
+ num_keys = ::FFI::MemoryPointer.new(:int64_t)
description = ::FFI::MemoryPointer.new(:string)
domain = ::FFI::MemoryPointer.new(:string)
graph_name = ::FFI::MemoryPointer.new(:string)
producer_name = ::FFI::MemoryPointer.new(:string)
version = ::FFI::MemoryPointer.new(:int64_t)
metadata = ::FFI::MemoryPointer.new(:pointer)
check_status api[:SessionGetModelMetadata].call(read_pointer, metadata)
+
+ custom_metadata_map = {}
+ check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys)
+ num_keys.read(:int64_t).times do |i|
+ key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
+ value = ::FFI::MemoryPointer.new(:string)
+ check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value)
+ custom_metadata_map[key] = value.read_pointer.read_string
+ end
+
check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description)
check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain)
check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name)
check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name)
check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version)
api[:ReleaseModelMetadata].call(metadata.read_pointer)
- # TODO add custom_metadata_map
- # need a way to get keys
-
{
+ custom_metadata_map: custom_metadata_map,
description: description.read_pointer.read_string,
domain: domain.read_pointer.read_string,
graph_name: graph_name.read_pointer.read_string,
producer_name: producer_name.read_pointer.read_string,
version: version.read(:int64_t)