Sha256: bd2f4ed7f20bc43eaad6c08f632f4cb1689c7929c2ace67d16ae6ac2d6b94378
Contents?: true
Size: 1.75 KB
Versions: 3
Compression:
Stored size: 1.75 KB
Contents
module Torch module Utils module Data class DataLoader include Enumerable attr_reader :dataset def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil) @dataset = dataset @batch_size = batch_size @shuffle = shuffle @batch_sampler = nil if collate_fn.nil? if auto_collation? collate_fn = method(:default_collate) else collate_fn = method(:default_convert) end end @collate_fn = collate_fn end def each return to_enum(:each) unless block_given? # try to keep the random number generator in sync with Python # this makes it easy to compare results base_seed = Torch.empty([], dtype: :int64).random!.item indexes = if @shuffle Torch.randperm(@dataset.size).to_a else @dataset.size.times end indexes.each_slice(@batch_size) do |idx| # TODO improve performance yield @collate_fn.call(idx.map { |i| @dataset[i] }) end end def size (@dataset.size / @batch_size.to_f).ceil end alias_method :length, :size alias_method :count, :size private def default_convert(batch) elem = batch[0] case elem when Tensor Torch.stack(batch, 0) when Integer Torch.tensor(batch) when Array batch.transpose.map { |v| default_convert(v) } else batch end end def auto_collation? !@batch_sampler.nil? end end end end end
Version data entries
3 entries across 3 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.9.0 | lib/torch/utils/data/data_loader.rb |
torch-rb-0.8.3 | lib/torch/utils/data/data_loader.rb |
torch-rb-0.8.2 | lib/torch/utils/data/data_loader.rb |