lib/onnxruntime/inference_session.rb in onnxruntime-0.3.1 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.3.2
- old
+ new
@@ -8,22 +8,21 @@
check_status api[:CreateSessionOptions].call(session_options)
check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
if execution_mode
- mode =
- case execution_mode
- when :sequential
- 0
- when :parallel
- 1
- else
- raise ArgumentError, "Invalid execution mode"
- end
+ execution_modes = {sequential: 0, parallel: 1}
+ mode = execution_modes[execution_mode]
+ raise ArgumentError, "Invalid execution mode" unless mode
check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
end
- check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level
+ if graph_optimization_level
+ optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
+ # TODO raise error in 0.4.0
+ level = optimization_levels[graph_optimization_level] || graph_optimization_level
+ check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level)
+ end
check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
@@ -289,10 +288,10 @@
@session.read_pointer
end
def check_status(status)
unless status.null?
- message = api[:GetErrorMessage].call(status)
+ message = api[:GetErrorMessage].call(status).read_string
api[:ReleaseStatus].call(status)
raise OnnxRuntime::Error, message
end
end