lib/torch/nn/module.rb in torch-rb-0.5.2 vs lib/torch/nn/module.rb in torch-rb-0.5.3

- old
+ new

@@ -111,39 +111,57 @@ def call(*input, **kwargs) forward(*input, **kwargs) end - def state_dict(destination: nil) + def state_dict(destination: nil, prefix: "") destination ||= {} - named_parameters.each do |k, v| - destination[k] = v + save_to_state_dict(destination, prefix: prefix) + + named_children.each do |name, mod| + next unless mod + mod.state_dict(destination: destination, prefix: prefix + name + ".") end destination end - # TODO add strict option - # TODO match PyTorch behavior - def load_state_dict(state_dict) - state_dict.each do |k, input_param| - k1, k2 = k.split(".", 2) - mod = named_modules[k1] - if mod.is_a?(Module) - param = mod.named_parameters[k2] - if param.is_a?(Parameter) - Torch.no_grad do - param.copy!(input_param) - end - else - raise Error, "Unknown parameter: #{k1}" - end - else - raise Error, "Unknown module: #{k1}" + def load_state_dict(state_dict, strict: true) + # TODO support strict: false + raise "strict: false not implemented yet" unless strict + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + + # TODO handle metadata + + _load = lambda do |mod, prefix = ""| + # TODO handle metadata + local_metadata = {} + mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs) + mod.named_children.each do |name, child| + _load.call(child, prefix + name + ".") unless child.nil? end end - # TODO return missing keys and unexpected keys + _load.call(self) + + if strict + if unexpected_keys.any? + error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}" + end + + if missing_keys.any? + error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}" + end + end + + if error_msgs.any? + # just show first error + raise Error, error_msgs[0] + end + nil end def parameters named_parameters.values @@ -297,9 +315,71 @@ # used for format # remove tensors for performance # so we can skip call to inspect def dict instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h + end + + def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + # TODO add hooks + + # TODO handle non-persistent buffers + persistent_buffers = named_buffers + local_name_params = named_parameters(recurse: false).merge(persistent_buffers) + local_state = local_name_params.select { |_, v| !v.nil? } + + local_state.each do |name, param| + key = prefix + name + if state_dict.key?(key) + input_param = state_dict[key] + + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if param.shape.length == 0 && input_param.shape.length == 1 + input_param = input_param[0] + end + + if input_param.shape != param.shape + # local shape should match the one in checkpoint + error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " + + "the shape in current model is #{param.shape}." + next + end + + begin + Torch.no_grad do + param.copy!(input_param) + end + rescue => e + error_msgs << "While copying the parameter named #{key.inspect}, " + + "whose dimensions in the model are #{param.size} and " + + "whose dimensions in the checkpoint are #{input_param.size}, " + + "an exception occurred: #{e.inspect}" + end + elsif strict + missing_keys << key + end + end + + if strict + state_dict.each_key do |key| + if key.start_with?(prefix) + input_name = key[prefix.length..-1] + input_name = input_name.split(".", 2)[0] + if !named_children.key?(input_name) && !local_state.key?(input_name) + unexpected_keys << key + end + end + end + end + end + + def save_to_state_dict(destination, prefix: "") + named_parameters(recurse: false).each do |k, v| + destination[prefix + k] = v + end + named_buffers.each do |k, v| + destination[prefix + k] = v + end end end end end