lib/torch/utils/data/data_loader.rb in torch-rb-0.2.2 vs lib/torch/utils/data/data_loader.rb in torch-rb-0.2.3

- old
+ new

@@ -10,17 +10,40 @@ @dataset = dataset @batch_size = batch_size end def each + # try to keep the random number generator in sync with Python + # this makes it easy to compare results + base_seed = Torch.empty([], dtype: :int64).random!.item + + max_size = @dataset.size size.times do |i| start_index = i * @batch_size - yield @dataset[start_index...(start_index + @batch_size)] + end_index = [start_index + @batch_size, max_size].min + batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] } + yield collate(batch) end end def size (@dataset.size / @batch_size.to_f).ceil + end + + private + + def collate(batch) + elem = batch[0] + case elem + when Tensor + Torch.stack(batch, 0) + when Integer + Torch.tensor(batch) + when Array + batch.transpose.map { |v| collate(v) } + else + raise NotImpelmentYet + end end end end end end