Sha256: 0002d6ad1a160b36433a433111406bed5e17b9b3a1ede3ec7f2a4e7480ab385c

Contents?: true

Size: 771 Bytes

Versions: 18

Compression:

Stored size: 771 Bytes

Contents

require "dnn"
require "dnn/datasets/mnist"
require "numo/linalg/autoloader"
require_relative "dcgan"

include DNN::Optimizers
include DNN::Losses
include DNN::Callbacks

Numo::SFloat.srand(rand(1 << 31))

epochs = 20
batch_size = 128

gen = Generator.new
dis = Discriminator.new
dcgan = DCGAN.new(gen, dis)

dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new)
dcgan.add_callback(CheckPoint.new("trained/dcgan_model"))

x_train, * = DNN::MNIST.load_train
x_train = Numo::SFloat.cast(x_train)
x_train = x_train / 127.5 - 1

iter = DNN::Iterator.new(x_train, x_train, last_round_down: true)
dcgan.fit_by_iterator(iter, epochs, batch_size: batch_size)

Version data entries

18 entries across 18 versions & 1 rubygems

Version Path
ruby-dnn-1.3.0 examples/dcgan/train.rb
ruby-dnn-1.2.3 examples/dcgan/train.rb
ruby-dnn-1.2.2 examples/dcgan/train.rb
ruby-dnn-1.2.1 examples/dcgan/train.rb
ruby-dnn-1.2.0 examples/dcgan/train.rb
ruby-dnn-1.1.6 examples/dcgan/train.rb
ruby-dnn-1.1.5 examples/dcgan/train.rb
ruby-dnn-1.1.4 examples/dcgan/train.rb
ruby-dnn-1.1.3 examples/dcgan/train.rb
ruby-dnn-1.1.2 examples/dcgan/train.rb
ruby-dnn-1.1.1 examples/dcgan/train.rb
ruby-dnn-1.1.0 examples/dcgan/train.rb
ruby-dnn-1.0.0 examples/dcgan/train.rb
ruby-dnn-0.16.2 examples/dcgan/train.rb
ruby-dnn-0.16.1 examples/dcgan/train.rb
ruby-dnn-0.16.0 examples/dcgan/train.rb
ruby-dnn-0.15.3 examples/dcgan/train.rb
ruby-dnn-0.15.2 examples/dcgan/train.rb