Sha256: 0122eb6d4f524f8807f24c5f789ae50dd7c7087dc17844537e4b4430e751cec4

Contents?: true

Size: 705 Bytes

Versions: 5

Compression:

Stored size: 705 Bytes

Contents

require "nn"
require "nn/cifar10"

x_train = []
y_train = []

(1..5).each do |i|
  x_train2, y_train2 = CIFAR10.load_train(i)
  x_train.concat(x_train2)
  y_train.concat(CIFAR10.categorical(y_train2))
end
GC.start

x_test, y_test = CIFAR10.load_test
y_test = CIFAR10.categorical(y_test)
GC.start

puts "load cifar10"

nn = NN.new([3072, 100, 100, 10],
  learning_rate: 0.1,
  batch_size: 32,
  activation: [:relu, :softmax],
  momentum: 0.9,
  use_dropout: true,
  dropout_ratio: 0.2,
  use_batch_norm: true,
)

func = -> x, y do
  x /= 255
  [x, y]
end

nn.train(x_train, y_train, 20, func) do |epoch|
  nn.test(x_test, y_test, &func)
  nn.learning_rate *= 0.99
end

Version data entries

5 entries across 5 versions & 1 rubygems

Version Path
nn-2.4.0 sample/cifar10_program.rb
nn-2.3.0 sample/cifar10_program.rb
nn-2.2.0 sample/cifar10_program.rb
nn-2.1.0 sample/cifar10_program.rb
nn-2.0.0 sample/cifar10_program.rb