Sha256: c8569861d09c05c0cbd0a7d6283ba6a4cfcbe08e1e124531fbe473c691403564

Contents?: true

Size: 1.3 KB

Versions: 1

Compression:

Stored size: 1.3 KB

Contents

module Torch
  module NN
    class Conv2d < Module
      attr_reader :bias, :weight

      def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0) #, dilation: 1, groups: 1)
        @in_channels = in_channels
        @out_channels = out_channels
        @kernel_size = pair(kernel_size)
        @stride = pair(stride)
        @padding = pair(padding)
        # @dilation = pair(dilation)

        # TODO divide by groups
        @weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
        @bias = Parameter.new(Tensor.new(out_channels))

        reset_parameters
      end

      def reset_parameters
        Init.kaiming_uniform_(@weight, Math.sqrt(5))
        if @bias
          fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
          bound = 1 / Math.sqrt(fan_in)
          Init.uniform_(@bias, -bound, bound)
        end
      end

      def call(input)
        F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding) #, @dilation, @groups)
      end

      def inspect
        "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
      end

      private

      def pair(value)
        if value.is_a?(Array)
          value
        else
          [value] * 2
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
torch-rb-0.1.2 lib/torch/nn/conv2d.rb