Sha256: ccb89c5b61ff2108e356ce5d21ed7c47e2aca16607fc0dc39adc79147162eb3e

Contents?: true

Size: 587 Bytes

Versions: 2

Compression:

Stored size: 587 Bytes

Contents

module OnnxRuntime
  class Model
    def initialize(path_or_bytes, **session_options)
      @session = InferenceSession.new(path_or_bytes, **session_options)
    end

    def predict(input_feed, output_names: nil, **run_options)
      predictions = @session.run(output_names, input_feed, **run_options)
      output_names ||= outputs.map { |o| o[:name] }

      result = {}
      output_names.zip(predictions).each do |k, v|
        result[k.to_s] = v
      end
      result
    end

    def inputs
      @session.inputs
    end

    def outputs
      @session.outputs
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
onnxruntime-0.2.3 lib/onnxruntime/model.rb
onnxruntime-0.2.2 lib/onnxruntime/model.rb