lib/torch/nn/module.rb in torch-rb-0.2.3 vs lib/torch/nn/module.rb in torch-rb-0.2.4
- old
+ new
@@ -184,12 +184,26 @@
def modules
named_modules.values
end
- def named_modules
- {"" => self}.merge(named_children)
+ # TODO return enumerator?
+ def named_modules(memo: nil, prefix: "")
+ ret = {}
+ memo ||= Set.new
+ unless memo.include?(self)
+ memo << self
+ ret[prefix] = self
+ named_children.each do |name, mod|
+ next unless mod.is_a?(Module)
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
+ ret[m[0]] = m[1]
+ end
+ end
+ end
+ ret
end
def train(mode = true)
@training = mode
children.each do |mod|
@@ -228,10 +242,12 @@
"#{name}(#{extra_inspect})"
else
str = String.new
str << "#{name}(\n"
named_children.each do |name, mod|
- str << " (#{name}): #{mod.inspect}\n"
+ mod_str = mod.inspect
+ mod_str = mod_str.lines.join(" ")
+ str << " (#{name}): #{mod_str}\n"
end
str << ")"
end
end