Sha256: ca823a9df4a3617add8420c80b302230fe9e90cf68511293d81519abd0219db5

Contents?: true

Size: 610 Bytes

Versions: 41

Compression:

Stored size: 610 Bytes

Contents

module Torch
  module Utils
    module Data
      class << self
        def random_split(dataset, lengths)
          if lengths.sum != dataset.length
            raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
          end

          indices = Torch.randperm(lengths.sum).to_a
          _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
        end

        private

        def _accumulate(iterable)
          sum = 0
          iterable.map { |x| sum += x }
        end
      end
    end
  end
end

Version data entries

41 entries across 41 versions & 1 rubygems

Version Path
torch-rb-0.18.0 lib/torch/utils/data.rb
torch-rb-0.17.1 lib/torch/utils/data.rb
torch-rb-0.17.0 lib/torch/utils/data.rb
torch-rb-0.16.0 lib/torch/utils/data.rb
torch-rb-0.15.0 lib/torch/utils/data.rb
torch-rb-0.14.1 lib/torch/utils/data.rb
torch-rb-0.14.0 lib/torch/utils/data.rb
torch-rb-0.13.2 lib/torch/utils/data.rb
torch-rb-0.13.1 lib/torch/utils/data.rb
torch-rb-0.13.0 lib/torch/utils/data.rb
torch-rb-0.12.2 lib/torch/utils/data.rb
torch-rb-0.12.1 lib/torch/utils/data.rb
torch-rb-0.12.0 lib/torch/utils/data.rb
torch-rb-0.11.2 lib/torch/utils/data.rb
torch-rb-0.11.1 lib/torch/utils/data.rb
torch-rb-0.11.0 lib/torch/utils/data.rb
torch-rb-0.10.2 lib/torch/utils/data.rb
torch-rb-0.10.1 lib/torch/utils/data.rb
torch-rb-0.10.0 lib/torch/utils/data.rb
torch-rb-0.9.2 lib/torch/utils/data.rb