lib/dnn/core/iterator.rb in ruby-dnn-0.14.3 vs lib/dnn/core/iterator.rb in ruby-dnn-0.15.0

- old
+ new

@@ -2,12 +2,12 @@ # This class manages input datas and output datas together. class Iterator attr_reader :num_datas attr_reader :last_round_down - # @param [Numo::SFloat] x_datas input datas. - # @param [Numo::SFloat] y_datas output datas. + # @param [Numo::SFloat | Array] x_datas input datas. + # @param [Numo::SFloat | Array] y_datas output datas. # @param [Boolean] random Set true to return batches randomly. Setting false returns batches in order of index. # @param [Boolean] last_round_down Set true to round down for last batch data when call foreach. def initialize(x_datas, y_datas, random: true, last_round_down: false) @x_datas = x_datas @y_datas = y_datas @@ -17,17 +17,25 @@ reset end # Return the next batch. # @param [Integer] batch_size Required batch size. + # @return [Array] Returns the mini batch in the form [x_batch, y_batch]. def next_batch(batch_size) raise DNN_Error, "This iterator has not next batch. Please call reset." unless has_next? if @indexes.length <= batch_size batch_indexes = @indexes @has_next = false else batch_indexes = @indexes.shift(batch_size) end + get_batch(batch_indexes) + end + + # Implement a process to get mini batch. + # @param [Array] batch_indexes Index of batch to get. + # @return [Array] Returns the mini batch in the form [x_batch, y_batch]. + private def get_batch(batch_indexes) x_batch = if @x_datas.is_a?(Array) @x_datas.map { |datas| datas[batch_indexes, false] } else @x_datas[batch_indexes, false] end