lib/dnn/lib/mnist.rb in ruby-dnn-0.1.0 vs lib/dnn/lib/mnist.rb in ruby-dnn-0.1.1

- old
+ new

@@ -1,46 +1,75 @@ +require "open-uri" require "zlib" +require "dnn/core/error" require "dnn/ext/mnist/mnist_ext" module DNN module MNIST - class MNISTLoadError < StandardError - end + class DNN_MNIST_LoadError < DNN_Error; end + class DNN_MNIST_DownloadError < DNN_Error; end + + URL_TRAIN_IMAGES = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz" + URL_TRAIN_LABELS = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz" + URL_TEST_IMAGES = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz" + URL_TEST_LABELS = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz" + private_class_method :_mnist_load_images private_class_method :_mnist_load_labels + def self.downloads + return if Dir.exist?(mnist_dir) + Dir.mkdir(mnist_dir) + puts "Now downloading..." + download(URL_TRAIN_IMAGES) + download(URL_TRAIN_LABELS) + download(URL_TEST_IMAGES) + download(URL_TEST_LABELS) + puts "The download has ended." + end + def self.load_train - train_images_file_name = "mnist/train-images-idx3-ubyte.gz" - train_labels_file_name = "mnist/train-labels-idx1-ubyte.gz" + downloads + train_images_file_name = url_to_file_name(URL_TRAIN_IMAGES) + train_labels_file_name = url_to_file_name(URL_TRAIN_LABELS) unless File.exist?(train_images_file_name) - raise MNISTLoadError.new(%`file "#{train_images_file_name}" is not found.`) + raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`) end unless File.exist?(train_labels_file_name) - raise MNISTLoadError.new(%`file "#{train_labels_file_name}" is not found.`) + raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`) end images = load_images(train_images_file_name) labels = load_labels(train_labels_file_name) [images, labels] end def self.load_test - test_images_file_name = "mnist/t10k-images-idx3-ubyte.gz" - test_labels_file_name = "mnist/t10k-labels-idx1-ubyte.gz" + downloads + test_images_file_name = url_to_file_name(URL_TEST_IMAGES) + test_labels_file_name = url_to_file_name(URL_TEST_LABELS) unless File.exist?(test_images_file_name) - raise MNISTLoadError.new(%`file "#{train_images_file_name}" is not found.`) + raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`) end unless File.exist?(test_labels_file_name) - raise MNISTLoadError.new(%`file "#{train_labels_file_name}" is not found.`) + raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`) end images = load_images(test_images_file_name) labels = load_labels(test_labels_file_name) [images, labels] end private_class_method + def self.download(url) + open(url, "rb") do |f| + File.binwrite(url_to_file_name(url), f.read) + end + rescue => ex + raise DNN_MNIST_DownloadError.new(ex.message) + end + def self.load_images(file_name) images = nil Zlib::GzipReader.open(file_name) do |f| magic, num_images = f.read(8).unpack("N2") rows, cols = f.read(8).unpack("N2") @@ -54,8 +83,16 @@ Zlib::GzipReader.open(file_name) do |f| magic, num_labels = f.read(8).unpack("N2") labels = _mnist_load_labels(f.read, num_labels) end labels + end + + def self.mnist_dir + __dir__ + "/mnist" + end + + def self.url_to_file_name(url) + mnist_dir + "/" + url.match(%r`.+/(.+)$`)[1] end end end \ No newline at end of file