Sha256: fc2d83f035cf1ef1e04f79714e83193fd31ff164e5982787246eae0d80fae850

Contents?: true

Size: 734 Bytes

Versions: 2

Compression:

Stored size: 734 Bytes

Contents

module TensorFlow
  module Data
    class ShuffleDataset < Dataset
      def initialize(input_dataset, buffer_size)
        @input_dataset = input_dataset # keep reference for memory
        @output_types = input_dataset.output_types
        @output_shapes = input_dataset.output_shapes

        variant_tensor = RawOps.shuffle_dataset(
          input_dataset: input_dataset,
          buffer_size: TensorFlow.convert_to_tensor(buffer_size, dtype: :int64),
          seed: TensorFlow.convert_to_tensor(0, dtype: :int64),
          seed2: TensorFlow.convert_to_tensor(0, dtype: :int64),
          output_types: @output_types,
          output_shapes: @output_shapes
        )
        super(variant_tensor)
      end
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/data/shuffle_dataset.rb
tensorflow-0.1.2 lib/tensorflow/data/shuffle_dataset.rb