Sha256: 1949d81dae32185908b45e1f70122f4cdcd9b24911bd17e4d37336e0d9dfee21

Contents?: true

Size: 474 Bytes

Versions: 2

Compression:

Stored size: 474 Bytes

Contents

module TensorFlow
  module Data
    class TensorSliceDataset < Dataset
      def initialize(element)
        tensors = Utils.to_tensor_array(element)
        @tensors = tensors # keep reference for memory
        @output_types = tensors.map(&:dtype)
        @output_shapes = tensors.map { |t| t.shape[1..-1] }

        variant_tensor = RawOps.tensor_slice_dataset(components: tensors, output_shapes: @output_shapes)
        super(variant_tensor)
      end
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/data/tensor_slice_dataset.rb
tensorflow-0.1.2 lib/tensorflow/data/tensor_slice_dataset.rb