lib/torch/utils/data/data_loader.rb in torch-rb-0.2.3 vs lib/torch/utils/data/data_loader.rb in torch-rb-0.2.4

- old
+ new

@@ -4,24 +4,29 @@ class DataLoader include Enumerable attr_reader :dataset - def initialize(dataset, batch_size: 1) + def initialize(dataset, batch_size: 1, shuffle: false) @dataset = dataset @batch_size = batch_size + @shuffle = shuffle end def each # try to keep the random number generator in sync with Python # this makes it easy to compare results base_seed = Torch.empty([], dtype: :int64).random!.item - max_size = @dataset.size - size.times do |i| - start_index = i * @batch_size - end_index = [start_index + @batch_size, max_size].min - batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] } + indexes = + if @shuffle + Torch.randperm(@dataset.size).to_a + else + @dataset.size.times + end + + indexes.each_slice(@batch_size) do |idx| + batch = idx.map { |i| @dataset[i] } yield collate(batch) end end def size