lib/torch/nn/utils.rb in torch-rb-0.8.1 vs lib/torch/nn/utils.rb in torch-rb-0.8.2

- old
+ new

@@ -18,8 +18,24 @@ end def _ntuple(n, value) value.is_a?(Array) ? value : [value] * n end + + def _clones(mod, n) + state = mod.state_dict + layers = n.times.map do |i| + mod.clone.tap { |l| l.load_state_dict(state) } + end + ModuleList.new(layers) + end + + def _activation_fn(activation) + case activation.to_sym + when :relu then F.method(:relu) + when :gelu then F.method(:gelu) + else raise ArgumentError, "Activation should be relu/gelu, not `#{activation}`" + end + end end end end