lib/dnn/cifar10.rb in ruby-dnn-0.10.1 vs lib/dnn/cifar10.rb in ruby-dnn-0.10.2
- old
+ new
@@ -1,51 +1,51 @@
-require "zlib"
-require "archive/tar/minitar"
-require_relative "../../ext/cifar_loader/cifar_loader"
-require_relative "downloader"
-
-URL_CIFAR10 = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
-DIR_CIFAR10 = "cifar-10-batches-bin"
-
-module DNN
- module CIFAR10
- class DNN_CIFAR10_LoadError < DNN_Error; end
-
- def self.downloads
- return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR10)
- Downloader.download(URL_CIFAR10)
- cifar10_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
- begin
- Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
- Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
- end
- ensure
- File.unlink(cifar10_binary_file_name)
- end
- end
-
- def self.load_train
- downloads
- bin = ""
- (1..5).each do |i|
- fname = __dir__ + "/downloads/#{DIR_CIFAR10}/data_batch_#{i}.bin"
- raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
- bin << File.binread(fname)
- end
- x_bin, y_bin = CIFAR10.load_binary(bin, 50000)
- x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
- y_train = Numo::UInt8.from_binary(y_bin)
- [x_train, y_train]
- end
-
- def self.load_test
- downloads
- fname = __dir__ + "/downloads/#{DIR_CIFAR10}/test_batch.bin"
- raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
- bin = File.binread(fname)
- x_bin, y_bin = CIFAR10.load_binary(bin, 10000)
- x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
- y_test = Numo::UInt8.from_binary(y_bin)
- [x_test, y_test]
- end
- end
-end
+require "zlib"
+require "archive/tar/minitar"
+require_relative "../../ext/cifar_loader/cifar_loader"
+require_relative "downloader"
+
+URL_CIFAR10 = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
+DIR_CIFAR10 = "cifar-10-batches-bin"
+
+module DNN
+ module CIFAR10
+ class DNN_CIFAR10_LoadError < DNN_Error; end
+
+ def self.downloads
+ return if Dir.exist?(__dir__ + "/downloads/" + DIR_CIFAR10)
+ Downloader.download(URL_CIFAR10)
+ cifar10_binary_file_name = __dir__ + "/downloads/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
+ begin
+ Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
+ Archive::Tar::Minitar::unpack(gz, __dir__ + "/downloads")
+ end
+ ensure
+ File.unlink(cifar10_binary_file_name)
+ end
+ end
+
+ def self.load_train
+ downloads
+ bin = ""
+ (1..5).each do |i|
+ fname = __dir__ + "/downloads/#{DIR_CIFAR10}/data_batch_#{i}.bin"
+ raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
+ bin << File.binread(fname)
+ end
+ x_bin, y_bin = CIFAR10.load_binary(bin, 50000)
+ x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
+ y_train = Numo::UInt8.from_binary(y_bin)
+ [x_train, y_train]
+ end
+
+ def self.load_test
+ downloads
+ fname = __dir__ + "/downloads/#{DIR_CIFAR10}/test_batch.bin"
+ raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
+ bin = File.binread(fname)
+ x_bin, y_bin = CIFAR10.load_binary(bin, 10000)
+ x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
+ y_test = Numo::UInt8.from_binary(y_bin)
+ [x_test, y_test]
+ end
+ end
+end