lib/torch/utils/data/data_loader.rb in torch-rb-0.3.1 vs lib/torch/utils/data/data_loader.rb in torch-rb-0.3.2

- old
+ new

@@ -4,14 +4,26 @@ class DataLoader include Enumerable attr_reader :dataset - def initialize(dataset, batch_size: 1, shuffle: false) + def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil) @dataset = dataset @batch_size = batch_size @shuffle = shuffle + + @batch_sampler = nil + + if collate_fn.nil? + if auto_collation? + collate_fn = method(:default_collate) + else + collate_fn = method(:default_convert) + end + end + + @collate_fn = collate_fn end def each # try to keep the random number generator in sync with Python # this makes it easy to compare results @@ -23,32 +35,36 @@ else @dataset.size.times end indexes.each_slice(@batch_size) do |idx| - batch = idx.map { |i| @dataset[i] } - yield collate(batch) + # TODO improve performance + yield @collate_fn.call(idx.map { |i| @dataset[i] }) end end def size (@dataset.size / @batch_size.to_f).ceil end private - def collate(batch) + def default_convert(batch) elem = batch[0] case elem when Tensor Torch.stack(batch, 0) when Integer Torch.tensor(batch) when Array - batch.transpose.map { |v| collate(v) } + batch.transpose.map { |v| default_convert(v) } else - raise NotImpelmentYet + raise NotImplementedYet end + end + + def auto_collation? + !@batch_sampler.nil? end end end end end