Sha256: abde1b3c83f2f2e2a8605c546fe844ba5cbb00da59697d3001fd8841992ca25f

Contents?: true

Size: 1.7 KB

Versions: 5

Compression:

Stored size: 1.7 KB

Contents

module DNN
  # This class manages input datas and output datas together.
  class Iterator
    # @param [Numo::SFloat] x_datas input datas.
    # @param [Numo::SFloat] y_datas output datas.
    # @param [Boolean] random Set true to return batches randomly. Setting false returns batches in order of index.
    def initialize(x_datas, y_datas, random: true)
      @x_datas = x_datas
      @y_datas = y_datas
      @random = random
      @num_datas = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0]
      reset
    end

    # Return the next batch.
    # @param [Integer] batch_size Required batch size.
    def next_batch(batch_size)
      raise DNN_Error.new("This iterator has not next batch. Please call reset.") unless has_next?
      if @indexes.length <= batch_size
        batch_indexes = @indexes
        @has_next = false
      else
        batch_indexes = @indexes.shift(batch_size)
      end
      x_batch = if @x_datas.is_a?(Array)
        @x_datas.map { |datas| datas[batch_indexes, false] }
      else
        @x_datas[batch_indexes, false]
      end
      y_batch = if @y_datas.is_a?(Array)
        @y_datas.map { |datas| datas[batch_indexes, false] }
      else
        @y_datas[batch_indexes, false]
      end
      [x_batch, y_batch]
    end

    # Reset input datas and output datas.
    def reset
      @has_next = true
      @indexes = @num_datas.times.to_a
      @indexes.shuffle! if @random
    end

    # Return the true if has next batch.
    def has_next?
      @has_next
    end

    def foreach(batch_size, &block)
      step = 0
      while has_next?
        x_batch, y_batch = next_batch(batch_size)
        block.call(x_batch, y_batch, step)
        step += 1
      end
      reset
    end
  end
end

Version data entries

5 entries across 5 versions & 1 rubygems

Version Path
ruby-dnn-0.13.4 lib/dnn/core/iterator.rb
ruby-dnn-0.13.3 lib/dnn/core/iterator.rb
ruby-dnn-0.13.2 lib/dnn/core/iterator.rb
ruby-dnn-0.13.1 lib/dnn/core/iterator.rb
ruby-dnn-0.13.0 lib/dnn/core/iterator.rb