lib/onnxruntime/inference_session.rb in onnxruntime-0.3.0 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.3.1

- old
+ new

@@ -107,28 +107,38 @@ create_from_onnx_value(output_tensor[i].read_pointer) end end def modelmeta + keys = ::FFI::MemoryPointer.new(:pointer) + num_keys = ::FFI::MemoryPointer.new(:int64_t) description = ::FFI::MemoryPointer.new(:string) domain = ::FFI::MemoryPointer.new(:string) graph_name = ::FFI::MemoryPointer.new(:string) producer_name = ::FFI::MemoryPointer.new(:string) version = ::FFI::MemoryPointer.new(:int64_t) metadata = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetModelMetadata].call(read_pointer, metadata) + + custom_metadata_map = {} + check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(metadata.read_pointer, @allocator.read_pointer, keys, num_keys) + num_keys.read(:int64_t).times do |i| + key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string + value = ::FFI::MemoryPointer.new(:string) + check_status api[:ModelMetadataLookupCustomMetadataMap].call(metadata.read_pointer, @allocator.read_pointer, key, value) + custom_metadata_map[key] = value.read_pointer.read_string + end + check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description) check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain) check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name) check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name) check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version) api[:ReleaseModelMetadata].call(metadata.read_pointer) - # TODO add custom_metadata_map - # need a way to get keys - { + custom_metadata_map: custom_metadata_map, description: description.read_pointer.read_string, domain: domain.read_pointer.read_string, graph_name: graph_name.read_pointer.read_string, producer_name: producer_name.read_pointer.read_string, version: version.read(:int64_t)