lib/mnist-learn.rb in mnist-learn-0.1.0 vs lib/mnist-learn.rb in mnist-learn-0.1.1
- old
+ new
@@ -64,11 +64,11 @@
def images
@all_images ||= load_images[2]
end
def labels
- @all_labels ||= load_labels
+ @all_labels ||= (@one_hot ? load_labels.map { |label_data| one_hot_transform(label_data) } : load_labels)
end
def next_batch(batch_size)
if @index == 0
@rows, @columns, @images = load_images
@@ -76,15 +76,15 @@
end
images = []
labels = []
batch_size.times.each do
next if @index >= @total_count
- image_data = @images[@index]
- label_data = @labels[@index]
- image_data.map! { |b| b.to_f / 255.0 }
- @index += 1
- images << image_data
- labels << (@one_hot ? one_hot_transform(label_data) : label_data.to_f)
+ image_data = @images[@index]
+ label_data = @labels[@index]
+ image_data.map! { |b| b.to_f / 255.0 }
+ @index += 1
+ images << image_data
+ labels << (@one_hot ? one_hot_transform(label_data) : label_data.to_f)
end
[images, labels]
end
private