lib/torch/nn/module.rb in torch-rb-0.1.8 vs lib/torch/nn/module.rb in torch-rb-0.2.0
- old
+ new
@@ -32,10 +32,31 @@
def _apply(fn)
children.each do |mod|
mod._apply(fn)
end
+
+ instance_variables.each do |key|
+ param = instance_variable_get(key)
+ if param.is_a?(Parameter)
+ param_applied = nil
+ Torch.no_grad do
+ param_applied = fn.call(param)
+ end
+ # TODO should_use_set_data
+ instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad))
+
+ if param.grad
+ grad_applied = nil
+ Torch.no_grad do
+ grad_applied = fn.call(param.grad)
+ end
+ # TODO should_use_set_data
+ instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad)
+ end
+ end
+ end
# TODO apply to more objects
self
end
def apply(fn)
@@ -109,10 +130,10 @@
instance_variables.each do |name|
param = instance_variable_get(name)
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
end
@parameters.each do |name, param|
- params[[prefix, name].join] = param
+ params[[prefix, name].join] = param if param
end
params
end
def buffers