Sha256: 738e001aa66c607a962e43018947ea3250990fd15c68d12af264cf07e2421a44
Contents?: true
Size: 1003 Bytes
Versions: 15
Compression:
Stored size: 1003 Bytes
Contents
require 'rbbt/util/python' module RbbtTensorflow def self.init RbbtPython.run do pyimport "tensorflow", as: "tf" end end def self.test mod = x_test = y_test = nil RbbtPython.run do mnist_db = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist_db.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 mod = tf.keras.models.Sequential.new([ tf.keras.layers.Flatten.new(input_shape: [28, 28]), tf.keras.layers.Dense.new(128, activation:'relu'), tf.keras.layers.Dropout.new(0.2), tf.keras.layers.Dense.new(10, activation:'softmax') ]) mod.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) mod.fit(x_train, y_train, epochs:3) mod end RbbtPython.run do mod.evaluate(x_test, y_test, verbose:2) end end end if __FILE__ == $0 RbbtTensorflow.init RbbtTensorflow.test end
Version data entries
15 entries across 15 versions & 1 rubygems