lib/dnn/datasets/cifar100.rb in ruby-dnn-0.15.3 vs lib/dnn/datasets/cifar100.rb in ruby-dnn-0.16.0
- old
+ new
@@ -1,8 +1,7 @@
require "zlib"
require "archive/tar/minitar"
-require_relative "../../../ext/cifar_loader/cifar_loader"
require_relative "downloader"
URL_CIFAR100 = "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
DIR_CIFAR100 = "cifar-100-binary"
@@ -25,25 +24,27 @@
def self.load_train
downloads
bin = ""
fname = DOWNLOADS_PATH + "/downloads/#{DIR_CIFAR100}/train.bin"
- raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
+ raise DNN_CIFAR100_LoadError, %`file "#{fname}" is not found.` unless File.exist?(fname)
bin << File.binread(fname)
- x_bin, y_bin = CIFAR100.load_binary(bin, 50000)
- x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
- y_train = Numo::UInt8.from_binary(y_bin).reshape(50000, 2)
+ datas = Numo::UInt8.from_binary(bin).reshape(50000, 3074)
+ x_train = datas[true, 2...3074]
+ x_train = x_train.reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
+ y_train = datas[true, 0...2]
[x_train, y_train]
end
def self.load_test
downloads
fname = DOWNLOADS_PATH + "/downloads/#{DIR_CIFAR100}/test.bin"
- raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
+ raise DNN_CIFAR100_LoadError, %`file "#{fname}" is not found.` unless File.exist?(fname)
bin = File.binread(fname)
- x_bin, y_bin = CIFAR100.load_binary(bin, 10000)
- x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
- y_test = Numo::UInt8.from_binary(y_bin).reshape(10000, 2)
+ datas = Numo::UInt8.from_binary(bin).reshape(10000, 3074)
+ x_test = datas[true, 2...3074]
+ x_test = x_test.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
+ y_test = datas[true, 0...2]
[x_test, y_test]
end
end
end