lib/dnn/datasets/cifar100.rb in ruby-dnn-0.15.3 vs lib/dnn/datasets/cifar100.rb in ruby-dnn-0.16.0

- old
+ new

@@ -1,8 +1,7 @@ require "zlib" require "archive/tar/minitar" -require_relative "../../../ext/cifar_loader/cifar_loader" require_relative "downloader" URL_CIFAR100 = "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz" DIR_CIFAR100 = "cifar-100-binary" @@ -25,25 +24,27 @@ def self.load_train downloads bin = "" fname = DOWNLOADS_PATH + "/downloads/#{DIR_CIFAR100}/train.bin" - raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname) + raise DNN_CIFAR100_LoadError, %`file "#{fname}" is not found.` unless File.exist?(fname) bin << File.binread(fname) - x_bin, y_bin = CIFAR100.load_binary(bin, 50000) - x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone - y_train = Numo::UInt8.from_binary(y_bin).reshape(50000, 2) + datas = Numo::UInt8.from_binary(bin).reshape(50000, 3074) + x_train = datas[true, 2...3074] + x_train = x_train.reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone + y_train = datas[true, 0...2] [x_train, y_train] end def self.load_test downloads fname = DOWNLOADS_PATH + "/downloads/#{DIR_CIFAR100}/test.bin" - raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname) + raise DNN_CIFAR100_LoadError, %`file "#{fname}" is not found.` unless File.exist?(fname) bin = File.binread(fname) - x_bin, y_bin = CIFAR100.load_binary(bin, 10000) - x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone - y_test = Numo::UInt8.from_binary(y_bin).reshape(10000, 2) + datas = Numo::UInt8.from_binary(bin).reshape(10000, 3074) + x_test = datas[true, 2...3074] + x_test = x_test.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone + y_test = datas[true, 0...2] [x_test, y_test] end end end