lib/torch/nn/module.rb in torch-rb-0.2.3 vs lib/torch/nn/module.rb in torch-rb-0.2.4

- old
+ new

@@ -184,12 +184,26 @@ def modules named_modules.values end - def named_modules - {"" => self}.merge(named_children) + # TODO return enumerator? + def named_modules(memo: nil, prefix: "") + ret = {} + memo ||= Set.new + unless memo.include?(self) + memo << self + ret[prefix] = self + named_children.each do |name, mod| + next unless mod.is_a?(Module) + submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name + mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m| + ret[m[0]] = m[1] + end + end + end + ret end def train(mode = true) @training = mode children.each do |mod| @@ -228,10 +242,12 @@ "#{name}(#{extra_inspect})" else str = String.new str << "#{name}(\n" named_children.each do |name, mod| - str << " (#{name}): #{mod.inspect}\n" + mod_str = mod.inspect + mod_str = mod_str.lines.join(" ") + str << " (#{name}): #{mod_str}\n" end str << ")" end end