lib/mnist-learn.rb in mnist-learn-0.1.2 vs lib/mnist-learn.rb in mnist-learn-0.1.3
- old
+ new
@@ -1,205 +1,205 @@
-require "mnist-learn/version"
-require 'fileutils'
-require 'zlib'
-require 'net/http'
-require 'ostruct'
-
-module Mnist
- class Error < StandardError; end
-
- class LoadError < Error; end
-
- class InvalidMagic < LoadError; end
-
- class MnistReader
- def initialize(base_path, one_hot = false)
- @base_path = base_path
- @one_hot = one_hot
- end
-
- def train
- load_pair('train-images-idx3-ubyte', 'train-labels-idx1-ubyte')
- end
-
- def test
- load_pair('t10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
- end
-
- private
-
- def load_pair(images, labels)
- Loader.new(File.join(@base_path, images), File.join(@base_path, labels), @one_hot)
- end
- end
-
- class Loader
- IMAGE_FILE_MAGIC = 2051
- LABEL_FILE_MAGIC = 2049
-
- def initialize(filename_image, filename_label, one_hot)
- @filename_image = filename_image
- @filename_label = filename_label
- @one_hot = one_hot
- @index = 0
- end
-
- attr_reader :filename_image, :filename_label
-
- def load_images
- check_magic(input_images, IMAGE_FILE_MAGIC)
- @total_count = read_total_count(input_images)
- nrows, ncols = read_image_size(input_images)
- images = @total_count.times.map do
- read_image(nrows, ncols)
- end
- [nrows, ncols, images]
- end
-
- def load_labels
- check_magic(input_labels, LABEL_FILE_MAGIC)
- @total_count = read_total_count(input_labels)
- read_labels(input_labels, @total_count)
- end
-
- def images
- @all_images ||= load_images[2]
- end
-
- def labels
- @all_labels ||= (@one_hot ? load_labels.map { |label_data| one_hot_transform(label_data) } : load_labels)
- end
-
- def next(batch_size)
- if @index == 0
- @rows, @columns, @images = load_images
- @labels = load_labels
- end
- images = []
- labels = []
- batch_size.times.each do
- next if @index >= @total_count
- image_data = @images[@index]
- label_data = @labels[@index]
- image_data.map! { |b| b.to_f / 255.0 }
- @index += 1
- images << image_data
- labels << (@one_hot ? one_hot_transform(label_data) : label_data.to_f)
- end
- [images, labels]
- end
-
- def next_batch(batch_size, rnd: Random.new)
- @data_set ||= begin
- rows, columns, images = load_images
- labels = load_labels
- Array.new(images.size) do
- image_data = images[@index]
- label_data = labels[@index]
- image_data.map! { |b| b.to_f / 255.0 }
- @index += 1
- [image_data, (@one_hot ? one_hot_transform(label_data) : label_data.to_f)]
- end
- end
- @data_set.shuffle!(random: rnd)
- batch = @data_set[0...batch_size]
- [batch.map { |v| v[0]}, batch.map { |v| v[1]}]
- end
-
- private
-
- def one_hot_transform(label)
- arr = Array.new(10) { 0.0 }
- arr[label] = 1.0
- arr
- end
-
- def check_magic(input_file, expected_magic)
- actual_magic = read_magic(input_file)
- unless actual_magic == expected_magic
- raise InvalidMagic, "Expected #{expected_magic}, but #{actual_magic} is given"
- end
- end
-
- def read_uint8(input_file, n=1)
- input_file.read(n).unpack('C*')
- end
-
- def read_uint32(input_file, n=1)
- input_file.read(4 * n).unpack('N*')
- end
-
- def read_magic(input_file)
- read_uint32(input_file).first
- end
-
- def read_total_count(input_file)
- read_uint32(input_file).first
- end
-
- def read_image_size(input_file)
- read_uint32(input_file, 2)
- end
-
- alias read_labels read_uint8
-
- def read_image(nrows, ncols)
- input_images.read(nrows * ncols).unpack("C*")
- end
-
- def input_images
- @input_images ||= File.open(filename_image)
- end
-
- def input_labels
- @input_labels ||= File.open(filename_label)
- end
- end
-
- def self.load_images(filename)
- Loader.new(filename).load_images
- end
-
- def self.load_labels(filename)
- Loader.new(filename).load_labels
- end
-
- def self.read_data_sets(path, one_hot: false)
- unless Dir.exist?(path)
- FileUtils.mkdir_p path
- end
-
- base_url = "yann.lecun.com"
- filenames = [
- "train-images-idx3-ubyte.gz",
- "train-labels-idx1-ubyte.gz",
- "t10k-images-idx3-ubyte.gz",
- "t10k-labels-idx1-ubyte.gz"
- ]
- Net::HTTP.start(base_url) do |http|
- filenames.each do |name|
- unless File.exists?(File.join(path, name))
- f = File.open(File.join(path, name), "w")
- begin
- http.request_get('/exdb/mnist/' + name) do |resp|
- resp.read_body do |segment|
- f.write(segment)
- end
- end
- ensure
- f.close
- end
- end
- end
- end
-
- filenames.each do |name|
- next if File.exists?(File.join(path, File.basename(name, '.gz')))
- puts "extracting #{name} ..."
- Zlib::GzipReader.open(File.join(path, name)) do |zipfile|
- outfile = File.open(File.join(path, File.basename(name, '.gz')), 'w')
- outfile.write(zipfile.read)
- end
- end
- MnistReader.new(path, one_hot)
- end
-end
+require "mnist-learn/version"
+require 'fileutils'
+require 'zlib'
+require 'net/http'
+require 'ostruct'
+
+module Mnist
+ class Error < StandardError; end
+
+ class LoadError < Error; end
+
+ class InvalidMagic < LoadError; end
+
+ class MnistReader
+ def initialize(base_path, one_hot = false)
+ @base_path = base_path
+ @one_hot = one_hot
+ end
+
+ def train
+ load_pair('train-images-idx3-ubyte', 'train-labels-idx1-ubyte')
+ end
+
+ def test
+ load_pair('t10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
+ end
+
+ private
+
+ def load_pair(images, labels)
+ Loader.new(File.join(@base_path, images), File.join(@base_path, labels), @one_hot)
+ end
+ end
+
+ class Loader
+ IMAGE_FILE_MAGIC = 2051
+ LABEL_FILE_MAGIC = 2049
+
+ def initialize(filename_image, filename_label, one_hot)
+ @filename_image = filename_image
+ @filename_label = filename_label
+ @one_hot = one_hot
+ @index = 0
+ end
+
+ attr_reader :filename_image, :filename_label
+
+ def load_images
+ check_magic(input_images, IMAGE_FILE_MAGIC)
+ @total_count = read_total_count(input_images)
+ nrows, ncols = read_image_size(input_images)
+ images = @total_count.times.map do
+ read_image(nrows, ncols)
+ end
+ [nrows, ncols, images]
+ end
+
+ def load_labels
+ check_magic(input_labels, LABEL_FILE_MAGIC)
+ @total_count = read_total_count(input_labels)
+ read_labels(input_labels, @total_count)
+ end
+
+ def images
+ @all_images ||= load_images[2]
+ end
+
+ def labels
+ @all_labels ||= (@one_hot ? load_labels.map { |label_data| one_hot_transform(label_data) } : load_labels)
+ end
+
+ def next(batch_size)
+ if @index == 0
+ @rows, @columns, @images = load_images
+ @labels = load_labels
+ end
+ images = []
+ labels = []
+ batch_size.times.each do
+ next if @index >= @total_count
+ image_data = @images[@index]
+ label_data = @labels[@index]
+ image_data.map! { |b| b.to_f / 255.0 }
+ @index += 1
+ images << image_data
+ labels << (@one_hot ? one_hot_transform(label_data) : label_data.to_f)
+ end
+ [images, labels]
+ end
+
+ def next_batch(batch_size, rnd: Random.new)
+ @data_set ||= begin
+ rows, columns, images = load_images
+ labels = load_labels
+ Array.new(images.size) do
+ image_data = images[@index]
+ label_data = labels[@index]
+ image_data.map! { |b| b.to_f / 255.0 }
+ @index += 1
+ [image_data, (@one_hot ? one_hot_transform(label_data) : label_data.to_f)]
+ end
+ end
+ @data_set.shuffle!(random: rnd)
+ batch = @data_set[0...batch_size]
+ [batch.map { |v| v[0]}, batch.map { |v| v[1]}]
+ end
+
+ private
+
+ def one_hot_transform(label)
+ arr = Array.new(10) { 0.0 }
+ arr[label] = 1.0
+ arr
+ end
+
+ def check_magic(input_file, expected_magic)
+ actual_magic = read_magic(input_file)
+ unless actual_magic == expected_magic
+ raise InvalidMagic, "Expected #{expected_magic}, but #{actual_magic} is given"
+ end
+ end
+
+ def read_uint8(input_file, n=1)
+ input_file.read(n).unpack('C*')
+ end
+
+ def read_uint32(input_file, n=1)
+ input_file.read(4 * n).unpack('N*')
+ end
+
+ def read_magic(input_file)
+ read_uint32(input_file).first
+ end
+
+ def read_total_count(input_file)
+ read_uint32(input_file).first
+ end
+
+ def read_image_size(input_file)
+ read_uint32(input_file, 2)
+ end
+
+ alias read_labels read_uint8
+
+ def read_image(nrows, ncols)
+ input_images.read(nrows * ncols).unpack("C*")
+ end
+
+ def input_images
+ @input_images ||= File.open(filename_image)
+ end
+
+ def input_labels
+ @input_labels ||= File.open(filename_label)
+ end
+ end
+
+ def self.load_images(filename)
+ Loader.new(filename).load_images
+ end
+
+ def self.load_labels(filename)
+ Loader.new(filename).load_labels
+ end
+
+ def self.read_data_sets(path, one_hot: false)
+ unless Dir.exist?(path)
+ FileUtils.mkdir_p path
+ end
+
+ base_url = "yann.lecun.com"
+ filenames = [
+ "train-images-idx3-ubyte.gz",
+ "train-labels-idx1-ubyte.gz",
+ "t10k-images-idx3-ubyte.gz",
+ "t10k-labels-idx1-ubyte.gz"
+ ]
+ Net::HTTP.start(base_url) do |http|
+ filenames.each do |name|
+ unless File.exists?(File.join(path, name))
+ f = File.open(File.join(path, name), "wb")
+ begin
+ http.request_get('/exdb/mnist/' + name) do |resp|
+ resp.read_body do |segment|
+ f.write(segment)
+ end
+ end
+ ensure
+ f.close
+ end
+ end
+ end
+ end
+
+ filenames.each do |name|
+ next if File.exists?(File.join(path, File.basename(name, '.gz')))
+ puts "extracting #{name} ..."
+ Zlib::GzipReader.open(File.join(path, name)) do |zipfile|
+ outfile = File.open(File.join(path, File.basename(name, '.gz')), 'wb')
+ outfile.write(zipfile.read)
+ end
+ end
+ MnistReader.new(path, one_hot)
+ end
+end