lib/torch/utils/data/data_loader.rb in torch-rb-0.1.2 vs lib/torch/utils/data/data_loader.rb in torch-rb-0.1.3

- old
+ new

@@ -1,20 +1,26 @@ module Torch module Utils module Data class DataLoader + include Enumerable + + attr_reader :dataset + def initialize(dataset, batch_size: 1) @dataset = dataset @batch_size = batch_size end def each - size = @dataset.size - start_index = 0 - while start_index < size + size.times do |i| + start_index = i * @batch_size yield @dataset[start_index...(start_index + @batch_size)] - start_index += @batch_size end + end + + def size + (@dataset.size / @batch_size.to_f).ceil end end end end end