Sha256: 89034bef56a4e99e3ad4e76f43c7a2e57a3ea95e27d49b158a6e8b4ed0651fdc

Contents?: true

Size: 981 Bytes

Versions: 2

Compression:

Stored size: 981 Bytes

Contents

module DNN
  class Param
    attr_accessor :trainable
    attr_accessor :data
    attr_accessor :grad

    def initialize(data = nil, grad = nil)
      @data = data
      @grad = grad
      @trainable = true
    end

    def backward(grad)
      if @trainable
        @grad ||= Xumo::SFloat[0]
        if @data.shape == grad.shape
          @grad += grad
        elsif @data.shape == grad.shape[1..-1]
          @grad += grad.sum(0)
        else
          raise DNNError, "Shape is missmatch."
        end
      else
        @grad = Xumo::SFloat[0]
      end
    end

    def shape
      @data.shape
    end

    def +@
      self
    end

    def -@
      Neg.(self)
    end

    def +(other)
      Layers::Add.(self, other)
    end

    def -(other)
      Layers::Sub.(self, other)
    end

    def *(other)
      Layers::Mul.(self, other)
    end

    def /(other)
      Layers::Div.(self, other)
    end

    def **(index)
      Layers::Pow.new(index).(self)
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
ruby-dnn-1.1.0 lib/dnn/core/param.rb
ruby-dnn-1.0.0 lib/dnn/core/param.rb