lib/torch/nn/utils.rb in torch-rb-0.8.1 vs lib/torch/nn/utils.rb in torch-rb-0.8.2
- old
+ new
@@ -18,8 +18,24 @@
end
def _ntuple(n, value)
value.is_a?(Array) ? value : [value] * n
end
+
+ def _clones(mod, n)
+ state = mod.state_dict
+ layers = n.times.map do |i|
+ mod.clone.tap { |l| l.load_state_dict(state) }
+ end
+ ModuleList.new(layers)
+ end
+
+ def _activation_fn(activation)
+ case activation.to_sym
+ when :relu then F.method(:relu)
+ when :gelu then F.method(:gelu)
+ else raise ArgumentError, "Activation should be relu/gelu, not `#{activation}`"
+ end
+ end
end
end
end