Sha256: 82a1ee82725a77cf29da4fcfe109cbce40befaaf3759a36fee81a8503fb4ec64
Contents?: true
Size: 1.94 KB
Versions: 1
Compression:
Stored size: 1.94 KB
Contents
require "zlib" require "dnn/ext/mnist/mnist_ext" module DNN module MNIST class MNISTLoadError < StandardError end private_class_method :_mnist_load_images private_class_method :_mnist_load_labels def self.load_train train_images_file_name = "mnist/train-images-idx3-ubyte.gz" train_labels_file_name = "mnist/train-labels-idx1-ubyte.gz" unless File.exist?(train_images_file_name) raise MNISTLoadError.new(%`file "#{train_images_file_name}" is not found.`) end unless File.exist?(train_labels_file_name) raise MNISTLoadError.new(%`file "#{train_labels_file_name}" is not found.`) end images = load_images(train_images_file_name) labels = load_labels(train_labels_file_name) [images, labels] end def self.load_test test_images_file_name = "mnist/t10k-images-idx3-ubyte.gz" test_labels_file_name = "mnist/t10k-labels-idx1-ubyte.gz" unless File.exist?(test_images_file_name) raise MNISTLoadError.new(%`file "#{train_images_file_name}" is not found.`) end unless File.exist?(test_labels_file_name) raise MNISTLoadError.new(%`file "#{train_labels_file_name}" is not found.`) end images = load_images(test_images_file_name) labels = load_labels(test_labels_file_name) [images, labels] end private_class_method def self.load_images(file_name) images = nil Zlib::GzipReader.open(file_name) do |f| magic, num_images = f.read(8).unpack("N2") rows, cols = f.read(8).unpack("N2") images = _mnist_load_images(f.read, num_images, cols, rows) end images end def self.load_labels(file_name) labels = nil Zlib::GzipReader.open(file_name) do |f| magic, num_labels = f.read(8).unpack("N2") labels = _mnist_load_labels(f.read, num_labels) end labels end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
ruby-dnn-0.1.0 | lib/dnn/lib/mnist.rb |