Sha256: be51e430b3a7ee553d3685ef09bb52c8179880b14dcad519cd95fff56e91abce

Contents?: true

Size: 661 Bytes

Versions: 2

Compression:

Stored size: 661 Bytes

Contents

module DNN

  class Lasso
    def initialize(l1_lambda, param)
      @l1_lambda = l1_lambda
      @param = param
    end

    def forward(x)
      x + @l1_lambda * @param.data.abs.sum
    end

    def backward
      dlasso = Xumo::SFloat.ones(*@param.data.shape)
      dlasso[@param.data < 0] = -1
      @param.grad += @l1_lambda * dlasso
    end
  end


  class Ridge
    def initialize(l2_lambda, param)
      @l2_lambda = l2_lambda
      @param = param
    end

    def forward(x)
      x + 0.5 * @l2_lambda * (@param.data**2).sum
    end

    def backward
      @param.grad += @l2_lambda * @param.data
    end
  end

end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
ruby-dnn-0.9.4 lib/dnn/core/regularizers.rb
ruby-dnn-0.9.3 lib/dnn/core/regularizers.rb