lib/torch/nn/module.rb in torch-rb-0.5.2 vs lib/torch/nn/module.rb in torch-rb-0.5.3
- old
+ new
@@ -111,39 +111,57 @@
def call(*input, **kwargs)
forward(*input, **kwargs)
end
- def state_dict(destination: nil)
+ def state_dict(destination: nil, prefix: "")
destination ||= {}
- named_parameters.each do |k, v|
- destination[k] = v
+ save_to_state_dict(destination, prefix: prefix)
+
+ named_children.each do |name, mod|
+ next unless mod
+ mod.state_dict(destination: destination, prefix: prefix + name + ".")
end
destination
end
- # TODO add strict option
- # TODO match PyTorch behavior
- def load_state_dict(state_dict)
- state_dict.each do |k, input_param|
- k1, k2 = k.split(".", 2)
- mod = named_modules[k1]
- if mod.is_a?(Module)
- param = mod.named_parameters[k2]
- if param.is_a?(Parameter)
- Torch.no_grad do
- param.copy!(input_param)
- end
- else
- raise Error, "Unknown parameter: #{k1}"
- end
- else
- raise Error, "Unknown module: #{k1}"
+ def load_state_dict(state_dict, strict: true)
+ # TODO support strict: false
+ raise "strict: false not implemented yet" unless strict
+
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+
+ # TODO handle metadata
+
+ _load = lambda do |mod, prefix = ""|
+ # TODO handle metadata
+ local_metadata = {}
+ mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
+ mod.named_children.each do |name, child|
+ _load.call(child, prefix + name + ".") unless child.nil?
end
end
- # TODO return missing keys and unexpected keys
+ _load.call(self)
+
+ if strict
+ if unexpected_keys.any?
+ error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
+ end
+
+ if missing_keys.any?
+ error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
+ end
+ end
+
+ if error_msgs.any?
+ # just show first error
+ raise Error, error_msgs[0]
+ end
+
nil
end
def parameters
named_parameters.values
@@ -297,9 +315,71 @@
# used for format
# remove tensors for performance
# so we can skip call to inspect
def dict
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
+ end
+
+ def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+ # TODO add hooks
+
+ # TODO handle non-persistent buffers
+ persistent_buffers = named_buffers
+ local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
+ local_state = local_name_params.select { |_, v| !v.nil? }
+
+ local_state.each do |name, param|
+ key = prefix + name
+ if state_dict.key?(key)
+ input_param = state_dict[key]
+
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
+ if param.shape.length == 0 && input_param.shape.length == 1
+ input_param = input_param[0]
+ end
+
+ if input_param.shape != param.shape
+ # local shape should match the one in checkpoint
+ error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
+ "the shape in current model is #{param.shape}."
+ next
+ end
+
+ begin
+ Torch.no_grad do
+ param.copy!(input_param)
+ end
+ rescue => e
+ error_msgs << "While copying the parameter named #{key.inspect}, " +
+ "whose dimensions in the model are #{param.size} and " +
+ "whose dimensions in the checkpoint are #{input_param.size}, " +
+ "an exception occurred: #{e.inspect}"
+ end
+ elsif strict
+ missing_keys << key
+ end
+ end
+
+ if strict
+ state_dict.each_key do |key|
+ if key.start_with?(prefix)
+ input_name = key[prefix.length..-1]
+ input_name = input_name.split(".", 2)[0]
+ if !named_children.key?(input_name) && !local_state.key?(input_name)
+ unexpected_keys << key
+ end
+ end
+ end
+ end
+ end
+
+ def save_to_state_dict(destination, prefix: "")
+ named_parameters(recurse: false).each do |k, v|
+ destination[prefix + k] = v
+ end
+ named_buffers.each do |k, v|
+ destination[prefix + k] = v
+ end
end
end
end
end