examples/dcgan/dcgan.rb in ruby-dnn-0.13.4 vs examples/dcgan/dcgan.rb in ruby-dnn-0.14.0
- old
+ new
@@ -1,10 +1,7 @@
-include DNN::Layers
-include DNN::Activations
-include DNN::Optimizers
-include DNN::Losses
include DNN::Models
+include DNN::Layers
class Generator < Model
def initialize
super
@l1 = Dense.new(1024)
@@ -53,11 +50,10 @@
x = Tanh.(x)
x
end
end
-
class Discriminator < Model
def initialize
super
@l1 = Conv2D.new(32, 4, strides: 2, padding: true)
@l2 = Conv2D.new(32, 4, padding: true)
@@ -95,22 +91,36 @@
x = @l6.(x)
x
end
end
-
class DCGAN < Model
- attr_reader :gen
- attr_reader :dis
+ attr_accessor :gen
+ attr_accessor :dis
- def initialize(gen, dis)
+ def initialize(gen = nil, dis = nil)
super()
@gen = gen
@dis = dis
end
def call(x)
x = @gen.(x)
x = @dis.(x, false)
x
+ end
+
+ def train_step(x_batch, y_batch)
+ batch_size = x_batch.shape[0]
+ noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
+ images = @gen.predict(noise)
+ x = x_batch.concatenate(images)
+ y = Numo::SFloat.cast([1] * batch_size + [0] * batch_size).reshape(batch_size * 2, 1)
+ dis_loss = @dis.train_on_batch(x, y)
+
+ noise = Numo::SFloat.new(batch_size, 20).rand(-1, 1)
+ label = Numo::SFloat.cast([1] * batch_size).reshape(batch_size, 1)
+ dcgan_loss = train_on_batch(noise, label)
+
+ { dis_loss: dis_loss.mean, dcgan_loss: dcgan_loss.mean }
end
end