Sha256: ca823a9df4a3617add8420c80b302230fe9e90cf68511293d81519abd0219db5
Contents?: true
Size: 610 Bytes
Versions: 43
Compression:
Stored size: 610 Bytes
Contents
module Torch module Utils module Data class << self def random_split(dataset, lengths) if lengths.sum != dataset.length raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!" end indices = Torch.randperm(lengths.sum).to_a _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) } end private def _accumulate(iterable) sum = 0 iterable.map { |x| sum += x } end end end end end
Version data entries
43 entries across 43 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.3.4 | lib/torch/utils/data.rb |
torch-rb-0.3.3 | lib/torch/utils/data.rb |
torch-rb-0.3.2 | lib/torch/utils/data.rb |