Sha256: 1f75eecfa8eba96928a82a049f4e7f87d6e919e24266e36fc14c9178660bd56c
Contents?: true
Size: 1.24 KB
Versions: 2
Compression:
Stored size: 1.24 KB
Contents
module TensorFlow module Data class Dataset include Enumerable # TODO remove attr_reader :output_types, :output_shapes def initialize(variant_tensor) @variant_tensor = variant_tensor end def batch(batch_size, drop_remainder: false) BatchDataset.new(self, batch_size, drop_remainder) end def shuffle(buffer_size) ShuffleDataset.new(self, buffer_size) end def self.from_tensor_slices(tensors) TensorSliceDataset.new(tensors) end def to_ptr @variant_tensor.to_ptr end def each iterator, deleter = RawOps.anonymous_iterator_v2(output_types: @output_types, output_shapes: @output_shapes) RawOps.make_iterator(dataset: @variant_tensor, iterator: iterator) begin loop do values = RawOps.iterator_get_next_sync(iterator: iterator, output_types: @output_types, output_shapes: @output_shapes) yield values end rescue Error => e # iterate until end of sequence error raise e unless e.message == "End of sequence" end ensure RawOps.delete_iterator(handle: iterator, deleter: deleter) if iterator end end end end
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
tensorflow-0.2.0 | lib/tensorflow/data/dataset.rb |
tensorflow-0.1.2 | lib/tensorflow/data/dataset.rb |