Sha256: adcd9a381aeda20292ee6cbe2ae2af50157d5bc9614ede7dcfd0f2c993c88306

Contents?: true

Size: 1.65 KB

Versions: 4

Compression:

Stored size: 1.65 KB

Contents

module Torch
  module Utils
    module Data
      class DataLoader
        include Enumerable

        attr_reader :dataset

        def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
          @dataset = dataset
          @batch_size = batch_size
          @shuffle = shuffle

          @batch_sampler = nil

          if collate_fn.nil?
            if auto_collation?
              collate_fn = method(:default_collate)
            else
              collate_fn = method(:default_convert)
            end
          end

          @collate_fn = collate_fn
        end

        def each
          # try to keep the random number generator in sync with Python
          # this makes it easy to compare results
          base_seed = Torch.empty([], dtype: :int64).random!.item

          indexes =
            if @shuffle
              Torch.randperm(@dataset.size).to_a
            else
              @dataset.size.times
            end

          indexes.each_slice(@batch_size) do |idx|
            # TODO improve performance
            yield @collate_fn.call(idx.map { |i| @dataset[i] })
          end
        end

        def size
          (@dataset.size / @batch_size.to_f).ceil
        end

        private

        def default_convert(batch)
          elem = batch[0]
          case elem
          when Tensor
            Torch.stack(batch, 0)
          when Integer
            Torch.tensor(batch)
          when Array
            batch.transpose.map { |v| default_convert(v) }
          else
            raise NotImplementedYet
          end
        end

        def auto_collation?
          !@batch_sampler.nil?
        end
      end
    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
torch-rb-0.3.5 lib/torch/utils/data/data_loader.rb
torch-rb-0.3.4 lib/torch/utils/data/data_loader.rb
torch-rb-0.3.3 lib/torch/utils/data/data_loader.rb
torch-rb-0.3.2 lib/torch/utils/data/data_loader.rb