Sha256: 65b7599e9770a1fe47c0d5bdd59f07324cefd617ee83e2623ddae569a1695b48
Contents?: true
Size: 786 Bytes
Versions: 2
Compression:
Stored size: 786 Bytes
Contents
require "dnn" require "dnn/datasets/mnist" require "numo/linalg/autoloader" require_relative "dcgan" include DNN::Optimizers include DNN::Losses include DNN::Callbacks MNIST = DNN::MNIST Numo::SFloat.srand(rand(1 << 31)) epochs = 20 batch_size = 128 gen = Generator.new dis = Discriminator.new dcgan = DCGAN.new(gen, dis) dis.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new) dcgan.setup(Adam.new(alpha: 0.0002, beta1: 0.5), SigmoidCrossEntropy.new) dcgan.add_callback(CheckPoint.new("trained/dcgan_model")) x_train, * = MNIST.load_train x_train = Numo::SFloat.cast(x_train) x_train = x_train / 127.5 - 1 iter = DNN::Iterator.new(x_train, x_train, last_round_down: true) dcgan.fit_by_iterator(iter, epochs, batch_size: batch_size)
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
ruby-dnn-0.15.1 | examples/dcgan/train.rb |
ruby-dnn-0.15.0 | examples/dcgan/train.rb |