lib/dnn/core/iterator.rb in ruby-dnn-0.12.4 vs lib/dnn/core/iterator.rb in ruby-dnn-0.13.0
- old
+ new
@@ -6,11 +6,11 @@
# @param [Boolean] random Set true to return batches randomly. Setting false returns batches in order of index.
def initialize(x_datas, y_datas, random: true)
@x_datas = x_datas
@y_datas = y_datas
@random = random
- @num_datas = x_datas.shape[0]
+ @num_datas = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0]
reset
end
# Return the next batch.
# @param [Integer] batch_size Required batch size.
@@ -20,11 +20,19 @@
batch_indexes = @indexes
@has_next = false
else
batch_indexes = @indexes.shift(batch_size)
end
- x_batch = @x_datas[batch_indexes, false]
- y_batch = @y_datas[batch_indexes, false]
+ x_batch = if @x_datas.is_a?(Array)
+ @x_datas.map { |datas| datas[batch_indexes, false] }
+ else
+ @x_datas[batch_indexes, false]
+ end
+ y_batch = if @y_datas.is_a?(Array)
+ @y_datas.map { |datas| datas[batch_indexes, false] }
+ else
+ @y_datas[batch_indexes, false]
+ end
[x_batch, y_batch]
end
# Reset input datas and output datas.
def reset