examples/dcgan/train.rb in ruby-dnn-0.12.4 vs examples/dcgan/train.rb in ruby-dnn-0.13.0
- old
+ new
@@ -1,7 +1,7 @@
require "dnn"
-require "dnn/mnist"
+require "dnn/datasets/mnist"
require "numo/linalg/autoloader"
require_relative "dcgan"
MNIST = DNN::MNIST
@@ -20,13 +20,15 @@
x_train, y_train = MNIST.load_train
x_train = Numo::SFloat.cast(x_train)
x_train = x_train / 127.5 - 1
iter = DNN::Iterator.new(x_train, y_train)
+num_batchs = x_train.shape[0] / batch_size
(1..epochs).each do |epoch|
puts "epoch: #{epoch}"
- iter.foreach(batch_size) do |x_batch, y_batch, index|
+ num_batchs.times do |index|
+ x_batch, y_batch = iter.next_batch(batch_size)
noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
images = gen.predict(noise)
x = x_batch.concatenate(images)
y = Numo::SFloat.cast([1] * batch_size + [0] * batch_size).reshape(batch_size * 2, 1)
dis_loss = dis.train_on_batch(x, y)
@@ -35,7 +37,8 @@
label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
dcgan_loss = dcgan.train_on_batch(noise, label)
puts "index: #{index}, dis_loss: #{dis_loss.mean}, dcgan_loss: #{dcgan_loss.mean}"
end
+ iter.reset
dcgan.save("trained/dcgan_model_epoch#{epoch}.marshal")
end