lib/onnxruntime/inference_session.rb in onnxruntime-0.2.3 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.3.0

- old
+ new

@@ -29,18 +29,26 @@ check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath # session @session = ::FFI::MemoryPointer.new(:pointer) - path_or_bytes = path_or_bytes.to_str + from_memory = + if path_or_bytes.respond_to?(:read) + path_or_bytes = path_or_bytes.read + true + else + path_or_bytes = path_or_bytes.to_str + path_or_bytes.encoding == Encoding::BINARY + end # fix for Windows "File doesn't exist" - if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY + if Gem.win_platform? && !from_memory path_or_bytes = File.binread(path_or_bytes) + from_memory = true end - if path_or_bytes.encoding == Encoding::BINARY + if from_memory check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session) else check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session) end @@ -98,10 +106,44 @@ output_names.size.times.map do |i| create_from_onnx_value(output_tensor[i].read_pointer) end end + def modelmeta + description = ::FFI::MemoryPointer.new(:string) + domain = ::FFI::MemoryPointer.new(:string) + graph_name = ::FFI::MemoryPointer.new(:string) + producer_name = ::FFI::MemoryPointer.new(:string) + version = ::FFI::MemoryPointer.new(:int64_t) + + metadata = ::FFI::MemoryPointer.new(:pointer) + check_status api[:SessionGetModelMetadata].call(read_pointer, metadata) + check_status api[:ModelMetadataGetDescription].call(metadata.read_pointer, @allocator.read_pointer, description) + check_status api[:ModelMetadataGetDomain].call(metadata.read_pointer, @allocator.read_pointer, domain) + check_status api[:ModelMetadataGetGraphName].call(metadata.read_pointer, @allocator.read_pointer, graph_name) + check_status api[:ModelMetadataGetProducerName].call(metadata.read_pointer, @allocator.read_pointer, producer_name) + check_status api[:ModelMetadataGetVersion].call(metadata.read_pointer, version) + api[:ReleaseModelMetadata].call(metadata.read_pointer) + + # TODO add custom_metadata_map + # need a way to get keys + + { + description: description.read_pointer.read_string, + domain: domain.read_pointer.read_string, + graph_name: graph_name.read_pointer.read_string, + producer_name: producer_name.read_pointer.read_string, + version: version.read(:int64_t) + } + end + + def end_profiling + out = ::FFI::MemoryPointer.new(:string) + check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out) + out.read_pointer.read_string + end + private def create_input_tensor(input_feed) allocator_info = ::FFI::MemoryPointer.new(:pointer) check_status = api[:CreateCpuMemoryInfo].call(1, 0, allocator_info) @@ -259,18 +301,35 @@ { type: "tensor(#{FFI::TensorElementDataType[type]})", shape: shape } when :sequence - # TODO show nested + sequence_type_info = ::FFI::MemoryPointer.new(:pointer) + check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info) + nested_type_info = ::FFI::MemoryPointer.new(:pointer) + check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info) + v = node_info(nested_type_info)[:type] + { - type: "seq", + type: "seq(#{v})", shape: [] } when :map - # TODO show nested + map_type_info = ::FFI::MemoryPointer.new(:pointer) + check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info) + + # key + key_type = ::FFI::MemoryPointer.new(:int) + check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type) + k = FFI::TensorElementDataType[key_type.read_int] + + # value + value_type_info = ::FFI::MemoryPointer.new(:pointer) + check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info) + v = node_info(value_type_info)[:type] + { - type: "map", + type: "map(#{k},#{v})", shape: [] } else unsupported_type("ONNX", type) end