lib/torch/nn/module.rb in torch-rb-0.2.0 vs lib/torch/nn/module.rb in torch-rb-0.2.1

- old
+ new

@@ -65,12 +65,13 @@ end fn.call(self) self end - def cuda(device: nil) - _apply ->(t) { t.cuda(device) } + # TODO add device + def cuda + _apply ->(t) { t.cuda } end def cpu _apply ->(t) { t.cpu } end @@ -110,11 +111,31 @@ destination[k] = v end destination end + # TODO add strict option + # TODO match PyTorch behavior def load_state_dict(state_dict) - raise NotImplementedYet + state_dict.each do |k, input_param| + k1, k2 = k.split(".", 2) + mod = named_modules[k1] + if mod.is_a?(Module) + param = mod.named_parameters[k2] + if param.is_a?(Parameter) + Torch.no_grad do + param.copy!(input_param) + end + else + raise Error, "Unknown parameter: #{k1}" + end + else + raise Error, "Unknown module: #{k1}" + end + end + + # TODO return missing keys and unexpected keys + nil end def parameters named_parameters.values end