examples/dcgan/dcgan.rb in ruby-dnn-0.16.0 vs examples/dcgan/dcgan.rb in ruby-dnn-0.16.1

- old
+ new

@@ -61,18 +61,11 @@ @l4 = Conv2D.new(64, 4, padding: true) @l5 = Dense.new(1024) @l6 = Dense.new(1) end - def forward(x, trainable = true) - @l1.trainable = trainable - @l2.trainable = trainable - @l3.trainable = trainable - @l4.trainable = trainable - @l5.trainable = trainable - @l6.trainable = trainable - + def forward(x) x = InputLayer.new([28, 28, 1]).(x) x = @l1.(x) x = LeakyReLU.(x, 0.2) x = @l2.(x) @@ -89,10 +82,22 @@ x = LeakyReLU.(x, 0.2) x = @l6.(x) x end + + def enable_training + trainable_layers.each do |layer| + layer.trainable = true + end + end + + def disable_training + trainable_layers.each do |layer| + layer.trainable = false + end + end end class DCGAN < Model attr_accessor :gen attr_accessor :dis @@ -103,19 +108,21 @@ @dis = dis end def forward(x) x = @gen.(x) - x = @dis.(x, false) + @dis.disable_training + x = @dis.(x) x end def train_step(x_batch, y_batch) batch_size = x_batch.shape[0] noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1) images = @gen.predict(noise) x = x_batch.concatenate(images) y = Numo::SFloat.cast([1] * batch_size + [0] * batch_size).reshape(batch_size * 2, 1) + @dis.enable_training dis_loss = @dis.train_on_batch(x, y) noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1) label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1) dcgan_loss = train_on_batch(noise, label)