Sha256: 8aeb7990891b35780ea85ccf8aa82acaf8b40400c56c65640c2dae5240e16375
Contents?: true
Size: 558 Bytes
Versions: 4
Compression:
Stored size: 558 Bytes
Contents
require "dnn" require "dnn/image" require_relative "convnet8" def load_model return if $model $model = ConvNet.create([28, 28, 1]) $model.predict1(Numo::SFloat.zeros(28, 28, 1)) $model.load_params("trained_mnist_params.marshal") end def mnist_predict(img, width, height) load_model img = DNN::Image.from_binary(img, height, width, DNN::Image::RGBA) img = DNN::Image.to_rgb(img) img = DNN::Image.to_gray_scale(img) x = Numo::SFloat.cast(img) / 255 out = $model.predict1(x) out.to_a.map { |v| v.round(4) * 100 } end
Version data entries
4 entries across 4 versions & 1 rubygems