lib/ruby_brain/dataset/mnist/data.rb in ruby_brain-0.1.2 vs lib/ruby_brain/dataset/mnist/data.rb in ruby_brain-0.1.3
- old
+ new
@@ -12,23 +12,37 @@
end
def data
train_images_path = Dir.pwd + '/train-images-idx3-ubyte.gz'
train_labels_path = Dir.pwd + '/train-labels-idx1-ubyte.gz'
-
+ test_images_path = Dir.pwd + '/t10k-images-idx3-ubyte.gz'
+ test_labels_path = Dir.pwd + '/t10k-labels-idx1-ubyte.gz'
+
unless File.exist?(train_images_path)
puts 'downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ...'
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', train_images_path)
end
unless File.exist?(train_labels_path)
- puts 'downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
+ puts 'downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz ...'
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', train_labels_path)
end
+ unless File.exist?(test_images_path)
+ puts 'downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz ...'
+ download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', test_images_path)
+ end
+
+ unless File.exist?(test_labels_path)
+ puts 'downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz ...'
+ download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', test_labels_path)
+ end
+
train_images = Mnist.load_images(train_images_path)
train_labels = Mnist.load_labels(train_labels_path)
+ test_images = Mnist.load_images(test_images_path)
+ test_labels = Mnist.load_labels(test_labels_path)
input_training_set = train_images[2].map do |image|
image.unpack('C*').map {|e| e / 255.0}
end
@@ -36,10 +50,20 @@
one_hot_vector = Array.new(10, 0)
one_hot_vector[label] = 1
one_hot_vector
end
+ input_test_set = test_images[2].map do |image|
+ image.unpack('C*').map {|e| e / 255.0}
+ end
+
+ output_test_set = test_labels.map do |label|
+ one_hot_vector = Array.new(10, 0)
+ one_hot_vector[label] = 1
+ one_hot_vector
+ end
+
# puts train_images[0].class
# puts train_images[1].class
# puts train_images[2].size
# puts train_images[2][0].size
# puts train_images[2][59999][783].class
@@ -53,10 +77,10 @@
# end
# puts
# puts train_labels[j]
# end
- {input: input_training_set, output: output_training_set}
+ [{input: input_training_set, output: output_training_set}, {input: input_test_set, output: output_test_set}]
end
module_function :data, :download_file
end