Sha256: 1949d81dae32185908b45e1f70122f4cdcd9b24911bd17e4d37336e0d9dfee21
Contents?: true
Size: 474 Bytes
Versions: 2
Compression:
Stored size: 474 Bytes
Contents
module TensorFlow module Data class TensorSliceDataset < Dataset def initialize(element) tensors = Utils.to_tensor_array(element) @tensors = tensors # keep reference for memory @output_types = tensors.map(&:dtype) @output_shapes = tensors.map { |t| t.shape[1..-1] } variant_tensor = RawOps.tensor_slice_dataset(components: tensors, output_shapes: @output_shapes) super(variant_tensor) end end end end
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
tensorflow-0.2.0 | lib/tensorflow/data/tensor_slice_dataset.rb |
tensorflow-0.1.2 | lib/tensorflow/data/tensor_slice_dataset.rb |