lib/torch/nn/conv2d.rb in torch-rb-0.1.2 vs lib/torch/nn/conv2d.rb in torch-rb-0.1.3

- old
+ new

@@ -1,37 +1,25 @@ module Torch module NN - class Conv2d < Module + class Conv2d < ConvNd attr_reader :bias, :weight - def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0) #, dilation: 1, groups: 1) - @in_channels = in_channels - @out_channels = out_channels - @kernel_size = pair(kernel_size) - @stride = pair(stride) - @padding = pair(padding) - # @dilation = pair(dilation) - - # TODO divide by groups - @weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size)) - @bias = Parameter.new(Tensor.new(out_channels)) - - reset_parameters + def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros") + kernel_size = pair(kernel_size) + stride = pair(stride) + padding = pair(padding) + dilation = pair(dilation) + super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode) end - def reset_parameters - Init.kaiming_uniform_(@weight, Math.sqrt(5)) - if @bias - fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight) - bound = 1 / Math.sqrt(fan_in) - Init.uniform_(@bias, -bound, bound) + def forward(input) + if @padding_mode == "circular" + raise NotImplementedError end + F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups) end - def call(input) - F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding) #, @dilation, @groups) - end - + # TODO add more parameters def inspect "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})" end private