examples/mnist2.rb in ruby_brain-0.1.2 vs examples/mnist2.rb in ruby_brain-0.1.3
- old
+ new
@@ -1,35 +1,28 @@
require 'ruby_brain'
require 'ruby_brain/dataset/mnist/data'
-NUM_TRAIN_DATA = 5000
NUM_TEST_DATA = 500
dataset = RubyBrain::DataSet::Mnist::data
-training_input = dataset[:input][0..(NUM_TRAIN_DATA-1)]
-training_supervisor = dataset[:output][0..(NUM_TRAIN_DATA-1)]
+test_dataset = dataset.last
+test_input = test_dataset[:input]
+test_supervisor = test_dataset[:output]
# test_input = dataset[:input][NUM_TRAIN_DATA..(NUM_TRAIN_DATA+NUM_TEST_DATA-1)]
# test_supervisor = dataset[:output][NUM_TRAIN_DATA..(NUM_TRAIN_DATA+NUM_TEST_DATA-1)]
-test_input = dataset[:input][NUM_TRAIN_DATA..-1]
-test_supervisor = dataset[:output][NUM_TRAIN_DATA..-1]
+test_input = test_dataset[:input][NUM_TRAIN_DATA..-1]
+test_supervisor = test_dataset[:output][NUM_TRAIN_DATA..-1]
-network = RubyBrain::Network.new([dataset[:input].first.size, 50, dataset[:output].first.size])
+network = RubyBrain::Network.new([test_input.first.size, 50, test_supervisor.first.size])
# network.learning_rate = 0.7
network.init_network
-network.load_weights_from_yaml_file(File.dirname(__FILE__) + '/../best_weights_1469044985.yml')
### You can initializes weights by loading weights from file if you want.
-# network.load_weights_from_yaml_file('path/to/weights.yml.file')
+network.load_weights_from_yaml_file(File.dirname(__FILE__) + '/../best_weights_1469999296.yml')
-# network.learn(training_input, training_supervisor, max_training_count=100, tolerance=0.0004, monitoring_channels=[:best_params_training])
-
-### You can save weights into a yml file if you want.
-# network.dump_weights_to_yaml('path/to/weights.yml.file')
-
-
class Array
def argmax
max_i = 0
max_val = self[max_i]
self.each_with_index do |v, i|
@@ -58,25 +51,6 @@
results << (supervisor_label == predicated_label)
puts "------------------------------------------------------------"
end
puts "accuracy: #{results.count(true).to_f/results.size}"
-
-
-
-### you can do above procedure simply by using Trainer
-
-# training_option = {
-# learning_rate: 0.5,
-# max_training_count: 50,
-# tolerance: 0.0004,
-# # initial_weights_file: 'weights_3_30_10_1429166740.yml',
-# # initial_weights_file: 'best_weights_1429544001.yml',
-# monitoring_channels: [:best_params_training]
-# }
-
-# RubyBrain::Trainer.normal_learning([dataset[:input].first.size, 50, dataset[:output].first.size],
-# training_input, training_supervisor,
-# training_option)
-
-