lib/dnn/datasets/stl-10.rb in ruby-dnn-0.15.3 vs lib/dnn/datasets/stl-10.rb in ruby-dnn-0.16.0
- old
+ new
@@ -23,39 +23,37 @@
end
def self.load_train
downloads
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_X.bin"
- raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
+ raise DNN_STL10_LoadError, %`file "#{x_fname}" is not found.` unless File.exist?(x_fname)
y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_y.bin"
- raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
+ raise DNN_STL10_LoadError, %`file "#{y_fname}" is not found.` unless File.exist?(y_fname)
x_bin = File.binread(x_fname)
y_bin = File.binread(y_fname)
x_train = Numo::UInt8.from_binary(x_bin).reshape(5000, 3, 96, 96).transpose(0, 3, 2, 1).clone
y_train = Numo::UInt8.from_binary(y_bin)
[x_train, y_train]
end
def self.load_test
downloads
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_X.bin"
- raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
+ raise DNN_STL10_LoadError, %`file "#{x_fname}" is not found.` unless File.exist?(x_fname)
y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_y.bin"
- raise DNN_STL10_LoadError.new(%`file "#{y_fname}" is not found.`) unless File.exist?(y_fname)
+ raise DNN_STL10_LoadError, %`file "#{y_fname}" is not found.` unless File.exist?(y_fname)
x_bin = File.binread(x_fname)
y_bin = File.binread(y_fname)
x_test = Numo::UInt8.from_binary(x_bin).reshape(8000, 3, 96, 96).transpose(0, 3, 2, 1).clone
y_test = Numo::UInt8.from_binary(y_bin)
[x_test, y_test]
end
def self.load_unlabeled(range = 0...100000)
- unless 0 <= range.begin && range.end <= 100000
- raise DNN_Error, "Range must between 0 and 100000. (But the end is excluded)"
- end
+ raise DNN_Error, "Range must between 0 and 100000. (But the end is excluded)" unless range.begin >= 0 && range.end <= 100000
downloads
x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/unlabeled_X.bin"
- raise DNN_STL10_LoadError.new(%`file "#{x_fname}" is not found.`) unless File.exist?(x_fname)
+ raise DNN_STL10_LoadError, %`file "#{x_fname}" is not found.` unless File.exist?(x_fname)
num_datas = range.end - range.begin
length = num_datas * 3 * 96 * 96
ofs = range.begin * 3 * 96 * 96
x_bin = File.binread(x_fname, length, ofs)
Numo::UInt8.from_binary(x_bin).reshape(num_datas, 3, 96, 96).transpose(0, 3, 2, 1).clone