lib/dnn/core/iterator.rb in ruby-dnn-0.12.4 vs lib/dnn/core/iterator.rb in ruby-dnn-0.13.0

- old
+ new

@@ -6,11 +6,11 @@ # @param [Boolean] random Set true to return batches randomly. Setting false returns batches in order of index. def initialize(x_datas, y_datas, random: true) @x_datas = x_datas @y_datas = y_datas @random = random - @num_datas = x_datas.shape[0] + @num_datas = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0] reset end # Return the next batch. # @param [Integer] batch_size Required batch size. @@ -20,11 +20,19 @@ batch_indexes = @indexes @has_next = false else batch_indexes = @indexes.shift(batch_size) end - x_batch = @x_datas[batch_indexes, false] - y_batch = @y_datas[batch_indexes, false] + x_batch = if @x_datas.is_a?(Array) + @x_datas.map { |datas| datas[batch_indexes, false] } + else + @x_datas[batch_indexes, false] + end + y_batch = if @y_datas.is_a?(Array) + @y_datas.map { |datas| datas[batch_indexes, false] } + else + @y_datas[batch_indexes, false] + end [x_batch, y_batch] end # Reset input datas and output datas. def reset