lib/torch/nn/conv2d.rb in torch-rb-0.1.1 vs lib/torch/nn/conv2d.rb in torch-rb-0.1.2

- old
+ new

@@ -1,17 +1,16 @@ module Torch module NN class Conv2d < Module attr_reader :bias, :weight - def initialize(in_channels, out_channels, kernel_size) #, stride: 1, padding: 0, dilation: 1, groups: 1) + def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0) #, dilation: 1, groups: 1) @in_channels = in_channels @out_channels = out_channels @kernel_size = pair(kernel_size) - @stride = pair(1) - # @stride = pair(stride) - # @padding = pair(padding) + @stride = pair(stride) + @padding = pair(padding) # @dilation = pair(dilation) # TODO divide by groups @weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size)) @bias = Parameter.new(Tensor.new(out_channels)) @@ -27,10 +26,10 @@ Init.uniform_(@bias, -bound, bound) end end def call(input) - F.conv2d(input, @weight, @bias) # @stride, self.padding, self.dilation, self.groups) + F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding) #, @dilation, @groups) end def inspect "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})" end