lib/torch/nn/module.rb in torch-rb-0.1.8 vs lib/torch/nn/module.rb in torch-rb-0.2.0

- old
+ new

@@ -32,10 +32,31 @@ def _apply(fn) children.each do |mod| mod._apply(fn) end + + instance_variables.each do |key| + param = instance_variable_get(key) + if param.is_a?(Parameter) + param_applied = nil + Torch.no_grad do + param_applied = fn.call(param) + end + # TODO should_use_set_data + instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad)) + + if param.grad + grad_applied = nil + Torch.no_grad do + grad_applied = fn.call(param.grad) + end + # TODO should_use_set_data + instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad) + end + end + end # TODO apply to more objects self end def apply(fn) @@ -109,10 +130,10 @@ instance_variables.each do |name| param = instance_variable_get(name) params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter) end @parameters.each do |name, param| - params[[prefix, name].join] = param + params[[prefix, name].join] = param if param end params end def buffers