lib/torch/nn/conv2d.rb in torch-rb-0.1.2 vs lib/torch/nn/conv2d.rb in torch-rb-0.1.3
- old
+ new
@@ -1,37 +1,25 @@
module Torch
module NN
- class Conv2d < Module
+ class Conv2d < ConvNd
attr_reader :bias, :weight
- def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0) #, dilation: 1, groups: 1)
- @in_channels = in_channels
- @out_channels = out_channels
- @kernel_size = pair(kernel_size)
- @stride = pair(stride)
- @padding = pair(padding)
- # @dilation = pair(dilation)
-
- # TODO divide by groups
- @weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
- @bias = Parameter.new(Tensor.new(out_channels))
-
- reset_parameters
+ def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
+ kernel_size = pair(kernel_size)
+ stride = pair(stride)
+ padding = pair(padding)
+ dilation = pair(dilation)
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
end
- def reset_parameters
- Init.kaiming_uniform_(@weight, Math.sqrt(5))
- if @bias
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
- bound = 1 / Math.sqrt(fan_in)
- Init.uniform_(@bias, -bound, bound)
+ def forward(input)
+ if @padding_mode == "circular"
+ raise NotImplementedError
end
+ F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups)
end
- def call(input)
- F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding) #, @dilation, @groups)
- end
-
+ # TODO add more parameters
def inspect
"Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
end
private