lib/torch/utils/data/data_loader.rb in torch-rb-0.2.2 vs lib/torch/utils/data/data_loader.rb in torch-rb-0.2.3
- old
+ new
@@ -10,17 +10,40 @@
@dataset = dataset
@batch_size = batch_size
end
def each
+ # try to keep the random number generator in sync with Python
+ # this makes it easy to compare results
+ base_seed = Torch.empty([], dtype: :int64).random!.item
+
+ max_size = @dataset.size
size.times do |i|
start_index = i * @batch_size
- yield @dataset[start_index...(start_index + @batch_size)]
+ end_index = [start_index + @batch_size, max_size].min
+ batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
+ yield collate(batch)
end
end
def size
(@dataset.size / @batch_size.to_f).ceil
+ end
+
+ private
+
+ def collate(batch)
+ elem = batch[0]
+ case elem
+ when Tensor
+ Torch.stack(batch, 0)
+ when Integer
+ Torch.tensor(batch)
+ when Array
+ batch.transpose.map { |v| collate(v) }
+ else
+ raise NotImpelmentYet
+ end
end
end
end
end
end