lib/mnist-learn.rb in mnist-learn-0.1.1 vs lib/mnist-learn.rb in mnist-learn-0.1.2

- old
+ new

@@ -67,11 +67,11 @@ def labels @all_labels ||= (@one_hot ? load_labels.map { |label_data| one_hot_transform(label_data) } : load_labels) end - def next_batch(batch_size) + def next(batch_size) if @index == 0 @rows, @columns, @images = load_images @labels = load_labels end images = [] @@ -84,9 +84,26 @@ @index += 1 images << image_data labels << (@one_hot ? one_hot_transform(label_data) : label_data.to_f) end [images, labels] + end + + def next_batch(batch_size, rnd: Random.new) + @data_set ||= begin + rows, columns, images = load_images + labels = load_labels + Array.new(images.size) do + image_data = images[@index] + label_data = labels[@index] + image_data.map! { |b| b.to_f / 255.0 } + @index += 1 + [image_data, (@one_hot ? one_hot_transform(label_data) : label_data.to_f)] + end + end + @data_set.shuffle!(random: rnd) + batch = @data_set[0...batch_size] + [batch.map { |v| v[0]}, batch.map { |v| v[1]}] end private def one_hot_transform(label)