Sha256: b5173809e1e02488c209383106afe219decb066c8f32eae54897aa81d95635c6
Contents?: true
Size: 845 Bytes
Versions: 4
Compression:
Stored size: 845 Bytes
Contents
module Torch module NN module Init class << self def calculate_fan_in_and_fan_out(tensor) dimensions = tensor.dim if dimensions < 2 raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" end if dimensions == 2 fan_in = tensor.size(1) fan_out = tensor.size(0) else num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) receptive_field_size = 1 if tensor.dim > 2 receptive_field_size = tensor[0][0].numel end fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size end [fan_in, fan_out] end end end end end
Version data entries
4 entries across 4 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.1.3 | lib/torch/nn/init.rb |
torch-rb-0.1.2 | lib/torch/nn/init.rb |
torch-rb-0.1.1 | lib/torch/nn/init.rb |
torch-rb-0.1.0 | lib/torch/nn/init.rb |