Sha256: 1bba6e5431f7e01cac7ec698b8684e476999cea9157dc9fd2f174a5aa7115181

Contents?: true

Size: 1.75 KB

Versions: 1

Compression:

Stored size: 1.75 KB

Contents

module TensorFlow
  module Keras
    module Layers
      class Dense
        def initialize(units, activation: nil, use_bias: true, kernel_initializer: "glorot_uniform", bias_initializer: "zeros", dtype: :float)
          @units = units
          @activation = activation
          @use_bias = use_bias
          @kernel_initializer = kernel_initializer
          @bias_initializer = bias_initializer
          @dtype = dtype
          @built = false
        end

        def build(input_shape)
          last_dim = input_shape.last
          @kernel = Utils.add_weight(name: "kernel", shape: [last_dim, @units], initializer: @kernel_initializer, dtype: @dtype)

          if @use_bias
            @bias = Utils.add_weight(name: "bias", shape: [@units], initializer: @bias_initializer, dtype: @dtype)
          else
            @bias = nil
          end

          @output_shape = [last_dim, @units]

          @built = true
        end

        def output_shape
          @output_shape
        end

        def count_params
          @units + @kernel.shape.inject(&:*)
        end

        def call(inputs)
          build(inputs.shape) unless @built

          rank = inputs.shape.size

          if rank > 2
            raise Error, "Rank > 2 not supported yet"
          else
            inputs = TensorFlow.cast(inputs, @dtype)
            outputs = TensorFlow.matmul(inputs, @kernel)
          end

          if @use_bias
            outputs = NN.bias_add(outputs, @bias)
          end

          case @activation
          when "relu"
            NN.relu(outputs)
          when "softmax"
            NN.softmax(outputs)
          when nil
            outputs
          else
            raise "Unknown activation: #{@activation}"
          end
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/keras/layers/dense.rb