lib/ruby_brain/dataset/mnist/data.rb in ruby_brain-0.1.2 vs lib/ruby_brain/dataset/mnist/data.rb in ruby_brain-0.1.3

- old
+ new

@@ -12,23 +12,37 @@ end def data train_images_path = Dir.pwd + '/train-images-idx3-ubyte.gz' train_labels_path = Dir.pwd + '/train-labels-idx1-ubyte.gz' - + test_images_path = Dir.pwd + '/t10k-images-idx3-ubyte.gz' + test_labels_path = Dir.pwd + '/t10k-labels-idx1-ubyte.gz' + unless File.exist?(train_images_path) puts 'downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ...' download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', train_images_path) end unless File.exist?(train_labels_path) - puts 'downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz' + puts 'downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz ...' download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', train_labels_path) end + unless File.exist?(test_images_path) + puts 'downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz ...' + download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', test_images_path) + end + + unless File.exist?(test_labels_path) + puts 'downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz ...' + download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', test_labels_path) + end + train_images = Mnist.load_images(train_images_path) train_labels = Mnist.load_labels(train_labels_path) + test_images = Mnist.load_images(test_images_path) + test_labels = Mnist.load_labels(test_labels_path) input_training_set = train_images[2].map do |image| image.unpack('C*').map {|e| e / 255.0} end @@ -36,10 +50,20 @@ one_hot_vector = Array.new(10, 0) one_hot_vector[label] = 1 one_hot_vector end + input_test_set = test_images[2].map do |image| + image.unpack('C*').map {|e| e / 255.0} + end + + output_test_set = test_labels.map do |label| + one_hot_vector = Array.new(10, 0) + one_hot_vector[label] = 1 + one_hot_vector + end + # puts train_images[0].class # puts train_images[1].class # puts train_images[2].size # puts train_images[2][0].size # puts train_images[2][59999][783].class @@ -53,10 +77,10 @@ # end # puts # puts train_labels[j] # end - {input: input_training_set, output: output_training_set} + [{input: input_training_set, output: output_training_set}, {input: input_test_set, output: output_test_set}] end module_function :data, :download_file end