Sha256: 5863d540a8f5849c34f790579a672c48da66a7ae2e36f5771bc114c4a6d848d3
Contents?: true
Size: 1.64 KB
Versions: 2
Compression:
Stored size: 1.64 KB
Contents
require 'ruby_brain' require 'ruby_brain/dataset/mnist/data' NUM_TEST_DATA = 500 dataset = RubyBrain::DataSet::Mnist::data test_dataset = dataset.last test_input = test_dataset[:input] test_supervisor = test_dataset[:output] # test_input = dataset[:input][NUM_TRAIN_DATA..(NUM_TRAIN_DATA+NUM_TEST_DATA-1)] # test_supervisor = dataset[:output][NUM_TRAIN_DATA..(NUM_TRAIN_DATA+NUM_TEST_DATA-1)] test_input = test_dataset[:input][NUM_TRAIN_DATA..-1] test_supervisor = test_dataset[:output][NUM_TRAIN_DATA..-1] network = RubyBrain::Network.new([test_input.first.size, 50, test_supervisor.first.size]) # network.learning_rate = 0.7 network.init_network ### You can initializes weights by loading weights from file if you want. network.load_weights_from_yaml_file(File.dirname(__FILE__) + '/../best_weights_1469999296.yml') class Array def argmax max_i = 0 max_val = self[max_i] self.each_with_index do |v, i| if v > max_val max_val = v max_i = i end end return max_i end end results = [] test_input.each_with_index do |input, i| ### You can see test input, label and predicated lable in standard out if you uncomment in this block input.each_with_index do |e, j| print(e > 0.3 ? 'x' : ' ') puts if (j % 28) == 0 end puts supervisor_label = test_supervisor[i].argmax predicated_label = network.get_forward_outputs(test_input[i]).argmax puts "test_supervisor: #{supervisor_label}" puts "predicate: #{predicated_label}" results << (supervisor_label == predicated_label) puts "------------------------------------------------------------" end puts "accuracy: #{results.count(true).to_f/results.size}"
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
ruby_brain-0.1.4 | examples/mnist2.rb |
ruby_brain-0.1.3 | examples/mnist2.rb |