lib/torch/nn/module.rb in torch-rb-0.8.2 vs lib/torch/nn/module.rb in torch-rb-0.8.3
- old
+ new
@@ -278,10 +278,15 @@
end
str << ")"
end
end
+ def deep_dup
+ memo = {}
+ dup_value(self, memo)
+ end
+
def method_missing(method, *args, &block)
name = method.to_s
if named_parameters.key?(name)
named_parameters[name]
elsif named_buffers.key?(name)
@@ -384,9 +389,32 @@
named_parameters(recurse: false).each do |k, v|
destination[prefix + k] = v
end
named_buffers.each do |k, v|
destination[prefix + k] = v
+ end
+ end
+
+ # keep memo hash like Python deepcopy
+ # https://docs.python.org/3/library/copy.html
+ def dup_value(v, memo)
+ memo[v.object_id] ||= begin
+ case v
+ when Method, UnboundMethod
+ v
+ when Hash
+ v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
+ when Array
+ v.map { |v2| dup_value(v2, memo) }
+ when Torch::NN::Module
+ copy = v.dup
+ v.instance_variables.each do |var|
+ copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
+ end
+ copy
+ else
+ v.dup
+ end
end
end
end
end
end