lib/torch/utils/data/tensor_dataset.rb in torch-rb-0.1.2 vs lib/torch/utils/data/tensor_dataset.rb in torch-rb-0.1.3
- old
+ new
@@ -1,9 +1,12 @@
module Torch
module Utils
module Data
class TensorDataset
def initialize(*tensors)
+ unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
+ raise Error, "Tensors must all have same dim 0 size"
+ end
@tensors = tensors
end
def [](index)
@tensors.map { |t| t[index] }