examples/pix2pix/train.rb in ruby-dnn-1.1.1 vs examples/pix2pix/train.rb in ruby-dnn-1.1.2
- old
+ new
@@ -15,41 +15,48 @@
x_in = (x_in / 127.5) - 1
x_out = (x_out / 127.5) - 1
[x_in, x_out]
end
+initial_epoch = 1
+
epochs = 20
batch_size = 128
-gen = Generator.new([32, 32, 1])
-dis = Discriminator.new([32, 32, 1], [32, 32, 3])
-dcgan = DCGAN.new(gen, dis)
+if initial_epoch == 1
+ gen = Generator.new([32, 32, 1])
+ dis = Discriminator.new([32, 32, 1], [32, 32, 3])
+ dcgan = DCGAN.new(gen, dis)
+ gen.setup(Adam.new(alpha: 0.0002, beta1: 0.5), MeanAbsoluteError.new)
+ dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
+ dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5),
+ [MeanAbsoluteError.new, SigmoidCrossEntropy.new], loss_weights: [10, 1])
+else
+ dcgan = DCGAN.load("trained/dcgan_model_epoch#{initial_epoch - 1}.marshal")
+ gen = dcgan.gen
+ dis = dcgan.dis
+end
-gen.setup(Adam.new(alpha: 0.0002, beta1: 0.5), MeanAbsoluteError.new)
-dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
-dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
-
x_in, x_out = load_dataset
iter1 = DNN::Iterator.new(x_in, x_out)
iter2 = DNN::Iterator.new(x_in, x_out)
num_batchs = x_in.shape[0] / batch_size
-(1..epochs).each do |epoch|
+(initial_epoch..epochs).each do |epoch|
num_batchs.times do |index|
x_in, x_out = iter1.next_batch(batch_size)
- gen_loss = gen.train_on_batch(x_in, x_out)
- images = gen.generate_images
+ images = gen.predict(x_in)
y_real = Numo::SFloat.ones(batch_size, 1)
y_fake = Numo::SFloat.zeros(batch_size, 1)
dis.enable_training
dis_loss = dis.train_on_batch([x_in, x_out], y_real)
dis_loss += dis.train_on_batch([x_in, images], y_fake)
x_in, x_out = iter2.next_batch(batch_size)
- dcgan_loss = dcgan.train_on_batch(x_in, y_real)
+ dcgan_loss = dcgan.train_on_batch(x_in, [x_out, y_real])
- puts "epoch: #{epoch}, index: #{index}, gen_loss: #{gen_loss}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
+ puts "epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
end
iter1.reset
iter2.reset
dcgan.save("trained/dcgan_model_epoch#{epoch}.marshal")
end