lib/onnxruntime/inference_session.rb in onnxruntime-0.2.1 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.2.2
- old
+ new
@@ -1,13 +1,35 @@
module OnnxRuntime
class InferenceSession
attr_reader :inputs, :outputs
- def initialize(path_or_bytes)
+ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
# session options
session_options = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CreateSessionOptions].call(session_options)
+ check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
+ check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
+ check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
+ if execution_mode
+ mode =
+ case execution_mode
+ when :sequential
+ 0
+ when :parallel
+ 1
+ else
+ raise ArgumentError, "Invalid execution mode"
+ end
+ check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
+ end
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level
+ check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
+ check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
+ check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
+ check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
+ check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
+ check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath
# session
@session = ::FFI::MemoryPointer.new(:pointer)
path_or_bytes = path_or_bytes.to_str
@@ -51,19 +73,28 @@
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
@outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
end
end
- def run(output_names, input_feed)
+ # TODO support logid
+ def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
input_tensor = create_input_tensor(input_feed)
output_names ||= @outputs.map { |v| v[:name] }
output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
input_node_names = create_node_names(input_feed.keys.map(&:to_s))
output_node_names = create_node_names(output_names.map(&:to_s))
- # TODO support run options
- check_status api[:Run].call(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
+
+ # run options
+ run_options = ::FFI::MemoryPointer.new(:pointer)
+ check_status api[:CreateRunOptions].call(run_options)
+ check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
+ check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
+ check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
+ check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate
+
+ check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)
output_names.size.times.map do |i|
create_from_onnx_value(output_tensor[i].read_pointer)
end
end