Sha256: e13f9ab95328ad4ca92d77d2d1f72fc553ee8313d43f1c263744e9f1b1699fef

Contents?: true

Size: 981 Bytes

Versions: 2

Compression:

Stored size: 981 Bytes

Contents

module DNN
  class Param
    attr_accessor :trainable
    attr_accessor :data
    attr_accessor :grad

    def initialize(data = nil, grad = nil)
      @data = data
      @grad = grad
      @trainable = true
    end

    def backward(grad)
      if @trainable
        @grad ||= Xumo::SFloat[0]
        if @data.shape == grad.shape
          @grad += grad
        elsif @data.shape == grad.shape[1..-1]
          @grad += grad.sum(0)
        else
          raise DNN_Error, "Shape is missmatch."
        end
      else
        @grad = Xumo::SFloat[0]
      end
    end

    def shape
      @data.shape
    end

    def +@
      self
    end

    def -@
      self * -1
    end

    def +(other)
      Layers::Add.(self, other)
    end

    def -(other)
      Layers::Sub.(self, other)
    end

    def *(other)
      Layers::Mul.(self, other)
    end

    def /(other)
      Layers::Div.(self, other)
    end

    def **(index)
      Layers::Pow.new(index).(self)
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
ruby-dnn-0.16.2 lib/dnn/core/param.rb
ruby-dnn-0.16.1 lib/dnn/core/param.rb