Sha256: abde1b3c83f2f2e2a8605c546fe844ba5cbb00da59697d3001fd8841992ca25f
Contents?: true
Size: 1.7 KB
Versions: 5
Compression:
Stored size: 1.7 KB
Contents
module DNN # This class manages input datas and output datas together. class Iterator # @param [Numo::SFloat] x_datas input datas. # @param [Numo::SFloat] y_datas output datas. # @param [Boolean] random Set true to return batches randomly. Setting false returns batches in order of index. def initialize(x_datas, y_datas, random: true) @x_datas = x_datas @y_datas = y_datas @random = random @num_datas = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0] reset end # Return the next batch. # @param [Integer] batch_size Required batch size. def next_batch(batch_size) raise DNN_Error.new("This iterator has not next batch. Please call reset.") unless has_next? if @indexes.length <= batch_size batch_indexes = @indexes @has_next = false else batch_indexes = @indexes.shift(batch_size) end x_batch = if @x_datas.is_a?(Array) @x_datas.map { |datas| datas[batch_indexes, false] } else @x_datas[batch_indexes, false] end y_batch = if @y_datas.is_a?(Array) @y_datas.map { |datas| datas[batch_indexes, false] } else @y_datas[batch_indexes, false] end [x_batch, y_batch] end # Reset input datas and output datas. def reset @has_next = true @indexes = @num_datas.times.to_a @indexes.shuffle! if @random end # Return the true if has next batch. def has_next? @has_next end def foreach(batch_size, &block) step = 0 while has_next? x_batch, y_batch = next_batch(batch_size) block.call(x_batch, y_batch, step) step += 1 end reset end end end
Version data entries
5 entries across 5 versions & 1 rubygems