Sha256: c331711e6847d832d03ef42de9186517605f8c2175cfddad243f20c73335dccc

Contents?: true

Size: 1.49 KB

Versions: 6

Compression:

Stored size: 1.49 KB

Contents

require "dnn"
require "numo/linalg/autoloader"

include DNN::Models
include DNN::Layers
include DNN::Optimizers
include DNN::Losses

class ConvNet < Model
  def self.create(input_shape)
    convnet = ConvNet.new(input_shape, 32)
    convnet.setup(Adam.new, SoftmaxCrossEntropy.new)
    convnet
  end

  def initialize(input_shape, base_filter_size)
    super()
    @input_shape = input_shape
    @cv1 = Conv2D.new(base_filter_size, 3, padding: true)
    @cv2 = Conv2D.new(base_filter_size, 3, padding: true)
    @cv3 = Conv2D.new(base_filter_size * 2, 3, padding: true)
    @cv4 = Conv2D.new(base_filter_size * 2, 3, padding: true)
    @cv5 = Conv2D.new(base_filter_size * 4, 3, padding: true)
    @cv6 = Conv2D.new(base_filter_size * 4, 3, padding: true)
    @bn1 = BatchNormalization.new
    @bn2 = BatchNormalization.new
    @bn3 = BatchNormalization.new
    @bn4 = BatchNormalization.new
    @d1 = Dense.new(512)
    @d2 = Dense.new(10)
  end

  def forward(x)
    x = InputLayer.new(@input_shape).(x)

    x = @cv1.(x)
    x = ReLU.(x)
    x = Dropout.(x, 0.25)

    x = @cv2.(x)
    x = @bn1.(x)
    x = ReLU.(x)
    x = MaxPool2D.(x, 2)

    x = @cv3.(x)
    x = ReLU.(x)
    x = Dropout.(x, 0.25)

    x = @cv4.(x)
    x = @bn2.(x)
    x = ReLU.(x)
    x = MaxPool2D.(x, 2)

    x = @cv5.(x)
    x = ReLU.(x)
    x = Dropout.(x, 0.25)

    x = @cv6.(x)
    x = @bn3.(x)
    x = ReLU.(x)
    x = MaxPool2D.(x, 2)

    x = Flatten.(x)
    x = @d1.(x)
    x = @bn4.(x)
    x = ReLU.(x)
    x = @d2.(x)
    x
  end
end

Version data entries

6 entries across 6 versions & 1 rubygems

Version Path
ruby-dnn-1.3.0 examples/judge-number/convnet8.rb
ruby-dnn-1.2.3 examples/judge-number/convnet8.rb
ruby-dnn-1.2.2 examples/judge-number/convnet8.rb
ruby-dnn-1.2.1 examples/judge-number/convnet8.rb
ruby-dnn-1.2.0 examples/judge-number/convnet8.rb
ruby-dnn-1.1.6 examples/judge-number/convnet8.rb