lib/torch/nn/module.rb in torch-rb-0.17.0 vs lib/torch/nn/module.rb in torch-rb-0.17.1
- old
+ new
@@ -8,20 +8,27 @@
def initialize
@training = true
@parameters = {}
@buffers = {}
@modules = {}
+ @non_persistent_buffers_set = Set.new
end
def forward
raise NotImplementedError
end
- def register_buffer(name, tensor)
+ def register_buffer(name, tensor, persistent: true)
# TODO add checks
@buffers[name] = tensor
instance_variable_set("@#{name}", tensor)
+
+ if persistent
+ @non_persistent_buffers_set.delete(name)
+ else
+ @non_persistent_buffers_set << name
+ end
end
def register_parameter(name, param)
# TODO add checks
@parameters[name] = param
@@ -188,12 +195,22 @@
def buffers
named_buffers.values
end
- def named_buffers
- @buffers || {}
+ # TODO set recurse: true in 0.18.0
+ def named_buffers(prefix: "", recurse: false)
+ buffers = {}
+ if recurse
+ named_children.each do |name, mod|
+ buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
+ end
+ end
+ (@buffers || {}).each do |k, v|
+ buffers[[prefix, k].join] = v
+ end
+ buffers
end
def children
named_children.values
end
@@ -388,10 +405,13 @@
def save_to_state_dict(destination, prefix: "")
named_parameters(recurse: false).each do |k, v|
destination[prefix + k] = v
end
named_buffers.each do |k, v|
- destination[prefix + k] = v
+ # TODO exclude v.nil?
+ if !@non_persistent_buffers_set.include?(k)
+ destination[prefix + k] = v
+ end
end
end
# keep memo hash like Python deepcopy
# https://docs.python.org/3/library/copy.html