lib/dnn/core/iterator.rb in ruby-dnn-0.13.4 vs lib/dnn/core/iterator.rb in ruby-dnn-0.14.0

- old
+ new

@@ -1,39 +1,44 @@ module DNN # This class manages input datas and output datas together. class Iterator + attr_reader :num_datas + attr_reader :last_round_down + # @param [Numo::SFloat] x_datas input datas. # @param [Numo::SFloat] y_datas output datas. # @param [Boolean] random Set true to return batches randomly. Setting false returns batches in order of index. - def initialize(x_datas, y_datas, random: true) + # @param [Boolean] last_round_down Set true to round down for last batch data when call foreach. + def initialize(x_datas, y_datas, random: true, last_round_down: false) @x_datas = x_datas @y_datas = y_datas @random = random + @last_round_down = last_round_down @num_datas = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0] reset end # Return the next batch. # @param [Integer] batch_size Required batch size. def next_batch(batch_size) - raise DNN_Error.new("This iterator has not next batch. Please call reset.") unless has_next? + raise DNN_Error, "This iterator has not next batch. Please call reset." unless has_next? if @indexes.length <= batch_size batch_indexes = @indexes @has_next = false else batch_indexes = @indexes.shift(batch_size) end x_batch = if @x_datas.is_a?(Array) - @x_datas.map { |datas| datas[batch_indexes, false] } - else - @x_datas[batch_indexes, false] - end + @x_datas.map { |datas| datas[batch_indexes, false] } + else + @x_datas[batch_indexes, false] + end y_batch = if @y_datas.is_a?(Array) - @y_datas.map { |datas| datas[batch_indexes, false] } - else - @y_datas[batch_indexes, false] - end + @y_datas.map { |datas| datas[batch_indexes, false] } + else + @y_datas[batch_indexes, false] + end [x_batch, y_batch] end # Reset input datas and output datas. def reset @@ -46,14 +51,13 @@ def has_next? @has_next end def foreach(batch_size, &block) - step = 0 - while has_next? + steps = @last_round_down ? @num_datas / batch_size : (@num_datas.to_f / batch_size).ceil + steps.times do |step| x_batch, y_batch = next_batch(batch_size) block.call(x_batch, y_batch, step) - step += 1 end reset end end end