lib/torch/nn/conv2d.rb in torch-rb-0.1.5 vs lib/torch/nn/conv2d.rb in torch-rb-0.1.6

- old
+ new

@@ -1,35 +1,27 @@ module Torch module NN class Conv2d < ConvNd - def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros") - kernel_size = pair(kernel_size) - stride = pair(stride) - padding = pair(padding) - dilation = pair(dilation) - super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode) + def initialize(in_channels, out_channels, kernel_size, stride: 1, + padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros") + + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _pair(0), groups, bias, padding_mode) end def forward(input) if @padding_mode == "circular" raise NotImplementedError end - F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups) + F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups) end # TODO add more parameters def extra_inspect format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride) - end - - private - - def pair(value) - if value.is_a?(Array) - value - else - [value] * 2 - end end end end end