Sha256: 41db24aba7706c6347c76d796152f9c769aba11f5a3920df6877da687adf9e29

Contents?: true

Size: 806 Bytes

Versions: 1

Compression:

Stored size: 806 Bytes

Contents

module DNN
  class Link
    attr_accessor :prev
    attr_accessor :layer_node

    def initialize(prev = nil, layer_node = nil)
      @prev = prev
      @layer_node = layer_node
    end

    def backward(dy = Numo::SFloat[1])
      dy = @layer_node.backward_node(dy)
      @prev&.backward(dy)
    end
  end

  class TwoInputLink
    attr_accessor :prev1
    attr_accessor :prev2
    attr_accessor :layer_node

    def initialize(prev1 = nil, prev2 = nil, layer_node = nil)
      @prev1 = prev1
      @prev2 = prev2
      @layer_node = layer_node
    end

    def backward(dy = Numo::SFloat[1])
      dys = @layer_node.backward_node(dy)
      if dys.is_a?(Array)
        dy1, dy2 = *dys
      else
        dy1 = dys
      end
      @prev1&.backward(dy1)
      @prev2&.backward(dy2) if dy2
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
ruby-dnn-1.0.0 lib/dnn/core/link.rb