Sha256: ca823a9df4a3617add8420c80b302230fe9e90cf68511293d81519abd0219db5
Contents?: true
Size: 610 Bytes
Versions: 41
Compression:
Stored size: 610 Bytes
Contents
module Torch module Utils module Data class << self def random_split(dataset, lengths) if lengths.sum != dataset.length raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!" end indices = Torch.randperm(lengths.sum).to_a _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) } end private def _accumulate(iterable) sum = 0 iterable.map { |x| sum += x } end end end end end
Version data entries
41 entries across 41 versions & 1 rubygems