Sha256: 67ef6b02e430036bd9a85915c2a05fd28fcde381b6dbada40c9e37214ab6f659

Contents?: true

Size: 1.04 KB

Versions: 6

Compression:

Stored size: 1.04 KB

Contents

#ライブラリの読み込み
require "nn"
require "nn/mnist"

#MNISTのトレーニング用データを読み込む
x_train, y_train = MNIST.load_train

#y_trainを10クラスに配列でカテゴライズする
y_train = MNIST.categorical(y_train)

#MNISTのテスト用データを読み込む
x_test, y_test = MNIST.load_test

#y_testを10クラスにカテゴライズする
y_test = MNIST.categorical(y_test)

puts "load mnist"

#ニューラルネットワークの初期化
nn = NN.new([784, 100, 100, 10], #ノード数
  learning_rate: 0.1, #学習率
  batch_size: 100, #ミニバッチの数
  activation: [:relu, :softmax], #活性化関数
  momentum: 0.9, #モーメンタム係数
  use_batch_norm: true, #バッチノーマライゼーションを使用する
)

#ミニバッチを0~1の範囲で正規化
func = -> x_batch, y_batch do
  x_batch /= 255
  [x_batch, y_batch]
end

#学習を行う
nn.train(x_train, y_train, 10, func) do
  #学習結果のテストを行う
  nn.test(x_test, y_test, &func)
end

Version data entries

6 entries across 6 versions & 1 rubygems

Version Path
nn-2.4.0 sample/mnist_program.rb
nn-2.3.0 sample/mnist_program.rb
nn-2.2.0 sample/mnist_program.rb
nn-2.1.0 sample/mnist_program.rb
nn-2.0.0 sample/mnist_program.rb
nn-2.0.1 sample/mnist_program.rb