lib/torch/nn/init.rb in torch-rb-0.1.3 vs lib/torch/nn/init.rb in torch-rb-0.1.4
- old
+ new
@@ -1,9 +1,66 @@
module Torch
module NN
module Init
class << self
- def calculate_fan_in_and_fan_out(tensor)
+ def calculate_gain(nonlinearity, param: 0.01)
+ _calculate_gain(nonlinearity, param)
+ end
+
+ def uniform!(tensor, a: 0.0, b: 1.0)
+ _uniform!(tensor, a, b)
+ end
+
+ def normal!(tensor, mean: 0.0, std: 1.0)
+ _normal!(tensor, mean, std)
+ end
+
+ def constant!(tensor, val)
+ _constant!(tensor, val)
+ end
+
+ def ones!(tensor)
+ _ones!(tensor)
+ end
+
+ def zeros!(tensor)
+ _zeros!(tensor)
+ end
+
+ def eye!(tensor)
+ _eye!(tensor)
+ end
+
+ def dirac!(tensor)
+ _dirac!(tensor)
+ end
+
+ def xavier_uniform!(tensor, gain: 1.0)
+ _xavier_uniform!(tensor, gain)
+ end
+
+ def xavier_normal!(tensor, gain: 1.0)
+ _xavier_normal!(tensor, gain)
+ end
+
+ def kaiming_uniform!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
+ _kaiming_uniform!(tensor, a, mode, nonlinearity)
+ end
+
+ def kaiming_normal!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
+ _kaiming_normal!(tensor, a, mode, nonlinearity)
+ end
+
+ def orthogonal!(tensor, gain: 1)
+ _orthogonal!(tensor, gain)
+ end
+
+ def sparse!(tensor, sparsity, std: 0.01)
+ _sparse!(tensor, sparsity, std)
+ end
+
+ # TODO move to C++ when released
+ def _calculate_fan_in_and_fan_out(tensor)
dimensions = tensor.dim
if dimensions < 2
raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
end