lib/torch/utils/data/data_loader.rb in torch-rb-0.1.2 vs lib/torch/utils/data/data_loader.rb in torch-rb-0.1.3
- old
+ new
@@ -1,20 +1,26 @@
module Torch
module Utils
module Data
class DataLoader
+ include Enumerable
+
+ attr_reader :dataset
+
def initialize(dataset, batch_size: 1)
@dataset = dataset
@batch_size = batch_size
end
def each
- size = @dataset.size
- start_index = 0
- while start_index < size
+ size.times do |i|
+ start_index = i * @batch_size
yield @dataset[start_index...(start_index + @batch_size)]
- start_index += @batch_size
end
+ end
+
+ def size
+ (@dataset.size / @batch_size.to_f).ceil
end
end
end
end
end