lib/onnxruntime/inference_session.rb in onnxruntime-0.3.1 vs lib/onnxruntime/inference_session.rb in onnxruntime-0.3.2

- old
+ new

@@ -8,22 +8,21 @@ check_status api[:CreateSessionOptions].call(session_options) check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling if execution_mode - mode = - case execution_mode - when :sequential - 0 - when :parallel - 1 - else - raise ArgumentError, "Invalid execution mode" - end + execution_modes = {sequential: 0, parallel: 1} + mode = execution_modes[execution_mode] + raise ArgumentError, "Invalid execution mode" unless mode check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode) end - check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level + if graph_optimization_level + optimization_levels = {none: 0, basic: 1, extended: 2, all: 99} + # TODO raise error in 0.4.0 + level = optimization_levels[graph_optimization_level] || graph_optimization_level + check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, level) + end check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid @@ -289,10 +288,10 @@ @session.read_pointer end def check_status(status) unless status.null? - message = api[:GetErrorMessage].call(status) + message = api[:GetErrorMessage].call(status).read_string api[:ReleaseStatus].call(status) raise OnnxRuntime::Error, message end end