lib/torch/nn/conv2d.rb in torch-rb-0.1.1 vs lib/torch/nn/conv2d.rb in torch-rb-0.1.2
- old
+ new
@@ -1,17 +1,16 @@
module Torch
module NN
class Conv2d < Module
attr_reader :bias, :weight
- def initialize(in_channels, out_channels, kernel_size) #, stride: 1, padding: 0, dilation: 1, groups: 1)
+ def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0) #, dilation: 1, groups: 1)
@in_channels = in_channels
@out_channels = out_channels
@kernel_size = pair(kernel_size)
- @stride = pair(1)
- # @stride = pair(stride)
- # @padding = pair(padding)
+ @stride = pair(stride)
+ @padding = pair(padding)
# @dilation = pair(dilation)
# TODO divide by groups
@weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
@bias = Parameter.new(Tensor.new(out_channels))
@@ -27,10 +26,10 @@
Init.uniform_(@bias, -bound, bound)
end
end
def call(input)
- F.conv2d(input, @weight, @bias) # @stride, self.padding, self.dilation, self.groups)
+ F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding) #, @dilation, @groups)
end
def inspect
"Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
end