Sha256: b9707d50f1c8096b6e475f46a7bfcbe0db217e678e8068a4eadd3fce5dac34d4

Contents?: true

Size: 1.33 KB

Versions: 11

Compression:

Stored size: 1.33 KB

Contents

module DNN
  class Param
    attr_accessor :trainable
    attr_accessor :data
    attr_accessor :grad

    def initialize(data = nil, grad = nil)
      @data = data
      @grad = grad
      @trainable = true
    end

    def backward(grad)
      if @trainable
        @grad ||= Xumo::SFloat[0]
        if @data.shape == grad.shape
          @grad += grad
        elsif @data.shape == grad.shape[1..-1]
          @grad += grad.sum(0)
        else
          raise DNNError, "Shape is missmatch."
        end
      else
        @grad = Xumo::SFloat[0]
      end
    end

    def shape
      @data.shape
    end

    def +@
      self
    end

    def -@
      Neg.(self)
    end

    def +(other)
      other = Tensor.convert(other) unless other.is_a?(DNN::Tensor) || other.is_a?(DNN::Param)
      Layers::Add.(self, other)
    end

    def -(other)
      other = Tensor.convert(other) unless other.is_a?(DNN::Tensor) || other.is_a?(DNN::Param)
      Layers::Sub.(self, other)
    end

    def *(other)
      other = Tensor.convert(other) unless other.is_a?(DNN::Tensor) || other.is_a?(DNN::Param)
      Layers::Mul.(self, other)
    end

    def /(other)
      other = Tensor.convert(other) unless other.is_a?(DNN::Tensor) || other.is_a?(DNN::Param)
      Layers::Div.(self, other)
    end

    def **(index)
      Layers::Pow.new(index).(self)
    end
  end
end

Version data entries

11 entries across 11 versions & 1 rubygems

Version Path
ruby-dnn-1.3.0 lib/dnn/core/param.rb
ruby-dnn-1.2.3 lib/dnn/core/param.rb
ruby-dnn-1.2.2 lib/dnn/core/param.rb
ruby-dnn-1.2.1 lib/dnn/core/param.rb
ruby-dnn-1.2.0 lib/dnn/core/param.rb
ruby-dnn-1.1.6 lib/dnn/core/param.rb
ruby-dnn-1.1.5 lib/dnn/core/param.rb
ruby-dnn-1.1.4 lib/dnn/core/param.rb
ruby-dnn-1.1.3 lib/dnn/core/param.rb
ruby-dnn-1.1.2 lib/dnn/core/param.rb
ruby-dnn-1.1.1 lib/dnn/core/param.rb