Sha256: 7c490018e0acadfd64fa73009c7de5daef6ad2109aafad1a27b0f94c1e56250d

Contents?: true

Size: 1.41 KB

Versions: 15

Compression:

Stored size: 1.41 KB

Contents

require File.join(File.expand_path(File.dirname(__FILE__)), '../../..', 'test_helper.rb')
require 'rbbt/vector/model/tensorflow'

class TestTensorflowModel < Test::Unit::TestCase

  def test_keras
    Log.severity = 0
    TmpFile.with_file() do |dir|
      FileUtils.mkdir_p dir

      model = TensorFlowModel.new(
        dir, 
        optimizer: 'adam',
        loss: 'sparse_categorical_crossentropy',
        metrics: ['accuracy']
      )

      model.keras_graph do
        tf = tensorflow
        tf.keras.models.Sequential.new([
          tf.keras.layers.Flatten.new(input_shape: [28, 28]),
          tf.keras.layers.Dense.new(128, activation:'relu'),
          tf.keras.layers.Dropout.new(0.2),
          tf.keras.layers.Dense.new(10, activation:'softmax')
        ])
      end

      sum = predictions = nil
      model.tensorflow do
        tf = tensorflow
        mnist_db = tf.keras.datasets.mnist

        (x_train, y_train), (x_test, y_test) = mnist_db.load_data()
        x_train, x_test = x_train / 255.0, x_test / 255.0

        num = PyCall.len(x_train)

        num.times do |i|
          model.add x_train[i], y_train[i]
        end

        model.train

        predictions = model.eval_list x_test.tolist()
        sum = 0

        predictions.zip(y_test.tolist()).each do |pred,label|
          sum += 1 if label.to_i == pred
        end

      end

      assert sum.to_f / predictions.length > 0.7
    end
  end
end

Version data entries

15 entries across 15 versions & 1 rubygems

Version Path
rbbt-dm-1.2.7 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.2.6 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.2.4 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.2.3 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.2.1 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.63 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.62 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.61 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.60 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.59 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.58 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.57 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.56 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.55 test/rbbt/vector/model/test_tensorflow.rb
rbbt-dm-1.1.54 test/rbbt/vector/model/test_tensorflow.rb