Sha256: ba31de1270d5eff98ea74f4035cdc5147ce3aff85728404f6768a32bf9996d40
Contents?: true
Size: 1.2 KB
Versions: 1
Compression:
Stored size: 1.2 KB
Contents
module Torch module Utils module Data class DataLoader include Enumerable attr_reader :dataset def initialize(dataset, batch_size: 1) @dataset = dataset @batch_size = batch_size end def each # try to keep the random number generator in sync with Python # this makes it easy to compare results base_seed = Torch.empty([], dtype: :int64).random!.item max_size = @dataset.size size.times do |i| start_index = i * @batch_size end_index = [start_index + @batch_size, max_size].min batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] } yield collate(batch) end end def size (@dataset.size / @batch_size.to_f).ceil end private def collate(batch) elem = batch[0] case elem when Tensor Torch.stack(batch, 0) when Integer Torch.tensor(batch) when Array batch.transpose.map { |v| collate(v) } else raise NotImpelmentYet end end end end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.2.3 | lib/torch/utils/data/data_loader.rb |