lib/torch/utils/data/tensor_dataset.rb in torch-rb-0.1.2 vs lib/torch/utils/data/tensor_dataset.rb in torch-rb-0.1.3

- old
+ new

@@ -1,9 +1,12 @@ module Torch module Utils module Data class TensorDataset def initialize(*tensors) + unless tensors.all? { |t| t.size(0) == tensors[0].size(0) } + raise Error, "Tensors must all have same dim 0 size" + end @tensors = tensors end def [](index) @tensors.map { |t| t[index] }