lib/torch/nn/module.rb in torch-rb-0.1.4 vs lib/torch/nn/module.rb in torch-rb-0.1.5

- old
+ new

@@ -83,31 +83,62 @@ def state_dict raise NotImplementedYet end def parameters - params = [] + named_parameters.values + end + + def named_parameters(prefix: "", recurse: true) + params = {} + if recurse + named_children.each do |name, mod| + params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse)) + end + end instance_variables.each do |name| param = instance_variable_get(name) - params << param if param.is_a?(Parameter) + params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter) end - params + modules.flat_map { |_, mod| mod.parameters } + @parameters.each do |name, param| + params[[prefix, name].join] = param + end + params end + def buffers + named_buffers.values + end + + def named_buffers + @buffers || {} + end + def children - @modules.values + named_children.values end - def modules + def named_children modules = {} instance_variables.each do |name| mod = instance_variable_get(name) modules[name[1..-1]] = mod if mod.is_a?(Module) end - @modules.merge(modules) + @modules.each do |name, mod| + modules[name] = mod + end + modules end + def modules + named_modules.values + end + + def named_modules + {"" => self}.merge(named_children) + end + def train(mode = true) @training = mode children.each do |mod| mod.train(mode) end @@ -138,37 +169,53 @@ _apply ->(t) { t.share_memory! } end def inspect name = self.class.name.split("::").last - if modules.empty? + if children.empty? "#{name}(#{extra_inspect})" else str = String.new str << "#{name}(\n" - modules.each do |name, mod| + children.each do |name, mod| str << " (#{name}): #{mod.inspect}\n" end str << ")" end end def method_missing(method, *args, &block) - modules[method.to_s] || super + name = method.to_s + if named_parameters.key?(name) + named_parameters[name] + elsif named_buffers.key?(name) + named_buffers[name] + elsif named_modules.key?(name) + named_modules[name] + else + super + end end def respond_to?(method, include_private = false) - modules.key?(method.to_s) || super + name = method.to_s + named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super end private def extra_inspect nil end - def format(str, *vars) - str % vars.map(&:inspect) + def format(str, *vars, **options) + vars = + if vars.any? + vars.map(&:inspect) + else + options.map { |k, v| [k, v.inspect] }.to_h + end + str % vars end end end end