lib/torch/nn/module.rb in torch-rb-0.8.2 vs lib/torch/nn/module.rb in torch-rb-0.8.3

- old
+ new

@@ -278,10 +278,15 @@ end str << ")" end end + def deep_dup + memo = {} + dup_value(self, memo) + end + def method_missing(method, *args, &block) name = method.to_s if named_parameters.key?(name) named_parameters[name] elsif named_buffers.key?(name) @@ -384,9 +389,32 @@ named_parameters(recurse: false).each do |k, v| destination[prefix + k] = v end named_buffers.each do |k, v| destination[prefix + k] = v + end + end + + # keep memo hash like Python deepcopy + # https://docs.python.org/3/library/copy.html + def dup_value(v, memo) + memo[v.object_id] ||= begin + case v + when Method, UnboundMethod + v + when Hash + v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] } + when Array + v.map { |v2| dup_value(v2, memo) } + when Torch::NN::Module + copy = v.dup + v.instance_variables.each do |var| + copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo)) + end + copy + else + v.dup + end end end end end end