Sha256: a05aac9f1f577ee39ee735005ecf34a32efc548abcd87a8f8f0fbb12b72683ee

Contents?: true

Size: 874 Bytes

Versions: 5

Compression:

Stored size: 874 Bytes

Contents

module DNN
  module Layers

    class Split < Layer
      include LayerNode

      attr_reader :axis
      attr_reader :dim

      def initialize(axis: 1, dim: nil)
        super()
        raise DNNError, "dim is nil" if dim == nil
        @axis = axis
        @dim = dim
      end

      def forward_node(x)
        x1_dim = @dim
        x2_dim = x.shape[@axis] - @dim
        y1, y2others = x.split([x1_dim, x1_dim + x2_dim], axis: @axis)
        y2 = y2others.is_a?(Array) ? y2others[0].concatenate(y2others[1..-1], axis: @axis) : y2others
        [y1, y2]
      end

      def backward_node(dy1, dy2)
        dy1.concatenate(dy2, axis: @axis)
      end

      def to_hash
        super(axis: @axis, dim: @dim)
      end

      def load_hash(hash)
        initialize(axis: hash[:axis], dim: hash[:dim])
      end
    end

  end
end

Version data entries

5 entries across 5 versions & 1 rubygems

Version Path
ruby-dnn-1.3.0 lib/dnn/core/layers/split_layers.rb
ruby-dnn-1.2.3 lib/dnn/core/layers/split_layers.rb
ruby-dnn-1.2.2 lib/dnn/core/layers/split_layers.rb
ruby-dnn-1.2.1 lib/dnn/core/layers/split_layers.rb
ruby-dnn-1.2.0 lib/dnn/core/layers/split_layers.rb