lib/torch/utils/data/data_loader.rb in torch-rb-0.2.3 vs lib/torch/utils/data/data_loader.rb in torch-rb-0.2.4
- old
+ new
@@ -4,24 +4,29 @@
class DataLoader
include Enumerable
attr_reader :dataset
- def initialize(dataset, batch_size: 1)
+ def initialize(dataset, batch_size: 1, shuffle: false)
@dataset = dataset
@batch_size = batch_size
+ @shuffle = shuffle
end
def each
# try to keep the random number generator in sync with Python
# this makes it easy to compare results
base_seed = Torch.empty([], dtype: :int64).random!.item
- max_size = @dataset.size
- size.times do |i|
- start_index = i * @batch_size
- end_index = [start_index + @batch_size, max_size].min
- batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
+ indexes =
+ if @shuffle
+ Torch.randperm(@dataset.size).to_a
+ else
+ @dataset.size.times
+ end
+
+ indexes.each_slice(@batch_size) do |idx|
+ batch = idx.map { |i| @dataset[i] }
yield collate(batch)
end
end
def size