Sha256: a8e292f59f3a5f957c16950299de4379ef7baa003ffaa16c32eecfaef193dea5

Contents?: true

Size: 688 Bytes

Versions: 2

Compression:

Stored size: 688 Bytes

Contents

module TensorFlow
  module Data
    class BatchDataset < Dataset
      def initialize(input_dataset, batch_size, drop_remainder)
        @input_dataset = input_dataset # keep reference for memory
        @output_types = input_dataset.output_types
        @output_shapes = input_dataset.output_shapes.map { |s| [batch_size] + s }

        variant_tensor = RawOps.batch_dataset_v2(
          input_dataset: input_dataset,
          batch_size: TensorFlow.convert_to_tensor(batch_size, dtype: :int64),
          drop_remainder: drop_remainder,
          output_types: @output_types,
          output_shapes: @output_shapes
        )
        super(variant_tensor)
      end
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/data/batch_dataset.rb
tensorflow-0.1.2 lib/tensorflow/data/batch_dataset.rb