examples/mnist.rb in ruby_brain-0.1.2 vs examples/mnist.rb in ruby_brain-0.1.3
- old
+ new
@@ -1,23 +1,41 @@
require 'ruby_brain'
require 'ruby_brain/dataset/mnist/data'
-# NUM_TEST_DATA = 50000
+NUM_TRAIN_DATA = 5000
dataset = RubyBrain::DataSet::Mnist::data
-NUM_TRAIN_DATA = 5000
+training_dataset = dataset.first
+test_dataset = dataset.last
-training_input = dataset[:input][0..(NUM_TRAIN_DATA-1)]
-training_supervisor = dataset[:output][0..(NUM_TRAIN_DATA-1)]
+puts "[training data info]"
+puts " [in]"
+puts "#{training_dataset[:input].size} samples (use first #{NUM_TRAIN_DATA} for training)"
+puts "#{training_dataset[:input].first.size} features"
+puts " [out]"
+puts "#{training_dataset[:output].size} samples"
+puts "#{training_dataset[:output].first.size} features"
-# test_input = dataset[:input][NUM_TRAIN_DATA..(NUM_TRAIN_DATA+NUM_TEST_DATA-1)]
-# test_supervisor = dataset[:output][NUM_TRAIN_DATA..(NUM_TRAIN_DATA+NUM_TEST_DATA-1)]
-test_input = dataset[:input][NUM_TRAIN_DATA..-1]
-test_supervisor = dataset[:output][NUM_TRAIN_DATA..-1]
+puts "[test data info]"
+puts " [in]"
+puts "#{test_dataset[:input].size} samples"
+puts "#{test_dataset[:input].first.size} features"
+puts " [out]"
+puts "#{test_dataset[:output].size} samples"
+puts "#{test_dataset[:output].first.size} features"
-network = RubyBrain::Network.new([dataset[:input].first.size, 50, dataset[:output].first.size])
+# training_input = training_dataset[:input]
+# training_supervisor = training_dataset[:output]
+
+training_input = training_dataset[:input][0..(NUM_TRAIN_DATA-1)]
+training_supervisor = training_dataset[:output][0..(NUM_TRAIN_DATA-1)]
+
+test_input = test_dataset[:input]
+test_supervisor = test_dataset[:output]
+
+network = RubyBrain::Network.new([training_input.first.size, 50, training_supervisor.first.size])
network.learning_rate = 0.7
network.init_network
### You can load weights from file in this timing if you want.
# network.load_weights_from_yaml_file(File.dirname(__FILE__) + '/../best_weights_1469044985.yml')
@@ -26,10 +44,11 @@
network.learn(training_input, training_supervisor, max_training_count=100, tolerance=0.0004, monitoring_channels=[:best_params_training])
### You can save weights into a yml file if you want.
# network.dump_weights_to_yaml('path/to/weights.yml.file')
+network.dump_weights_to_yaml('./weights_xxx.yml')
class Array
def argmax
max_i, max_val = 0, self.first
@@ -69,10 +88,10 @@
# # initial_weights_file: 'weights_3_30_10_1429166740.yml',
# # initial_weights_file: 'best_weights_1429544001.yml',
# monitoring_channels: [:best_params_training]
# }
-# RubyBrain::Trainer.normal_learning([dataset[:input].first.size, 50, dataset[:output].first.size],
+# RubyBrain::Trainer.normal_learning([training_dataset[:input].first.size, 50, training_dataset[:output].first.size],
# training_input, training_supervisor,
# training_option)