lib/torch/nn/module.rb in torch-rb-0.17.0 vs lib/torch/nn/module.rb in torch-rb-0.17.1

- old
+ new

@@ -8,20 +8,27 @@ def initialize @training = true @parameters = {} @buffers = {} @modules = {} + @non_persistent_buffers_set = Set.new end def forward raise NotImplementedError end - def register_buffer(name, tensor) + def register_buffer(name, tensor, persistent: true) # TODO add checks @buffers[name] = tensor instance_variable_set("@#{name}", tensor) + + if persistent + @non_persistent_buffers_set.delete(name) + else + @non_persistent_buffers_set << name + end end def register_parameter(name, param) # TODO add checks @parameters[name] = param @@ -188,12 +195,22 @@ def buffers named_buffers.values end - def named_buffers - @buffers || {} + # TODO set recurse: true in 0.18.0 + def named_buffers(prefix: "", recurse: false) + buffers = {} + if recurse + named_children.each do |name, mod| + buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse)) + end + end + (@buffers || {}).each do |k, v| + buffers[[prefix, k].join] = v + end + buffers end def children named_children.values end @@ -388,10 +405,13 @@ def save_to_state_dict(destination, prefix: "") named_parameters(recurse: false).each do |k, v| destination[prefix + k] = v end named_buffers.each do |k, v| - destination[prefix + k] = v + # TODO exclude v.nil? + if !@non_persistent_buffers_set.include?(k) + destination[prefix + k] = v + end end end # keep memo hash like Python deepcopy # https://docs.python.org/3/library/copy.html