Sha256: 86878a6f2c524a33a54072e9acb8eee6477af96f2d250788059c7a37819d4789

Contents?: true

Size: 1.5 KB

Versions: 11

Compression:

Stored size: 1.5 KB

Contents

require "dnn"
require "dnn/datasets/iris"
# If you use numo/linalg then please uncomment out.
# require "numo/linalg/autoloader"

include DNN::Layers
include DNN::Optimizers
include DNN::Losses

x, y = DNN::Iris.load(true)
x_train, y_train = x[0...100, true], y[0...100]
x_test, y_test = x[100...150, true], y[100...150]

y_train = DNN::Utils.to_categorical(y_train, 3, Numo::SFloat)
y_test = DNN::Utils.to_categorical(y_test, 3, Numo::SFloat)

epochs = 1000
batch_size = 32

opt = Adam.new
lf = SoftmaxCrossEntropy.new

train_iter = DNN::Iterator.new(x_train, y_train)
test_iter = DNN::Iterator.new(x_test, y_test, random: false)

w1 = DNN::Param.new(Numo::SFloat.new(4, 16).rand_norm)
b1 = DNN::Param.new(Numo::SFloat.zeros(16))
w2 = DNN::Param.new(Numo::SFloat.new(16, 3).rand_norm)
b2 = DNN::Param.new(Numo::SFloat.zeros(3))

net = -> x, y do
  h = Dot.(x, w1) + b1
  h = Sigmoid.(h)
  out = Dot.(h, w2) + b2
  out
end

(1..epochs).each do |epoch|
  train_iter.foreach(batch_size) do |x_batch, y_batch, step|
    x = DNN::Tensor.convert(x_batch)
    y = DNN::Tensor.convert(y_batch)
    out = net.(x, y)
    loss = lf.(out, y)
    loss.link.backward
    puts "epoch: #{epoch}, step: #{step}, loss = #{loss.data.to_f}"
    opt.update([w1, b1, w2, b2])
  end
end

correct = 0
test_iter.foreach(batch_size) do |x_batch, y_batch, step|
  x = DNN::Tensor.convert(x_batch)
  y = DNN::Tensor.convert(y_batch)
  out = net.(x, y)
  correct += out.data.max_index(axis: 1).eq(y_batch.max_index(axis: 1)).count
end
puts "correct = #{correct}"

Version data entries

11 entries across 11 versions & 1 rubygems

Version Path
ruby-dnn-1.3.0 examples/iris_example_unused_model.rb
ruby-dnn-1.2.3 examples/iris_example.rb
ruby-dnn-1.2.2 examples/iris_example.rb
ruby-dnn-1.2.1 examples/iris_example.rb
ruby-dnn-1.2.0 examples/iris_example.rb
ruby-dnn-1.1.6 examples/iris_example.rb
ruby-dnn-1.1.5 examples/iris_example.rb
ruby-dnn-1.1.4 examples/iris_example.rb
ruby-dnn-1.1.3 examples/iris_example.rb
ruby-dnn-1.1.2 examples/iris_example.rb
ruby-dnn-1.1.1 examples/iris_example.rb