Sha256: 7a28c548b9bcfcd6d05382a0a0b90d7bce09988f9a2dfb03848b9c1f3c649323
Contents?: true
Size: 537 Bytes
Versions: 37
Compression:
Stored size: 537 Bytes
Contents
module Torch module Utils module Data class TensorDataset < Dataset def initialize(*tensors) unless tensors.all? { |t| t.size(0) == tensors[0].size(0) } raise Error, "Tensors must all have same dim 0 size" end @tensors = tensors end def [](index) @tensors.map { |t| t[index] } end def size @tensors[0].size(0) end alias_method :length, :size alias_method :count, :size end end end end
Version data entries
37 entries across 37 versions & 1 rubygems