ruby-nnは、rubyで書かれたニューラルネットワークライブラリです。 python向けの本格的なディープラーニングライブラリと比べると、性能や機能面で、大きく見劣りしますが、 MNISTで98%以上の精度を出せるぐらいの性能はあります。 なお、ruby-nnはNumo/NArrayを使用しています。 そのため、ruby-nnの使用には、Numo/NArrayのインストールが必要です。 [リファレンス] class NN ニューラルネットワークを扱うクラスです。 <クラスメソッド> load(file_name) : NN Marshal形式で保存された学習結果を読み込みます。 String file_name 読み込むMarshalファイル名 戻り値 NNのインスタンス load_json(file_name) : NN JSON形式で保存された学習結果を読み込みます。 String file_name 読み込むJSONファイル名 戻り値 NNのインスタンス <プロパティ> Array weights ネットワークの重みをSFloat形式で取得します。 Array biases ネットワークのバイアスをSFloat形式で取得します。 Array gammas バッチノーマライゼーションを使用している場合、gammaを取得します。 Array betas バッチノーマライゼーションを使用している場合、betaを取得します。 Float learning_rate 学習率 Integer batch_size ミニバッチの数 Array activation 活性化関数。配列の要素1が中間層の活性化関数で要素2が隠れ層の活性化関数です。 中間層には、:sigmoidまたは:relu、出力層には、:identityまたは:softmaxが使用できます。 Float momentum モーメンタム係数 Float weight_decay L2正則化項の強さ Float dropout_ratio ドロップアウトさせるノードの比率 <インスタンスメソッド> initialize(num_nodes, learning_rate: 0.01, batch_size: 1, activation: [:relu, :identity], momentum: 0, weight_decay: 0, use_dropout: false, dropout_ratio: 0.5, use_batch_norm: false) オブジェクトを初期化します。 Array num_nodes 各層のノード数 Float learning_rate 学習率 Integer batch_size ミニバッチの数 Array activation 活性化関数。配列の要素1が中間層の活性化関数で要素2が隠れ層の活性化関数です。 中間層には、:sigmoidまたは:relu、出力層には、:identityまたは:softmaxが使用できます。 Float momentum モーメンタム係数 Float weight_decay L2正則化項の強さ bool use_dropout ドロップアウトを使用するか否か Float dropout_ratio ドロップアウトさせるノードの比率 bool use_batch_norm バッチノーマライゼーションを使用するか否か train(x_train, y_train, x_test, y_test, epochs, func = nil, &block) : void 学習を行います。 Array> | SFloat x_train トレーニング用入力データ。 Array> | SFloat y_train トレーニング用正解データ。 Integer epochs 学習回数。入力データすべてを見たタイミングを1エポックとします。 Float learning_rate_decay 1エポックごとに学習率を減衰される値。 String save_dir 学習中にセーブを行う場合、セーブするディレクトリを指定します。nilの場合、セーブを行いません。 Integer save_interval 学習中にセーブするタイミングをエポック単位で指定します。 Array> | SFloat> test テストで使用するデータ。[x_test, y_test]の形式で指定してください。 nilを指定すると、エポックごとにテストを行いません。 Float border 学習の早期終了判定に使用するテストデータの正答率。 nilの場合、学習の早期終了を行いません。 Proc func(SFloat x, SFloat y) : Array 入力層のミニバッチを取得します。ブロックの戻り値は、ミニバッチを[x, y]の 形で指定してください。入力層をミニバッチ単位で正規化したい場合に使用します。 Proc block(Integer epoch) : void 1エポックの学習が終わった後で行いたい処理を、ブロックで渡します。 test(x_test, y_test, tolerance = 0.5, &block) : Float テストデータを用いて、テストを行います。 Array> | SFloat x_train テスト用入力データ。 Array> | SFloat y_train テスト用正解データ。 Float tolerance 許容する誤差。出力層の活性化関数が:identityの場合に使用します。 例えば出力が0.7で正解が1.0の場合、toleranceが0.4なら合格となり、0.2なら不合格となります。 Proc &block(SFloat x, SFloat y) : Array 入力層のミニバッチを取得します。ブロックの戻り値は、ミニバッチを[x, y]の 形で指定してください。入力層をミニバッチ単位で正規化したい場合に使用します。 戻り値 テストデータの正答率。 accurate(x_test, y_test, tolera) テストデータを用いて、テストデータの正答率を取得します。 Array> | SFloat x_train テスト用入力データ。 Array> | SFloat y_train テスト用正解データ。 Float tolerance 許容する誤差。出力層の活性化関数が:identityの場合に使用します。 例えば出力が0.7で正解が1.0の場合、toleranceが0.4なら合格となり、0.2なら不合格となります。 Proc &block(SFloat x, SFloat y) : Array 入力層のミニバッチを取得します。ブロックの戻り値は、ミニバッチを[x, y]の 形で指定してください。入力層をミニバッチ単位で正規化したい場合に使用します。 戻り値 テストデータの正答率。 learn(x_train, y_train, &block) : Float 入力データを元に、1回だけ学習を行います。柔軟な学習を行いたい場合に使用します。 Array> | SFloat x_train 入力データ Array> | SFloat y_train 正解データ Proc &block(SFloat x, SFloat y) : Array 入力層のミニバッチを取得します。ブロックの戻り値は、ミニバッチを[x, y]の 形で指定してください。入力層をミニバッチ単位で正規化したい場合に使用します。 戻り値 誤差関数の値。誤差関数は、出力層の活性化関数が:identityの場合、二乗和誤差が、 :softmaxの場合、クロスエントロピー誤差が使用されます。なお、L2正則化を使用する場合、 誤差関数の値には正則化項の値が含まれます。 run(x) : Array> 入力データから出力値を二次元配列で得ます。 Array> x 入力データ 戻り値 出力ノードの値 run(x) : SFloat 入力データから出力値をSFloat形式で得ます。 SFloat x 入力データ 戻り値 出力ノードの値 save(file_name) : void 学習結果をMarshal形式で保存します。 String file_name 書き込むMarshalファイル名 save_json(file_name) : void 学習結果をJSON形式で保存します。 String file_name 書き込むJSONファイル名 [MNISTデータを読み込む] MNISTをRubyでも簡単に試せるよう、MNISTを扱うためのモジュールを用意しました。 次のリンク(http://yann.lecun.com/exdb/mnist/)から、 train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz t10k-labels-idx1-ubyte.gz の4つのファイルをダウンロードし、実行するRubyファイルと同じ階層のmnistディレクトリに格納したうえで、使用してください。 MNIST.load_trainで学習用データを読み込み、MNIST.load_testでテスト用データを読み込みます。 また、MNIST.categorycalを使用すると、正解データを10クラスにカテゴライズされた上で、配列形式で返します。 (RubyでのMNISTの読み込みは、以下のリンクを参考にさせていただきました。) http://d.hatena.ne.jp/n_shuyo/20090913/mnist [お断り] 作者は、ニューラルネットワークを勉強し始めたばかりの初心者です。 そのため、バグや実装のミスもあるかと思いますが、温かい目で見守っていただけると、幸いでございます。 [更新履歴] 2018/3/8 バージョン1.0公開 2018/3/11 バージョン1.1公開 2018/3/13 バージョン1.2公開 2018/3/14 バージョン1.3公開 2018/3/18 バージョン1.4公開 2018/3/22 バージョン1.5公開 2018/4/15 バージョン1.6公開 2018/5/4 バージョン1.8公開 2018/5/16 バージョン2.0公開 2018/6/10 バージョン2.0.1公開 2018/6/10 バージョン2.1.0公開 2018/6/10 バージョン2.2.0公開