Sha256: 6c0a92367fa10b8bccb1e32aaf4816a7980e03c1cf6a1c57553d441cd6496a91

Contents?: true

Size: 1.96 KB

Versions: 1

Compression:

Stored size: 1.96 KB

Contents

# coding: utf-8
require 'ruby_brain'
require 'ruby_brain/dataset/mnist/data'


dataset = RubyBrain::DataSet::Mnist::data
training_dataset = dataset.first
test_dataset = dataset.last


training_dataset.keys # => [:input, :output]
test_dataset.keys     # => [:input, :output]

training_dataset[:input].size       # => 60000
training_dataset[:input].first.size # => 784

training_dataset[:output].size       # => 60000
training_dataset[:output].first.size # => 10

test_dataset[:input].size       # => 10000
test_dataset[:input].first.size # => 784

test_dataset[:output].size       # => 10000
test_dataset[:output].first.size # => 10


# use first 5000 pictures for training
NUM_TRAIN_DATA = 5000
training_input = training_dataset[:input][0..(NUM_TRAIN_DATA-1)]
training_supervisor = training_dataset[:output][0..(NUM_TRAIN_DATA-1)]

# use all pictures within test_dataset
test_input = test_dataset[:input]
test_supervisor = test_dataset[:output]

# network structure [784, 50, 10]
network = RubyBrain::Network.new([training_input.first.size, 50, training_supervisor.first.size])
# learning rate is 0.7
network.learning_rate = 0.7
# initialize network
network.init_network

network.learn(training_input, training_supervisor, max_training_count=100, tolerance=0.0004, monitoring_channels=[:best_params_training])

### turn on this snippet to print pictures as ascii art. 
# 
# test_input.each_with_index do |input, i|
#   input.each_with_index do |e, j|
#     print(e > 0.3 ? 'x' : ' ')
#     puts if (j % 28) == 0
#   end
#   puts
#   supervisor_label = test_supervisor[i].index(test_supervisor[i].max)
#   predicated_output = network.get_forward_outputs(input)
#   predicated_label = predicated_output.index(predicated_output.max)
#   puts "test_supervisor: #{supervisor_label}"
#   puts "predicate: #{predicated_label}"
#   results << (supervisor_label == predicated_label)
#   puts "------------------------------------------------------------"
# end

puts "accuracy: #{results.count(true).to_f/results.size}"

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
ruby_brain-0.1.4 examples/mnist_standalone.rb