Sha256: b5173809e1e02488c209383106afe219decb066c8f32eae54897aa81d95635c6

Contents?: true

Size: 845 Bytes

Versions: 4

Compression:

Stored size: 845 Bytes

Contents

module Torch
  module NN
    module Init
      class << self
        def calculate_fan_in_and_fan_out(tensor)
          dimensions = tensor.dim
          if dimensions < 2
            raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
          end

          if dimensions == 2
            fan_in = tensor.size(1)
            fan_out = tensor.size(0)
          else
            num_input_fmaps = tensor.size(1)
            num_output_fmaps = tensor.size(0)
            receptive_field_size = 1
            if tensor.dim > 2
              receptive_field_size = tensor[0][0].numel
            end
            fan_in = num_input_fmaps * receptive_field_size
            fan_out = num_output_fmaps * receptive_field_size
          end

          [fan_in, fan_out]
        end
      end
    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
torch-rb-0.1.3 lib/torch/nn/init.rb
torch-rb-0.1.2 lib/torch/nn/init.rb
torch-rb-0.1.1 lib/torch/nn/init.rb
torch-rb-0.1.0 lib/torch/nn/init.rb