lib/onnxruntime/inference_session.rb in onnxruntime-0.2.1 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.2.2

- old
+ new

@@ -1,13 +1,35 @@ module OnnxRuntime class InferenceSession attr_reader :inputs, :outputs - def initialize(path_or_bytes) + def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil) # session options session_options = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateSessionOptions].call(session_options) + check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena + check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern + check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling + if execution_mode + mode = + case execution_mode + when :sequential + 0 + when :parallel + 1 + else + raise ArgumentError, "Invalid execution mode" + end + check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode) + end + check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level + check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads + check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads + check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level + check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level + check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid + check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath # session @session = ::FFI::MemoryPointer.new(:pointer) path_or_bytes = path_or_bytes.to_str @@ -51,19 +73,28 @@ check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo) @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end end - def run(output_names, input_feed) + # TODO support logid + def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) input_tensor = create_input_tensor(input_feed) output_names ||= @outputs.map { |v| v[:name] } output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size) input_node_names = create_node_names(input_feed.keys.map(&:to_s)) output_node_names = create_node_names(output_names.map(&:to_s)) - # TODO support run options - check_status api[:Run].call(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) + + # run options + run_options = ::FFI::MemoryPointer.new(:pointer) + check_status api[:CreateRunOptions].call(run_options) + check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level + check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level + check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid + check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate + + check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) output_names.size.times.map do |i| create_from_onnx_value(output_tensor[i].read_pointer) end end