examples/dcgan/dcgan.rb in ruby-dnn-0.16.0 vs examples/dcgan/dcgan.rb in ruby-dnn-0.16.1
- old
+ new
@@ -61,18 +61,11 @@
@l4 = Conv2D.new(64, 4, padding: true)
@l5 = Dense.new(1024)
@l6 = Dense.new(1)
end
- def forward(x, trainable = true)
- @l1.trainable = trainable
- @l2.trainable = trainable
- @l3.trainable = trainable
- @l4.trainable = trainable
- @l5.trainable = trainable
- @l6.trainable = trainable
-
+ def forward(x)
x = InputLayer.new([28, 28, 1]).(x)
x = @l1.(x)
x = LeakyReLU.(x, 0.2)
x = @l2.(x)
@@ -89,10 +82,22 @@
x = LeakyReLU.(x, 0.2)
x = @l6.(x)
x
end
+
+ def enable_training
+ trainable_layers.each do |layer|
+ layer.trainable = true
+ end
+ end
+
+ def disable_training
+ trainable_layers.each do |layer|
+ layer.trainable = false
+ end
+ end
end
class DCGAN < Model
attr_accessor :gen
attr_accessor :dis
@@ -103,19 +108,21 @@
@dis = dis
end
def forward(x)
x = @gen.(x)
- x = @dis.(x, false)
+ @dis.disable_training
+ x = @dis.(x)
x
end
def train_step(x_batch, y_batch)
batch_size = x_batch.shape[0]
noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
images = @gen.predict(noise)
x = x_batch.concatenate(images)
y = Numo::SFloat.cast([1] * batch_size + [0] * batch_size).reshape(batch_size * 2, 1)
+ @dis.enable_training
dis_loss = @dis.train_on_batch(x, y)
noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
dcgan_loss = train_on_batch(noise, label)