examples/pix2pix/train.rb in ruby-dnn-1.1.1 vs examples/pix2pix/train.rb in ruby-dnn-1.1.2

- old
+ new

@@ -15,41 +15,48 @@ x_in = (x_in / 127.5) - 1 x_out = (x_out / 127.5) - 1 [x_in, x_out] end +initial_epoch = 1 + epochs = 20 batch_size = 128 -gen = Generator.new([32, 32, 1]) -dis = Discriminator.new([32, 32, 1], [32, 32, 3]) -dcgan = DCGAN.new(gen, dis) +if initial_epoch == 1 + gen = Generator.new([32, 32, 1]) + dis = Discriminator.new([32, 32, 1], [32, 32, 3]) + dcgan = DCGAN.new(gen, dis) + gen.setup(Adam.new(alpha: 0.0002, beta1: 0.5), MeanAbsoluteError.new) + dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new) + dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), + [MeanAbsoluteError.new, SigmoidCrossEntropy.new], loss_weights: [10, 1]) +else + dcgan = DCGAN.load("trained/dcgan_model_epoch#{initial_epoch - 1}.marshal") + gen = dcgan.gen + dis = dcgan.dis +end -gen.setup(Adam.new(alpha: 0.0002, beta1: 0.5), MeanAbsoluteError.new) -dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new) -dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new) - x_in, x_out = load_dataset iter1 = DNN::Iterator.new(x_in, x_out) iter2 = DNN::Iterator.new(x_in, x_out) num_batchs = x_in.shape[0] / batch_size -(1..epochs).each do |epoch| +(initial_epoch..epochs).each do |epoch| num_batchs.times do |index| x_in, x_out = iter1.next_batch(batch_size) - gen_loss = gen.train_on_batch(x_in, x_out) - images = gen.generate_images + images = gen.predict(x_in) y_real = Numo::SFloat.ones(batch_size, 1) y_fake = Numo::SFloat.zeros(batch_size, 1) dis.enable_training dis_loss = dis.train_on_batch([x_in, x_out], y_real) dis_loss += dis.train_on_batch([x_in, images], y_fake) x_in, x_out = iter2.next_batch(batch_size) - dcgan_loss = dcgan.train_on_batch(x_in, y_real) + dcgan_loss = dcgan.train_on_batch(x_in, [x_out, y_real]) - puts "epoch: #{epoch}, index: #{index}, gen_loss: #{gen_loss}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}" + puts "epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}" end iter1.reset iter2.reset dcgan.save("trained/dcgan_model_epoch#{epoch}.marshal") end