Sha256: 0122eb6d4f524f8807f24c5f789ae50dd7c7087dc17844537e4b4430e751cec4
Contents?: true
Size: 705 Bytes
Versions: 5
Compression:
Stored size: 705 Bytes
Contents
require "nn" require "nn/cifar10" x_train = [] y_train = [] (1..5).each do |i| x_train2, y_train2 = CIFAR10.load_train(i) x_train.concat(x_train2) y_train.concat(CIFAR10.categorical(y_train2)) end GC.start x_test, y_test = CIFAR10.load_test y_test = CIFAR10.categorical(y_test) GC.start puts "load cifar10" nn = NN.new([3072, 100, 100, 10], learning_rate: 0.1, batch_size: 32, activation: [:relu, :softmax], momentum: 0.9, use_dropout: true, dropout_ratio: 0.2, use_batch_norm: true, ) func = -> x, y do x /= 255 [x, y] end nn.train(x_train, y_train, 20, func) do |epoch| nn.test(x_test, y_test, &func) nn.learning_rate *= 0.99 end
Version data entries
5 entries across 5 versions & 1 rubygems
Version | Path |
---|---|
nn-2.4.0 | sample/cifar10_program.rb |
nn-2.3.0 | sample/cifar10_program.rb |
nn-2.2.0 | sample/cifar10_program.rb |
nn-2.1.0 | sample/cifar10_program.rb |
nn-2.0.0 | sample/cifar10_program.rb |