lib/torch/nn/module.rb in torch-rb-0.2.0 vs lib/torch/nn/module.rb in torch-rb-0.2.1
- old
+ new
@@ -65,12 +65,13 @@
end
fn.call(self)
self
end
- def cuda(device: nil)
- _apply ->(t) { t.cuda(device) }
+ # TODO add device
+ def cuda
+ _apply ->(t) { t.cuda }
end
def cpu
_apply ->(t) { t.cpu }
end
@@ -110,11 +111,31 @@
destination[k] = v
end
destination
end
+ # TODO add strict option
+ # TODO match PyTorch behavior
def load_state_dict(state_dict)
- raise NotImplementedYet
+ state_dict.each do |k, input_param|
+ k1, k2 = k.split(".", 2)
+ mod = named_modules[k1]
+ if mod.is_a?(Module)
+ param = mod.named_parameters[k2]
+ if param.is_a?(Parameter)
+ Torch.no_grad do
+ param.copy!(input_param)
+ end
+ else
+ raise Error, "Unknown parameter: #{k1}"
+ end
+ else
+ raise Error, "Unknown module: #{k1}"
+ end
+ end
+
+ # TODO return missing keys and unexpected keys
+ nil
end
def parameters
named_parameters.values
end