Sha256: ba31de1270d5eff98ea74f4035cdc5147ce3aff85728404f6768a32bf9996d40

Contents?: true

Size: 1.2 KB

Versions: 1

Compression:

Stored size: 1.2 KB

Contents

module Torch
  module Utils
    module Data
      class DataLoader
        include Enumerable

        attr_reader :dataset

        def initialize(dataset, batch_size: 1)
          @dataset = dataset
          @batch_size = batch_size
        end

        def each
          # try to keep the random number generator in sync with Python
          # this makes it easy to compare results
          base_seed = Torch.empty([], dtype: :int64).random!.item

          max_size = @dataset.size
          size.times do |i|
            start_index = i * @batch_size
            end_index = [start_index + @batch_size, max_size].min
            batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
            yield collate(batch)
          end
        end

        def size
          (@dataset.size / @batch_size.to_f).ceil
        end

        private

        def collate(batch)
          elem = batch[0]
          case elem
          when Tensor
            Torch.stack(batch, 0)
          when Integer
            Torch.tensor(batch)
          when Array
            batch.transpose.map { |v| collate(v) }
          else
            raise NotImpelmentYet
          end
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
torch-rb-0.2.3 lib/torch/utils/data/data_loader.rb