lib/torch/nn/module.rb in torch-rb-0.1.4 vs lib/torch/nn/module.rb in torch-rb-0.1.5
- old
+ new
@@ -83,31 +83,62 @@
def state_dict
raise NotImplementedYet
end
def parameters
- params = []
+ named_parameters.values
+ end
+
+ def named_parameters(prefix: "", recurse: true)
+ params = {}
+ if recurse
+ named_children.each do |name, mod|
+ params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
+ end
+ end
instance_variables.each do |name|
param = instance_variable_get(name)
- params << param if param.is_a?(Parameter)
+ params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
end
- params + modules.flat_map { |_, mod| mod.parameters }
+ @parameters.each do |name, param|
+ params[[prefix, name].join] = param
+ end
+ params
end
+ def buffers
+ named_buffers.values
+ end
+
+ def named_buffers
+ @buffers || {}
+ end
+
def children
- @modules.values
+ named_children.values
end
- def modules
+ def named_children
modules = {}
instance_variables.each do |name|
mod = instance_variable_get(name)
modules[name[1..-1]] = mod if mod.is_a?(Module)
end
- @modules.merge(modules)
+ @modules.each do |name, mod|
+ modules[name] = mod
+ end
+ modules
end
+ def modules
+ named_modules.values
+ end
+
+ def named_modules
+ {"" => self}.merge(named_children)
+ end
+
def train(mode = true)
@training = mode
children.each do |mod|
mod.train(mode)
end
@@ -138,37 +169,53 @@
_apply ->(t) { t.share_memory! }
end
def inspect
name = self.class.name.split("::").last
- if modules.empty?
+ if children.empty?
"#{name}(#{extra_inspect})"
else
str = String.new
str << "#{name}(\n"
- modules.each do |name, mod|
+ children.each do |name, mod|
str << " (#{name}): #{mod.inspect}\n"
end
str << ")"
end
end
def method_missing(method, *args, &block)
- modules[method.to_s] || super
+ name = method.to_s
+ if named_parameters.key?(name)
+ named_parameters[name]
+ elsif named_buffers.key?(name)
+ named_buffers[name]
+ elsif named_modules.key?(name)
+ named_modules[name]
+ else
+ super
+ end
end
def respond_to?(method, include_private = false)
- modules.key?(method.to_s) || super
+ name = method.to_s
+ named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end
private
def extra_inspect
nil
end
- def format(str, *vars)
- str % vars.map(&:inspect)
+ def format(str, *vars, **options)
+ vars =
+ if vars.any?
+ vars.map(&:inspect)
+ else
+ options.map { |k, v| [k, v.inspect] }.to_h
+ end
+ str % vars
end
end
end
end