Sha256: f2191878f841765e3dd03095868a9963e64ec213d2e0010310dee55aa3212b8f

Contents?: true

Size: 456 Bytes

Versions: 10

Compression:

Stored size: 456 Bytes

Contents

module Torch
  module Utils
    module Data
      class TensorDataset
        def initialize(*tensors)
          unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
            raise Error, "Tensors must all have same dim 0 size"
          end
          @tensors = tensors
        end

        def [](index)
          @tensors.map { |t| t[index] }
        end

        def size
          @tensors[0].size(0)
        end
      end
    end
  end
end

Version data entries

10 entries across 10 versions & 1 rubygems

Version Path
torch-rb-0.2.3 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.2.2 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.2.1 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.2.0 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.1.8 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.1.7 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.1.6 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.1.5 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.1.4 lib/torch/utils/data/tensor_dataset.rb
torch-rb-0.1.3 lib/torch/utils/data/tensor_dataset.rb