Sha256: 8672faac2b887f4e1c2485ac055a0b4dfac869f8cfeb6a4c23d360ac32e0d65c

Contents?: true

Size: 1.04 KB

Versions: 1

Compression:

Stored size: 1.04 KB

Contents

module Torch
  module NN
    class Conv2d < ConvNd
      attr_reader :bias, :weight

      def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
        kernel_size = pair(kernel_size)
        stride = pair(stride)
        padding = pair(padding)
        dilation = pair(dilation)
        super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
      end

      def forward(input)
        if @padding_mode == "circular"
          raise NotImplementedError
        end
        F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups)
      end

      # TODO add more parameters
      def extra_inspect
        format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
      end

      private

      def pair(value)
        if value.is_a?(Array)
          value
        else
          [value] * 2
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
torch-rb-0.1.4 lib/torch/nn/conv2d.rb