lib/mnist-learn.rb in mnist-learn-0.1.1 vs lib/mnist-learn.rb in mnist-learn-0.1.2
- old
+ new
@@ -67,11 +67,11 @@
def labels
@all_labels ||= (@one_hot ? load_labels.map { |label_data| one_hot_transform(label_data) } : load_labels)
end
- def next_batch(batch_size)
+ def next(batch_size)
if @index == 0
@rows, @columns, @images = load_images
@labels = load_labels
end
images = []
@@ -84,9 +84,26 @@
@index += 1
images << image_data
labels << (@one_hot ? one_hot_transform(label_data) : label_data.to_f)
end
[images, labels]
+ end
+
+ def next_batch(batch_size, rnd: Random.new)
+ @data_set ||= begin
+ rows, columns, images = load_images
+ labels = load_labels
+ Array.new(images.size) do
+ image_data = images[@index]
+ label_data = labels[@index]
+ image_data.map! { |b| b.to_f / 255.0 }
+ @index += 1
+ [image_data, (@one_hot ? one_hot_transform(label_data) : label_data.to_f)]
+ end
+ end
+ @data_set.shuffle!(random: rnd)
+ batch = @data_set[0...batch_size]
+ [batch.map { |v| v[0]}, batch.map { |v| v[1]}]
end
private
def one_hot_transform(label)