Sha256: 1f75eecfa8eba96928a82a049f4e7f87d6e919e24266e36fc14c9178660bd56c

Contents?: true

Size: 1.24 KB

Versions: 2

Compression:

Stored size: 1.24 KB

Contents

module TensorFlow
  module Data
    class Dataset
      include Enumerable

      # TODO remove
      attr_reader :output_types, :output_shapes

      def initialize(variant_tensor)
        @variant_tensor = variant_tensor
      end

      def batch(batch_size, drop_remainder: false)
        BatchDataset.new(self, batch_size, drop_remainder)
      end

      def shuffle(buffer_size)
        ShuffleDataset.new(self, buffer_size)
      end

      def self.from_tensor_slices(tensors)
        TensorSliceDataset.new(tensors)
      end

      def to_ptr
        @variant_tensor.to_ptr
      end

      def each
        iterator, deleter = RawOps.anonymous_iterator_v2(output_types: @output_types, output_shapes: @output_shapes)
        RawOps.make_iterator(dataset: @variant_tensor, iterator: iterator)
        begin
          loop do
            values = RawOps.iterator_get_next_sync(iterator: iterator, output_types: @output_types, output_shapes: @output_shapes)
            yield values
          end
        rescue Error => e
          # iterate until end of sequence error
          raise e unless e.message == "End of sequence"
        end
      ensure
        RawOps.delete_iterator(handle: iterator, deleter: deleter) if iterator
      end
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/data/dataset.rb
tensorflow-0.1.2 lib/tensorflow/data/dataset.rb