lib/dnn/core/iterator.rb in ruby-dnn-0.14.3 vs lib/dnn/core/iterator.rb in ruby-dnn-0.15.0
- old
+ new
@@ -2,12 +2,12 @@
# This class manages input datas and output datas together.
class Iterator
attr_reader :num_datas
attr_reader :last_round_down
- # @param [Numo::SFloat] x_datas input datas.
- # @param [Numo::SFloat] y_datas output datas.
+ # @param [Numo::SFloat | Array] x_datas input datas.
+ # @param [Numo::SFloat | Array] y_datas output datas.
# @param [Boolean] random Set true to return batches randomly. Setting false returns batches in order of index.
# @param [Boolean] last_round_down Set true to round down for last batch data when call foreach.
def initialize(x_datas, y_datas, random: true, last_round_down: false)
@x_datas = x_datas
@y_datas = y_datas
@@ -17,17 +17,25 @@
reset
end
# Return the next batch.
# @param [Integer] batch_size Required batch size.
+ # @return [Array] Returns the mini batch in the form [x_batch, y_batch].
def next_batch(batch_size)
raise DNN_Error, "This iterator has not next batch. Please call reset." unless has_next?
if @indexes.length <= batch_size
batch_indexes = @indexes
@has_next = false
else
batch_indexes = @indexes.shift(batch_size)
end
+ get_batch(batch_indexes)
+ end
+
+ # Implement a process to get mini batch.
+ # @param [Array] batch_indexes Index of batch to get.
+ # @return [Array] Returns the mini batch in the form [x_batch, y_batch].
+ private def get_batch(batch_indexes)
x_batch = if @x_datas.is_a?(Array)
@x_datas.map { |datas| datas[batch_indexes, false] }
else
@x_datas[batch_indexes, false]
end