examples/dcgan/dcgan.rb in ruby-dnn-0.13.4 vs examples/dcgan/dcgan.rb in ruby-dnn-0.14.0

- old
+ new

@@ -1,10 +1,7 @@ -include DNN::Layers -include DNN::Activations -include DNN::Optimizers -include DNN::Losses include DNN::Models +include DNN::Layers class Generator < Model def initialize super @l1 = Dense.new(1024) @@ -53,11 +50,10 @@ x = Tanh.(x) x end end - class Discriminator < Model def initialize super @l1 = Conv2D.new(32, 4, strides: 2, padding: true) @l2 = Conv2D.new(32, 4, padding: true) @@ -95,22 +91,36 @@ x = @l6.(x) x end end - class DCGAN < Model - attr_reader :gen - attr_reader :dis + attr_accessor :gen + attr_accessor :dis - def initialize(gen, dis) + def initialize(gen = nil, dis = nil) super() @gen = gen @dis = dis end def call(x) x = @gen.(x) x = @dis.(x, false) x + end + + def train_step(x_batch, y_batch) + batch_size = x_batch.shape[0] + noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1) + images = @gen.predict(noise) + x = x_batch.concatenate(images) + y = Numo::SFloat.cast([1] * batch_size + [0] * batch_size).reshape(batch_size * 2, 1) + dis_loss = @dis.train_on_batch(x, y) + + noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1) + label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1) + dcgan_loss = train_on_batch(noise, label) + + { dis_loss: dis_loss.mean, dcgan_loss: dcgan_loss.mean } end end