Sha256: 27f95a0019d6069fdbd8815a304536df1558601cd9c1f35bd42055c1d8d630dc

Contents?: true

Size: 849 Bytes

Versions: 5

Compression:

Stored size: 849 Bytes

Contents

module DNN
  class Link
    attr_accessor :prevs
    attr_accessor :next
    attr_accessor :layer_node
    attr_reader :num_outputs

    def initialize(prevs: nil, layer_node: nil, num_outputs: 1)
      @prevs = prevs
      @layer_node = layer_node
      @num_outputs = num_outputs
      @next = nil
      @hold = []
    end

    def forward(x)
      @hold << x
      return if @hold.length < @prevs.length
      x = @layer_node.(*@hold)
      @hold = []
      @next ? @next.forward(x) : x
    end

    def backward(dy = Xumo::SFloat[1])
      @hold << dy
      return if @hold.length < @num_outputs
      dys = @layer_node.backward_node(*@hold)
      @hold = []
      if dys.is_a?(Array)
        dys.each.with_index do |dy, i|
          @prevs[i]&.backward(dy)
        end
      else
        @prevs.first&.backward(dys)
      end
    end
  end
end

Version data entries

5 entries across 5 versions & 1 rubygems

Version Path
ruby-dnn-1.3.0 lib/dnn/core/link.rb
ruby-dnn-1.2.3 lib/dnn/core/link.rb
ruby-dnn-1.2.2 lib/dnn/core/link.rb
ruby-dnn-1.2.1 lib/dnn/core/link.rb
ruby-dnn-1.2.0 lib/dnn/core/link.rb