Sha256: b9fc83453ebdc759d4090ee915ec25a42dbc6db78bff122ecf76a9b9764cd9ae

Contents?: true

Size: 1.28 KB

Versions: 4

Compression:

Stored size: 1.28 KB

Contents

module Torch
  module NN
    class Conv2d < ConvNd
      def initialize(in_channels, out_channels, kernel_size, stride: 1,
        padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode)
      end

      def forward(input)
        if @padding_mode == "circular"
          raise NotImplementedError
        end
        F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
      end

      # TODO add more parameters
      def extra_inspect
        s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
        s += ", padding: %{padding}" if @padding != [0] * @padding.size
        s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
        s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
        s += ", groups: %{groups}" if @groups != 1
        s += ", bias: false" unless @bias
        s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
        format(s, **dict)
      end
    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
torch-rb-0.2.7 lib/torch/nn/conv2d.rb
torch-rb-0.2.6 lib/torch/nn/conv2d.rb
torch-rb-0.2.5 lib/torch/nn/conv2d.rb
torch-rb-0.2.4 lib/torch/nn/conv2d.rb