Sha256: 8aeb7990891b35780ea85ccf8aa82acaf8b40400c56c65640c2dae5240e16375

Contents?: true

Size: 558 Bytes

Versions: 4

Compression:

Stored size: 558 Bytes

Contents

require "dnn"
require "dnn/image"
require_relative "convnet8"

def load_model
  return if $model
  $model = ConvNet.create([28, 28, 1])
  $model.predict1(Numo::SFloat.zeros(28, 28, 1))
  $model.load_params("trained_mnist_params.marshal")
end

def mnist_predict(img, width, height)
  load_model
  img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA)
  img = DNN::Image.to_rgb(img)
  img = DNN::Image.to_gray_scale(img)
  x = Numo::SFloat.cast(img) / 255
  out = $model.predict1(x)
  out.to_a.map { |v| v.round(4) * 100 }
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
ruby-dnn-1.3.0 examples/judge-number/mnist_predict.rb
ruby-dnn-1.2.3 examples/judge-number/mnist_predict.rb
ruby-dnn-1.2.2 examples/judge-number/mnist_predict.rb
ruby-dnn-1.2.1 examples/judge-number/mnist_predict.rb