Sha256: a8e292f59f3a5f957c16950299de4379ef7baa003ffaa16c32eecfaef193dea5
Contents?: true
Size: 688 Bytes
Versions: 2
Compression:
Stored size: 688 Bytes
Contents
module TensorFlow module Data class BatchDataset < Dataset def initialize(input_dataset, batch_size, drop_remainder) @input_dataset = input_dataset # keep reference for memory @output_types = input_dataset.output_types @output_shapes = input_dataset.output_shapes.map { |s| [batch_size] + s } variant_tensor = RawOps.batch_dataset_v2( input_dataset: input_dataset, batch_size: TensorFlow.convert_to_tensor(batch_size, dtype: :int64), drop_remainder: drop_remainder, output_types: @output_types, output_shapes: @output_shapes ) super(variant_tensor) end end end end
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
tensorflow-0.2.0 | lib/tensorflow/data/batch_dataset.rb |
tensorflow-0.1.2 | lib/tensorflow/data/batch_dataset.rb |